diff --git a/docs/api/ml.rl.training.rst b/docs/api/ml.rl.training.rst index f86cacfa..57785f36 100644 --- a/docs/api/ml.rl.training.rst +++ b/docs/api/ml.rl.training.rst @@ -64,7 +64,7 @@ ml.rl.training.imitator\_training module ml.rl.training.loss\_reporter module ------------------------------------ -.. automodule:: ml.rl.training.rl_reporter +.. automodule:: ml.rl.training.loss_reporter :members: :undoc-members: :show-inheritance: diff --git a/reagent/__init__.py b/reagent/__init__.py index e69de29b..5be5087f 100644 --- a/reagent/__init__.py +++ b/reagent/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. diff --git a/reagent/core/aggregators.py b/reagent/core/aggregators.py index af24693d..ebb2b114 100644 --- a/reagent/core/aggregators.py +++ b/reagent/core/aggregators.py @@ -3,100 +3,21 @@ import logging from collections import deque -from typing import Any, Callable, Deque, Dict, List, Optional +from typing import Callable, Deque, Dict, List, Optional import numpy as np import torch +from reagent.core.tracker import Aggregator from reagent.tensorboardX import SummaryWriterContext logger = logging.getLogger(__name__) -class Aggregator: - def __init__(self, key: str, interval: Optional[int] = None): - super().__init__() - self.key = key - self.iteration = 0 - self.interval = interval - self.aggregate_epoch = interval is None - self.intermediate_values: List[Any] = [] - - def update(self, key: str, value): - self.intermediate_values.append(value) - self.iteration += 1 - # pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`. - if self.interval and self.iteration % self.interval == 0: - logger.info( - f"Interval Agg. Update: {self.key}; iteration {self.iteration}; " - f"aggregator: {self.__class__.__name__}" - ) - self(self.key, self.intermediate_values) - self.intermediate_values = [] - - def finish_epoch(self): - # We need to reset iteration here to avoid aggregating on the same data multiple - # times - logger.info( - f"Epoch finished. Flushing: {self.key}; " - f"aggregator: {self.__class__.__name__}; points: {len(self.intermediate_values)}" - ) - self.iteration = 0 - if self.aggregate_epoch: - self(self.key, self.intermediate_values) - self.intermediate_values = [] - - def __call__(self, key: str, values): - assert key == self.key, f"Got {key}; expected {self.key}" - self.aggregate(values) - - def aggregate(self, intermediate_values): - pass - - def get_recent(self, count): - raise NotImplementedError() - - def get_all(self): - raise NotImplementedError() - - -class AppendAggregator(Aggregator): - def __init__(self, key: str, interval: Optional[int] = None): - super().__init__(key, interval) - self.values = [] - - def __call__(self, key: str, values): - assert key == self.key, f"Got {key}; expected {self.key}" - self.aggregate(values) - - def aggregate(self, intermediate_values): - self.values.extend(intermediate_values) - - def get_recent(self, count): - if len(self.values) == 0: - return [] - return self.values[-count:] - - def get_all(self): - return self.values - - class TensorAggregator(Aggregator): - def __call__(self, key: str, values, interval: Optional[int] = None): - if len(values) == 0: - return super().__call__(key, torch.tensor([0.0])) + def __call__(self, key: str, values): # Ensure that tensor is on cpu before aggregation. - reshaped_values = [] - for value in values: - if isinstance(value, list): - reshaped_values.append(torch.tensor(value)) - elif not hasattr(value, "size"): - reshaped_values.append(torch.tensor(value).unsqueeze(0)) - elif len(value.size()) == 0: - reshaped_values.append(value.unsqueeze(0)) - else: - reshaped_values.append(value) - values = torch.cat(reshaped_values, dim=0).cpu() + values = torch.cat(values, dim=0).cpu() return super().__call__(key, values) @@ -114,8 +35,8 @@ def _log_histogram_and_mean(log_key, val): class TensorBoardHistogramAndMeanAggregator(TensorAggregator): - def __init__(self, key: str, log_key: str, interval: Optional[int] = None): - super().__init__(key, interval) + def __init__(self, key: str, log_key: str): + super().__init__(key) self.log_key = log_key def aggregate(self, values): @@ -133,9 +54,8 @@ class TensorBoardActionHistogramAndMeanAggregator(TensorAggregator): title: str, actions: List[str], log_key_prefix: Optional[str] = None, - interval: Optional[int] = None, ): - super().__init__(key, interval) + super().__init__(key) self.log_key_prefix = log_key_prefix or f"{category}/{title}" self.actions = actions SummaryWriterContext.add_custom_scalars_multilinechart( @@ -157,10 +77,8 @@ class TensorBoardActionHistogramAndMeanAggregator(TensorAggregator): class TensorBoardActionCountAggregator(TensorAggregator): - def __init__( - self, key: str, title: str, actions: List[str], interval: Optional[int] = None - ): - super().__init__(key, interval) + def __init__(self, key: str, title: str, actions: List[str]): + super().__init__(key) self.log_key = f"actions/{title}" self.actions = actions SummaryWriterContext.add_custom_scalars_multilinechart( @@ -177,8 +95,8 @@ class TensorBoardActionCountAggregator(TensorAggregator): class MeanAggregator(TensorAggregator): - def __init__(self, key: str, interval: Optional[int] = None): - super().__init__(key, interval) + def __init__(self, key: str): + super().__init__(key) self.values: List[float] = [] def aggregate(self, values): @@ -186,14 +104,6 @@ class MeanAggregator(TensorAggregator): logger.info(f"{self.key}: {mean}") self.values.append(mean) - def get_recent(self, count): - if len(self.values) == 0: - return [] - return self.values[-count:] - - def get_all(self): - return self.values - class FunctionsByActionAggregator(TensorAggregator): """ @@ -234,14 +144,8 @@ class FunctionsByActionAggregator(TensorAggregator): } """ - def __init__( - self, - key: str, - actions: List[str], - fns: Dict[str, Callable], - interval: Optional[int] = None, - ): - super().__init__(key, interval) + def __init__(self, key: str, actions: List[str], fns: Dict[str, Callable]): + super().__init__(key) self.actions = actions self.values: Dict[str, Dict[str, List[float]]] = { fn: {action: [] for action in self.actions} for fn in fns @@ -268,8 +172,8 @@ class ActionCountAggregator(TensorAggregator): `len(actions) - 1`. The input is assumed to contain action index. """ - def __init__(self, key: str, actions: List[str], interval: Optional[int] = None): - super().__init__(key, interval) + def __init__(self, key: str, actions: List[str]): + super().__init__(key) self.actions = actions self.values: Dict[str, List[int]] = {action: [] for action in actions} @@ -286,7 +190,7 @@ class ActionCountAggregator(TensorAggregator): """ totals = np.array([sum(counts) for counts in zip(*self.values.values())]) return { - action: (np.array(counts) / np.clip(totals, 1, None)).tolist() + action: (np.array(counts) / totals).tolist() for action, counts in self.values.items() } @@ -294,7 +198,7 @@ class ActionCountAggregator(TensorAggregator): """ Returns the cumulative distributions in each aggregating step """ - totals = max(1, sum(sum(counts) for counts in zip(*self.values.values()))) + totals = sum(sum(counts) for counts in zip(*self.values.values())) return {action: sum(counts) / totals for action, counts in self.values.items()} @@ -302,20 +206,10 @@ _RECENT_DEFAULT_SIZE = int(1e6) class RecentValuesAggregator(TensorAggregator): - def __init__( - self, key: str, size: int = _RECENT_DEFAULT_SIZE, interval: Optional[int] = None - ): - super().__init__(key, interval) + def __init__(self, key: str, size: int = _RECENT_DEFAULT_SIZE): + super().__init__(key) self.values: Deque[float] = deque(maxlen=size) def aggregate(self, values): flattened = torch.flatten(values).tolist() self.values.extend(flattened) - - def get_recent(self, count): - if len(self.values) == 0: - return [] - return self.values[-count:] - - def get_all(self): - return self.values diff --git a/reagent/core/async_wrapper.py b/reagent/core/async_wrapper.py deleted file mode 100644 index bf156f5c..00000000 --- a/reagent/core/async_wrapper.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python3 - -import functools -import importlib - - -if importlib.util.find_spec("fblearner") is not None: - import fblearner.flow.api as flow - - class AsyncWrapper: - def __init__(self, **kwargs): - self.async_wrapper = flow.flow_async(**kwargs) - self.type_wrapper = flow.typed() - - def __call__(self, func): - return self.async_wrapper(self.type_wrapper(func)) - - -else: - - def AsyncWrapper(**outer_kwargs): - def async_wrapper_internal(func): - @functools.wraps(func) - def async_wrapper_repeat(*args, **kwargs): - return func(*args, **kwargs) - - return async_wrapper_repeat - - return async_wrapper_internal diff --git a/reagent/core/observers.py b/reagent/core/observers.py new file mode 100644 index 00000000..4fe1c6cb --- /dev/null +++ b/reagent/core/observers.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +from typing import Any, Dict, Iterable, List, Optional + +from reagent.core.tracker import Aggregator, Observer + + +logger = logging.getLogger(__name__) + + +class CompositeObserver(Observer): + """ + A composite observer which takes care of dispatching values to child observers + """ + + def __init__(self, observers: Iterable[Observer]): + self.observers: Dict[str, List[Observer]] = {} + for observer in observers: + observing_keys = observer.get_observing_keys() + for key in observing_keys: + self.observers.setdefault(key, []).append(observer) + super().__init__(list(self.observers)) + + def update(self, key: str, value): + for observer in self.observers[key]: + observer.update(key, value) + + +class EpochEndObserver(Observer): + """ + Call the callback function with epoch # when the epoch ends + """ + + def __init__(self, callback, key: str = "epoch_end"): + super().__init__(observing_keys=[key]) + self.callback = callback + + def update(self, key: str, value): + self.callback(value) + + +class ValueListObserver(Observer): + """ + Simple observer that collect values into a list + """ + + def __init__(self, observing_key: str): + super().__init__(observing_keys=[observing_key]) + self.observing_key = observing_key + self.values: List[Any] = [] + + def update(self, key: str, value): + self.values.append(value) + + def reset(self): + self.values = [] + + +class IntervalAggregatingObserver(Observer): + def __init__( + self, + interval: Optional[int], + aggregator: Aggregator, + observe_epoch_end: bool = True, + ): + self.key = aggregator.key + obs_keys = ["epoch_end"] if observe_epoch_end else [] + obs_keys.append(self.key) + super().__init__(observing_keys=obs_keys) + self.iteration = 0 + self.interval = interval + self.intermediate_values: List[Any] = [] + self.aggregator = aggregator + + def update(self, key: str, value): + if key == "epoch_end": + self.flush() + return + + self.intermediate_values.append(value) + self.iteration += 1 + # pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`. + if self.interval and self.iteration % self.interval == 0: + logger.info( + f"Interval Agg. Update: {self.key}; iteration {self.iteration}; " + f"aggregator: {self.aggregator.__class__.__name__}" + ) + self.aggregator(self.key, self.intermediate_values) + self.intermediate_values = [] + + def flush(self): + # We need to reset iteration here to avoid aggregating on the same data multiple + # times + logger.info( + f"Interval Agg. Flushing: {self.key}; iteration: {self.iteration}; " + f"aggregator: {self.aggregator.__class__.__name__}; points: {len(self.intermediate_values)}" + ) + self.iteration = 0 + if self.intermediate_values: + self.aggregator(self.key, self.intermediate_values) + self.intermediate_values = [] diff --git a/reagent/core/registry_meta.py b/reagent/core/registry_meta.py index 0d87f9da..b8bef96b 100644 --- a/reagent/core/registry_meta.py +++ b/reagent/core/registry_meta.py @@ -16,7 +16,7 @@ class RegistryMeta(abc.ABCMeta): def __init__(cls, name, bases, attrs): if not hasattr(cls, "REGISTRY"): # Put REGISTRY on cls. This only happens once on the base class - logger.debug("Adding REGISTRY to type {}".format(name)) + logger.info("Adding REGISTRY to type {}".format(name)) cls.REGISTRY: Dict[str, Type] = {} cls.REGISTRY_NAME = name cls.REGISTRY_FROZEN = False @@ -28,19 +28,12 @@ class RegistryMeta(abc.ABCMeta): if not cls.__abstractmethods__ and name != cls.REGISTRY_NAME: # Only register fully-defined classes + logger.info(f"Registering {name} to {cls.REGISTRY_NAME}") if hasattr(cls, "__registry_name__"): registry_name = cls.__registry_name__ - logger.info( - f"Registering {name} with alias {registry_name} to {cls.REGISTRY_NAME}" - ) + logger.info(f"Using {registry_name} instead of {name}") name = registry_name - else: - logger.info(f"Registering {name} to {cls.REGISTRY_NAME}") - # assert name not in cls.REGISTRY - # TODO: Combine FB and OSS model managers and then bring back this assert. - # For now this works because FB model managers inherit from their OSS counterparts - if name in cls.REGISTRY: - logger.warning(f"Overwriting open source {name} with internal version") + assert name not in cls.REGISTRY cls.REGISTRY[name] = cls else: logger.info( diff --git a/reagent/core/result_types.py b/reagent/core/result_types.py index 116acb79..a22bb6bf 100644 --- a/reagent/core/result_types.py +++ b/reagent/core/result_types.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from reagent.core.dataclasses import dataclass -from reagent.reporting.result_registries import PublishingResult, ValidationResult +from reagent.workflow.result_registries import PublishingResult, ValidationResult @dataclass diff --git a/reagent/core/rl_training_output.py b/reagent/core/rl_training_output.py deleted file mode 100644 index 950c7802..00000000 --- a/reagent/core/rl_training_output.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from dataclasses import dataclass -from typing import Optional - -from reagent.core.union import ( - PublishingResult__Union, - TrainingReport__Union, - ValidationResult__Union, -) - - -@dataclass -class RLTrainingOutput: - validation_result: Optional[ValidationResult__Union] = None - publishing_result: Optional[PublishingResult__Union] = None - training_report: Optional[TrainingReport__Union] = None - local_output_path: Optional[str] = None diff --git a/reagent/core/tracker.py b/reagent/core/tracker.py new file mode 100644 index 00000000..0f03090f --- /dev/null +++ b/reagent/core/tracker.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +import logging +from typing import List + +import torch + + +logger = logging.getLogger(__name__) + + +class Observer: + """ + Base class for observers + """ + + def __init__(self, observing_keys: List[str]): + super().__init__() + assert isinstance(observing_keys, list) + self.observing_keys = observing_keys + + def get_observing_keys(self) -> List[str]: + return self.observing_keys + + def update(self, key: str, value): + pass + + +class Aggregator: + def __init__(self, key: str): + super().__init__() + self.key = key + + def __call__(self, key: str, values): + assert key == self.key, f"Got {key}; expected {self.key}" + self.aggregate(values) + + def aggregate(self, values): + pass + + +def observable(cls=None, **kwargs): # noqa: C901 + """ + Decorator to mark a class as producing observable values. The names of the + observable values are the names of keyword arguments. The values of keyword + arguments are the types of the value. The type is currently not used for + anything. + """ + assert kwargs + observable_value_types = kwargs + + def wrap(cls): + assert not hasattr(cls, "add_observer") + assert not hasattr(cls, "notify_observers") + + original_init = cls.__init__ + + @functools.wraps(original_init) + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + assert not hasattr(self, "_observable_value_types") + assert not hasattr(self, "_observers") + self._observable_value_types = observable_value_types + self._observers = {v: [] for v in observable_value_types} + + cls.__init__ = new_init + + def add_observer(self, observer: Observer) -> None: + observing_keys = observer.get_observing_keys() + unknown_keys = [ + k for k in observing_keys if k not in self._observable_value_types + ] + if unknown_keys: + logger.warning(f"{unknown_keys} cannot be observed in {type(self)}") + for k in observing_keys: + if k in self._observers and observer not in self._observers[k]: + self._observers[k].append(observer) + return self + + cls.add_observer = add_observer + + def add_observers(self, observers: List[Observer]) -> None: + for observer in observers: + self.add_observer(observer) + return self + + cls.add_observers = add_observers + + def notify_observers(self, **kwargs): + for key, value in kwargs.items(): + if value is None: + # Allow optional reporting + continue + + assert key in self._observers, f"Unknown key: {key}" + + # TODO: Create a generic framework for type conversion + if self._observable_value_types[key] == torch.Tensor: + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + if len(value.shape) == 0: + value = value.reshape(1) + value = value.detach() + + for observer in self._observers[key]: + observer.update(key, value) + + cls.notify_observers = notify_observers + + return cls + + if cls is None: + return wrap + + return wrap(cls) diff --git a/reagent/core/types.py b/reagent/core/types.py index 495e9d56..6e871fbb 100644 --- a/reagent/core/types.py +++ b/reagent/core/types.py @@ -1,26 +1,32 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import dataclasses -import logging - -# The dataclasses in this file should be vanilla dataclass to have minimal overhead -from dataclasses import dataclass, field from datetime import datetime as RecurringPeriod # noqa -from typing import Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Dict, List, Optional -import torch -import torch.nn.functional as F -from reagent.base_dataclass import BaseDataClass -from reagent.core.configuration import param_hash -from reagent.core.dataclasses import dataclass as pydantic_dataclass -from reagent.preprocessing.normalization_constants import ( +# Triggering registration to registries +import reagent.core.result_types # noqa +import reagent.workflow.training_reports # noqa +from reagent.core.dataclasses import dataclass +from reagent.core.fb_checker import IS_FB_ENVIRONMENT +from reagent.core.tagged_union import TaggedUnion # noqa F401 +from reagent.models.model_feature_config_provider import ModelFeatureConfigProvider +from reagent.preprocessing.normalization import ( DEFAULT_MAX_QUANTILE_SIZE, DEFAULT_MAX_UNIQUE_ENUM, DEFAULT_NUM_SAMPLES, DEFAULT_QUANTILE_K2_THRESHOLD, ) -from reagent.preprocessing.types import InputColumn +from reagent.types import BaseDataClass +from reagent.workflow.result_registries import PublishingResult, ValidationResult +from reagent.workflow.training_reports import TrainingReport + + +if IS_FB_ENVIRONMENT: + from reagent.fb.models.model_feature_config_builder import ( # noqa + ConfigeratorModelFeatureConfigProvider, + ) + import reagent.core.fb.fb_types # noqa @dataclass @@ -34,7 +40,7 @@ class OssDataset(Dataset): @dataclass -class TableSpec(BaseDataClass): +class TableSpec: table: str table_sample: Optional[float] = None eval_table_sample: Optional[float] = None @@ -44,11 +50,27 @@ class TableSpec(BaseDataClass): class RewardOptions: custom_reward_expression: Optional[str] = None metric_reward_values: Optional[Dict[str, float]] = None + additional_reward_expression: Optional[str] = None + + # for ranking + # key: feature id in slate_reward column, value: linear coefficient + slate_reward_values: Optional[Dict[str, float]] = None + # key: feature id in item_reward column, value: linear coefficient + item_reward_values: Optional[Dict[str, float]] = None @dataclass class ReaderOptions: - pass + num_threads: int = 32 + skip_smaller_batches: bool = True + num_workers: int = 0 + koski_logging_level: int = 2 + # distributed reader + distributed_reader: bool = False + distributed_master_mem: str = "20G" + distributed_worker_mem: str = "20G" + distributed_num_workers: int = 2 + gang_name: str = "" @dataclass @@ -58,7 +80,10 @@ class OssReaderOptions(ReaderOptions): @dataclass class ResourceOptions: - pass + cpu: Optional[int] = None + # "-1" or "xxG" where "xx" is a positive integer + memory: Optional[str] = "40g" + gpu: int = 1 @dataclass @@ -84,713 +109,45 @@ class PreprocessingOptions(BaseDataClass): set_missing_value_to_zero: Optional[bool] = False whitelist_features: Optional[List[int]] = None assert_whitelist_feature_coverage: bool = True - variance_threshold: VarianceThreshold = VarianceThreshold() - sequence_feature_id: Optional[int] = None - ignore_sanity_check_failure: bool = IGNORE_SANITY_CHECK_FAILURE ignore_sanity_check_task: bool = False + variance_threshold: VarianceThreshold = VarianceThreshold() load_from_operator_id: Optional[int] = None skip_sanity_check: bool = False - - # IdMappings are stored in manifold folder: - # "tree/{namespace}/{tablename}/{ds}/{base_mapping_name}/{embedding_table_name}" - base_mapping_name: str = "DefaultMappingName" + sequence_feature_id: Optional[int] = None ### below here for preprocessing sparse features ### # If the number of occurrences of any raw features ids is lower than this, we # ignore those feature ids when constructing the IdMapping sparse_threshold: int = 0 + # IdMappings are stored in manifold folder: + # "tree/{namespace}/{tablename}/{ds}/{base_mapping_name}/{embedding_table_name}" + base_mapping_name: str = "DefaultMappingName" -class NoDuplicatedWarningLogger: - def __init__(self, logger): - self.logger = logger - self.msg = set() - - def warning(self, msg): - if msg not in self.msg: - self.logger.warning(msg) - self.msg.add(msg) +@ModelFeatureConfigProvider.fill_union() +class ModelFeatureConfigProvider__Union(TaggedUnion): + pass -logger = logging.getLogger(__name__) -no_dup_logger = NoDuplicatedWarningLogger(logger) +@PublishingResult.fill_union() +class PublishingResult__Union(TaggedUnion): + pass -def isinstance_namedtuple(x): - return isinstance(x, tuple) and hasattr(x, "_fields") +@ValidationResult.fill_union() +class ValidationResult__Union(TaggedUnion): + pass + + +@TrainingReport.fill_union() +class RLTrainingReport(TaggedUnion): + pass @dataclass -class TensorDataClass(BaseDataClass): - def __getattr__(self, attr): - if attr.startswith("__") and attr.endswith("__"): - raise AttributeError - - tensor_attr = getattr(torch.Tensor, attr, None) - - if tensor_attr is None or not callable(tensor_attr): - logger.error( - f"Attemping to call torch.Tensor.{attr} on " - f"{type(self)} (instance of TensorDataClass)." - ) - if tensor_attr is None: - raise AttributeError(f"torch.Tensor doesn't have {attr} attribute.") - else: - raise RuntimeError(f"Tensor.{attr} is not callable.") - - def continuation(*args, **kwargs): - def f(v): - # if possible, returns v.attr(*args, **kwargs). - # otws, return v - if isinstance(v, (torch.Tensor, TensorDataClass)): - return getattr(v, attr)(*args, **kwargs) - elif isinstance(v, dict): - return {kk: f(vv) for kk, vv in v.items()} - elif isinstance(v, tuple): - return tuple(f(vv) for vv in v) - return v - - return type(self)(**f(self.__dict__)) - - return continuation - - def cuda(self, *args, **kwargs): - cuda_tensor = {} - for k, v in self.__dict__.items(): # noqa F402 - if isinstance(v, torch.Tensor): - kwargs["non_blocking"] = kwargs.get("non_blocking", True) - cuda_tensor[k] = v.cuda(*args, **kwargs) - elif isinstance(v, TensorDataClass): - cuda_tensor[k] = v.cuda(*args, **kwargs) - else: - cuda_tensor[k] = v - return type(self)(**cuda_tensor) - - -# (offset, value) -IdListFeatureValue = Tuple[torch.Tensor, torch.Tensor] -# (offset, key, value) -IdScoreListFeatureValue = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] -# name -> value -IdListFeature = Dict[str, IdListFeatureValue] -IdScoreListFeature = Dict[str, IdScoreListFeatureValue] -# id -> value -ServingIdListFeature = Dict[int, IdListFeatureValue] -ServingIdScoreListFeature = Dict[int, IdScoreListFeatureValue] - - -##### -# FIXME: These config types are misplaced but we need to write FBL config adapter -# if we moved them. -###### - - -@pydantic_dataclass -class IdListFeatureConfig(BaseDataClass): - name: str - # integer feature ID - feature_id: int - # name of the embedding table to use - id_mapping_name: str - - -@pydantic_dataclass -class IdScoreListFeatureConfig(BaseDataClass): - name: str - # integer feature ID - feature_id: int - # name of the embedding table to use - id_mapping_name: str - - -@pydantic_dataclass -class FloatFeatureInfo(BaseDataClass): - name: str - feature_id: int - - -@pydantic_dataclass -class IdMapping(object): - __hash__ = param_hash - - ids: List[int] = field(default_factory=list) - - def __post_init_post_parse__(self): - """ - used in preprocessing - ids list represents mapping from idx -> value - we want the reverse: from feature to embedding table indices - """ - self._id2index: Dict[int, int] = {} - - @property - def id2index(self) -> Dict[int, int]: - # pyre-fixme[16]: `IdMapping` has no attribute `_id2index`. - if not self._id2index: - self._id2index = {id: i for i, id in enumerate(self.ids)} - return self._id2index - - @property - def table_size(self): - return len(self.ids) - - -@pydantic_dataclass -class ModelFeatureConfig(BaseDataClass): - float_feature_infos: List[FloatFeatureInfo] = field(default_factory=list) - # table name -> id mapping - id_mapping_config: Dict[str, IdMapping] = field(default_factory=dict) - # id_list_feature_configs is feature_id -> list of values - id_list_feature_configs: List[IdListFeatureConfig] = field(default_factory=list) - # id_score_list_feature_configs is feature_id -> (keys -> values) - id_score_list_feature_configs: List[IdScoreListFeatureConfig] = field( - default_factory=list - ) - - def __post_init_post_parse__(self): - both_lists = self.id_list_feature_configs + self.id_score_list_feature_configs - if not self.only_dense: - # sanity check for keys in mapping config - ids = [config.feature_id for config in both_lists] - names = [config.name for config in both_lists] - assert len(ids) == len(set(ids)), f"duplicates in ids: {ids}" - assert len(names) == len(set(names)), f"duplicates in names: {names}" - assert len(ids) == len(names), f"{len(ids)} != {len(names)}" - - self._id2name = {config.feature_id: config.name for config in both_lists} - self._name2id = {config.name: config.feature_id for config in both_lists} - self._id2config = {config.feature_id: config for config in both_lists} - self._name2config = {config.name: config for config in both_lists} - - @property - def only_dense(self): - return not (self.id_list_feature_configs or self.id_score_list_feature_configs) - - @property - def id2name(self): - return self._id2name - - @property - def name2id(self): - return self._name2id - - @property - def id2config(self): - return self._id2config - - @property - def name2config(self): - return self._name2config - - -###### -# dataclasses for internal API -###### - - -@dataclass -class ValuePresence(TensorDataClass): - value: torch.Tensor - presence: Optional[torch.Tensor] - - -@dataclass -class ActorOutput(TensorDataClass): - action: torch.Tensor - log_prob: Optional[torch.Tensor] = None - squashed_mean: Optional[torch.Tensor] = None - - -@dataclass -class DocList(TensorDataClass): - # the shape is (batch_size, num_candidates, num_document_features) - float_features: torch.Tensor - # the shapes are (batch_size, num_candidates) - mask: torch.Tensor - value: torch.Tensor - - def __post_init__(self): - assert ( - len(self.float_features.shape) == 3 - ), f"Unexpected shape: {self.float_features.shape}" - - # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because - # its type `no_grad` is not callable. - @torch.no_grad() - def select_slate(self, action: torch.Tensor): - row_idx = torch.repeat_interleave( - torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1 - ) - mask = self.mask[row_idx, action] - # Make sure the indices are in the right range - assert mask.to(torch.bool).all() - float_features = self.float_features[row_idx, action] - value = self.value[row_idx, action] - return DocList(float_features, mask, value) - - def as_feature_data(self): - _batch_size, _slate_size, feature_dim = self.float_features.shape - return FeatureData(self.float_features.view(-1, feature_dim)) - - -@dataclass -class FeatureData(TensorDataClass): - # For dense features, shape is (batch_size, feature_dim) - float_features: torch.Tensor - id_list_features: IdListFeature = dataclasses.field(default_factory=dict) - id_score_list_features: IdScoreListFeature = dataclasses.field(default_factory=dict) - # For sequence, shape is (stack_size, batch_size, feature_dim) - stacked_float_features: Optional[torch.Tensor] = None - # For ranking algos, - candidate_docs: Optional[DocList] = None - # Experimental: sticking this here instead of putting it in float_features - # because a lot of places derive the shape of float_features from - # normalization parameters. - time_since_first: Optional[torch.Tensor] = None - - def __post_init__(self): - def usage(): - return ( - "For sequence features, use `stacked_float_features`." - "For document features, use `candidate_doc_float_features`." - ) - - if self.float_features.ndim == 3: - no_dup_logger.warning(f"`float_features` should be 2D.\n{usage()}") - elif self.float_features.ndim != 2: - raise ValueError( - f"float_features should be 2D; got {self.float_features.shape}.\n{usage()}" - ) - - @property - def has_float_features_only(self) -> bool: - return ( - not self.id_list_features - and self.time_since_first is None - and self.candidate_docs is None - ) - - def get_tiled_batch(self, num_tiles: int): - assert ( - self.has_float_features_only - ), f"only works for float features now: {self}" - """ - tiled_feature should be (batch_size * num_tiles, feature_dim) - forall i in [batch_size], - tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i] - """ - feat = self.float_features - assert ( - len(feat.shape) == 2 - ), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}." - batch_size, _ = feat.shape - # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`. - tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0) - return FeatureData(float_features=tiled_feat) - - -class TensorFeatureData(torch.nn.Module): - """ - Primarily for using in nn.Sequential - """ - - def forward(self, input: torch.Tensor) -> FeatureData: - assert isinstance(input, torch.Tensor) - return FeatureData(input) - - -class ServingFeatureData(NamedTuple): - float_features_with_presence: Tuple[torch.Tensor, torch.Tensor] - id_list_features: ServingIdListFeature - id_score_list_features: ServingIdScoreListFeature - - -@dataclass -class PreprocessedRankingInput(TensorDataClass): - state: FeatureData - src_seq: FeatureData - src_src_mask: torch.Tensor - tgt_in_seq: Optional[FeatureData] = None - tgt_out_seq: Optional[FeatureData] = None - tgt_tgt_mask: Optional[torch.Tensor] = None - slate_reward: Optional[torch.Tensor] = None - position_reward: Optional[torch.Tensor] = None - # all indices will be +2 to account for padding - # symbol (0) and decoder_start_symbol (1) - src_in_idx: Optional[torch.Tensor] = None - tgt_in_idx: Optional[torch.Tensor] = None - tgt_out_idx: Optional[torch.Tensor] = None - tgt_out_probs: Optional[torch.Tensor] = None - # store ground-truth target sequences - optim_tgt_in_idx: Optional[torch.Tensor] = None - optim_tgt_out_idx: Optional[torch.Tensor] = None - optim_tgt_in_seq: Optional[FeatureData] = None - optim_tgt_out_seq: Optional[FeatureData] = None - - def batch_size(self) -> int: - return self.state.float_features.size()[0] - - @classmethod - def from_tensors( - cls, - state: torch.Tensor, - src_seq: torch.Tensor, - src_src_mask: torch.Tensor, - tgt_in_seq: Optional[torch.Tensor] = None, - tgt_out_seq: Optional[torch.Tensor] = None, - tgt_tgt_mask: Optional[torch.Tensor] = None, - slate_reward: Optional[torch.Tensor] = None, - position_reward: Optional[torch.Tensor] = None, - src_in_idx: Optional[torch.Tensor] = None, - tgt_in_idx: Optional[torch.Tensor] = None, - tgt_out_idx: Optional[torch.Tensor] = None, - tgt_out_probs: Optional[torch.Tensor] = None, - optim_tgt_in_idx: Optional[torch.Tensor] = None, - optim_tgt_out_idx: Optional[torch.Tensor] = None, - optim_tgt_in_seq: Optional[torch.Tensor] = None, - optim_tgt_out_seq: Optional[torch.Tensor] = None, - **kwargs, - ): - assert isinstance(state, torch.Tensor) - assert isinstance(src_seq, torch.Tensor) - assert isinstance(src_src_mask, torch.Tensor) - assert tgt_in_seq is None or isinstance(tgt_in_seq, torch.Tensor) - assert tgt_out_seq is None or isinstance(tgt_out_seq, torch.Tensor) - assert tgt_tgt_mask is None or isinstance(tgt_tgt_mask, torch.Tensor) - assert slate_reward is None or isinstance(slate_reward, torch.Tensor) - assert position_reward is None or isinstance(position_reward, torch.Tensor) - assert src_in_idx is None or isinstance(src_in_idx, torch.Tensor) - assert tgt_in_idx is None or isinstance(tgt_in_idx, torch.Tensor) - assert tgt_out_idx is None or isinstance(tgt_out_idx, torch.Tensor) - assert tgt_out_probs is None or isinstance(tgt_out_probs, torch.Tensor) - assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor) - assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor) - assert optim_tgt_in_seq is None or isinstance(optim_tgt_in_seq, torch.Tensor) - assert optim_tgt_out_seq is None or isinstance(optim_tgt_out_seq, torch.Tensor) - - return cls( - state=FeatureData(float_features=state), - src_seq=FeatureData(float_features=src_seq), - src_src_mask=src_src_mask, - tgt_in_seq=FeatureData(float_features=tgt_in_seq) - if tgt_in_seq is not None - else None, - tgt_out_seq=FeatureData(float_features=tgt_out_seq) - if tgt_out_seq is not None - else None, - tgt_tgt_mask=tgt_tgt_mask, - slate_reward=slate_reward, - position_reward=position_reward, - src_in_idx=src_in_idx, - tgt_in_idx=tgt_in_idx, - tgt_out_idx=tgt_out_idx, - tgt_out_probs=tgt_out_probs, - optim_tgt_in_idx=optim_tgt_in_idx, - optim_tgt_out_idx=optim_tgt_out_idx, - optim_tgt_in_seq=FeatureData(float_features=optim_tgt_in_seq) - if optim_tgt_in_seq is not None - else None, - optim_tgt_out_seq=FeatureData(float_features=optim_tgt_out_seq) - if optim_tgt_out_seq is not None - else None, - ) - - def __post_init__(self): - if ( - isinstance(self.state, torch.Tensor) - or isinstance(self.src_seq, torch.Tensor) - or isinstance(self.tgt_in_seq, torch.Tensor) - or isinstance(self.tgt_out_seq, torch.Tensor) - or isinstance(self.optim_tgt_in_seq, torch.Tensor) - or isinstance(self.optim_tgt_out_seq, torch.Tensor) - ): - raise ValueError( - f"Use from_tensors() {type(self.state)} {type(self.src_seq)} " - f"{type(self.tgt_in_seq)} {type(self.tgt_out_seq)} " - f"{type(self.optim_tgt_in_seq)} {type(self.optim_tgt_out_seq)} " - ) - - -@dataclass -class BaseInput(TensorDataClass): - """ - Base class for all inputs, both raw and preprocessed - """ - - state: FeatureData - next_state: FeatureData - reward: torch.Tensor - time_diff: torch.Tensor - step: Optional[torch.Tensor] - not_terminal: torch.Tensor - - def batch_size(self): - return self.state.float_features.size()[0] - - @classmethod - def from_dict(cls, batch): - id_list_features = batch.get(InputColumn.STATE_ID_LIST_FEATURES, None) or {} - id_score_list_features = ( - batch.get(InputColumn.STATE_ID_SCORE_LIST_FEATURES, None) or {} - ) - next_id_list_features = ( - batch.get(InputColumn.NEXT_STATE_ID_LIST_FEATURES, None) or {} - ) - next_id_score_list_features = ( - batch.get(InputColumn.NEXT_STATE_ID_SCORE_LIST_FEATURES, None) or {} - ) - return BaseInput( - state=FeatureData( - float_features=batch[InputColumn.STATE_FEATURES], - id_list_features=id_list_features, - id_score_list_features=id_score_list_features, - ), - next_state=FeatureData( - float_features=batch[InputColumn.NEXT_STATE_FEATURES], - id_list_features=next_id_list_features, - id_score_list_features=next_id_score_list_features, - ), - reward=batch[InputColumn.REWARD], - time_diff=batch[InputColumn.TIME_DIFF], - step=batch[InputColumn.STEP], - not_terminal=batch[InputColumn.NOT_TERMINAL], - ) - - -@dataclass -class ExtraData(TensorDataClass): - mdp_id: Optional[torch.Tensor] = None - sequence_number: Optional[torch.Tensor] = None - action_probability: Optional[torch.Tensor] = None - max_num_actions: Optional[int] = None - metrics: Optional[torch.Tensor] = None - - @classmethod - def from_dict(cls, d): - return cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)}) - - -@dataclass -class DiscreteDqnInput(BaseInput): - action: torch.Tensor - next_action: torch.Tensor - possible_actions_mask: torch.Tensor - possible_next_actions_mask: torch.Tensor - extras: ExtraData - - @classmethod - def from_dict(cls, batch): - base = super().from_dict(batch) - return cls( - state=base.state, - next_state=base.next_state, - reward=base.reward, - time_diff=base.time_diff, - step=base.step, - not_terminal=base.not_terminal, - action=batch[InputColumn.ACTION], - next_action=batch[InputColumn.NEXT_ACTION], - possible_actions_mask=batch[InputColumn.POSSIBLE_ACTIONS_MASK], - possible_next_actions_mask=batch[InputColumn.POSSIBLE_NEXT_ACTIONS_MASK], - extras=batch[InputColumn.EXTRAS], - ) - - -@dataclass -class SlateQInput(BaseInput): - """ - The shapes of `reward`, `reward_mask`, & `next_item_mask` are - `(batch_size, slate_size)`. - - `reward_mask` indicated whether the reward could be observed, e.g., - the item got into viewport or not. - """ - - action: torch.Tensor - next_action: torch.Tensor - reward_mask: torch.Tensor - extras: Optional[ExtraData] = None - - @classmethod - def from_dict(cls, d): - action = d["action"] - next_action = d["next_action"] - return cls( - state=FeatureData( - float_features=d["state_features"], - candidate_docs=DocList( - float_features=d["candidate_features"], - mask=d["item_mask"], - value=d["item_probability"], - ), - ), - next_state=FeatureData( - float_features=d["next_state_features"], - candidate_docs=DocList( - float_features=d["next_candidate_features"], - mask=d["next_item_mask"], - value=d["next_item_probability"], - ), - ), - action=action, - next_action=next_action, - reward=d["position_reward"], - reward_mask=d["reward_mask"], - time_diff=d["time_diff"], - not_terminal=d["not_terminal"], - step=None, - extras=ExtraData.from_dict(d), - ) - - -@dataclass -class ParametricDqnInput(BaseInput): - action: FeatureData - next_action: FeatureData - possible_actions: FeatureData - possible_actions_mask: torch.Tensor - possible_next_actions: FeatureData - possible_next_actions_mask: torch.Tensor - extras: Optional[ExtraData] = None - - @classmethod - def from_dict(cls, batch): - return cls( - state=FeatureData(float_features=batch["state_features"]), - action=FeatureData(float_features=batch["action"]), - next_state=FeatureData(float_features=batch["next_state_features"]), - next_action=FeatureData(float_features=batch["next_action"]), - possible_actions=FeatureData(float_features=batch["possible_actions"]), - possible_actions_mask=batch["possible_actions_mask"], - possible_next_actions=FeatureData( - float_features=batch["possible_next_actions"] - ), - possible_next_actions_mask=batch["possible_next_actions_mask"], - reward=batch["reward"], - not_terminal=batch["not_terminal"], - time_diff=batch["time_diff"], - step=batch["step"], - extras=batch["extras"], - ) - - -@dataclass -class PolicyNetworkInput(BaseInput): - action: FeatureData - next_action: FeatureData - extras: Optional[ExtraData] = None - - @classmethod - def from_dict(cls, batch): - return cls( - state=FeatureData(float_features=batch["state_features"]), - action=FeatureData(float_features=batch["action"]), - next_state=FeatureData(float_features=batch["next_state_features"]), - next_action=FeatureData(float_features=batch["next_action"]), - reward=batch["reward"], - not_terminal=batch["not_terminal"], - time_diff=batch["time_diff"], - step=batch["step"], - extras=batch["extras"], - ) - - def batch_size(self) -> int: - return self.state.float_features.shape[0] - - -@dataclass -class PolicyGradientInput(BaseDataClass): - state: FeatureData - action: torch.Tensor - reward: torch.Tensor - log_prob: torch.Tensor - - @classmethod - def input_prototype(cls): - num_classes = 5 - batch_size = 10 - state_dim = 3 - return cls( - state=FeatureData(float_features=torch.randn(batch_size, state_dim)), - action=F.one_hot(torch.randint(high=num_classes, size=(batch_size,))), - reward=torch.rand(batch_size), - log_prob=torch.log(torch.rand(batch_size)), - ) - - -@dataclass -class MemoryNetworkInput(BaseInput): - action: torch.Tensor - - def batch_size(self): - if len(self.state.float_features.size()) == 2: - return self.state.float_features.size()[0] - elif len(self.state.float_features.size()) == 3: - return self.state.float_features.size()[1] - else: - raise NotImplementedError() - - -@dataclass -class PreprocessedTrainingBatch(TensorDataClass): - training_input: Union[PreprocessedRankingInput] - # TODO: deplicate this and move into individual ones. - extras: ExtraData = field(default_factory=ExtraData) - - def batch_size(self): - return self.training_input.state.float_features.size()[0] - - -@dataclass -class MemoryNetworkOutput(TensorDataClass): - mus: torch.Tensor - sigmas: torch.Tensor - logpi: torch.Tensor - reward: torch.Tensor - not_terminal: torch.Tensor - last_step_lstm_hidden: torch.Tensor - last_step_lstm_cell: torch.Tensor - all_steps_lstm_hidden: torch.Tensor - - -@dataclass -class Seq2RewardOutput(TensorDataClass): - acc_reward: torch.Tensor - - -@dataclass -class DqnPolicyActionSet(TensorDataClass): - greedy: int - softmax: Optional[int] = None - greedy_act_name: Optional[str] = None - softmax_act_name: Optional[str] = None - softmax_act_prob: Optional[float] = None - - -@dataclass -class PlanningPolicyOutput(TensorDataClass): - # best action to take next - next_best_continuous_action: Optional[torch.Tensor] = None - next_best_discrete_action_one_hot: Optional[torch.Tensor] = None - next_best_discrete_action_idx: Optional[int] = None - - -@dataclass -class RankingOutput(TensorDataClass): - # a tensor of integer indices w.r.t. to possible candidates - # shape: batch_size, tgt_seq_len - ranked_tgt_out_idx: Optional[torch.Tensor] = None - # generative probability of ranked tgt sequences at each decoding step - # shape: batch_size, tgt_seq_len, candidate_size - ranked_tgt_out_probs: Optional[torch.Tensor] = None - # log probabilities of given tgt sequences are used in REINFORCE - # shape: batch_size - log_probs: Optional[torch.Tensor] = None - # encoder scores in tgt_out_idx order - encoder_scores: Optional[torch.Tensor] = None - - -@dataclass -class RewardNetworkOutput(TensorDataClass): - predicted_reward: torch.Tensor +class RLTrainingOutput: + validation_result: Optional[ValidationResult__Union] = None + publishing_result: Optional[PublishingResult__Union] = None + training_report: Optional[RLTrainingReport] = None + output_path: Optional[str] = None diff --git a/reagent/core/union.py b/reagent/core/union.py deleted file mode 100644 index 4fde8dba..00000000 --- a/reagent/core/union.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from reagent.core.fb_checker import IS_FB_ENVIRONMENT -from reagent.core.tagged_union import TaggedUnion -from reagent.models.model_feature_config_provider import ModelFeatureConfigProvider -from reagent.reporting.result_registries import PublishingResult, ValidationResult -from reagent.reporting.training_reports import TrainingReport - - -if True: # Register modules for unions - import reagent.reporting.oss_training_reports # noqa - import reagent.core.result_types # noqa - - if IS_FB_ENVIRONMENT: - import reagent.reporting.fb.fb_training_reports # noqa - import reagent.fb.models.model_feature_config_builder # noqa - import reagent.core.fb.fb_result_types # noqa - import reagent.core.fb.fb_types # noqa - - -@ModelFeatureConfigProvider.fill_union() -class ModelFeatureConfigProvider__Union(TaggedUnion): - pass - - -@PublishingResult.fill_union() -class PublishingResult__Union(TaggedUnion): - pass - - -@ValidationResult.fill_union() -class ValidationResult__Union(TaggedUnion): - pass - - -@TrainingReport.fill_union() -class TrainingReport__Union(TaggedUnion): - pass diff --git a/reagent/data_fetchers/__init__.py b/reagent/data_fetchers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/reagent/data_fetchers/data_fetcher.py b/reagent/data_fetchers/data_fetcher.py deleted file mode 100644 index e2f65198..00000000 --- a/reagent/data_fetchers/data_fetcher.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python3 - - -import logging -from typing import Dict, Optional - -from reagent.core.types import Dataset, PreprocessingOptions, ReaderOptions, TableSpec -from reagent.parameters import NormalizationParameters -from reagent.preprocessing.batch_preprocessor import BatchPreprocessor - - -logger = logging.getLogger(__name__) - - -class DataFetcher: - # TODO: T71636145 Make a more specific API for DataFetcher - def query_data(self, **kwargs): - raise NotImplementedError() - - # TODO: T71636145 Make a more specific API for DataFetcher - def query_data_parametric(self, **kwargs): - raise NotImplementedError() - - def identify_normalization_parameters( - self, - table_spec: TableSpec, - column_name: str, - preprocessing_options: PreprocessingOptions, - seed: Optional[int] = None, - ) -> Dict[int, NormalizationParameters]: - raise NotImplementedError() - - def get_dataloader( - self, - dataset: Dataset, - batch_size: int, - batch_preprocessor: Optional[BatchPreprocessor], - use_gpu: bool, - reader_options: ReaderOptions, - ): - raise NotImplementedError() diff --git a/reagent/evaluation/compress_model_evaluator.py b/reagent/evaluation/compress_model_evaluator.py index 339947ab..f163563b 100644 --- a/reagent/evaluation/compress_model_evaluator.py +++ b/reagent/evaluation/compress_model_evaluator.py @@ -3,8 +3,8 @@ import logging import torch -from reagent.core.types import MemoryNetworkInput from reagent.training.world_model.compress_model_trainer import CompressModelTrainer +from reagent.types import MemoryNetworkInput logger = logging.getLogger(__name__) diff --git a/reagent/evaluation/evaluation_data_page.py b/reagent/evaluation/evaluation_data_page.py index f42a8a3a..c5e15f83 100644 --- a/reagent/evaluation/evaluation_data_page.py +++ b/reagent/evaluation/evaluation_data_page.py @@ -8,17 +8,8 @@ from typing import NamedTuple, Optional, cast import numpy as np import torch import torch.nn as nn -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet -from reagent.ope.estimators.sequential_estimators import ( - Action, - ActionSpace, - RLEstimatorInput, - RLPolicy, - State, - Transition, - ValueFunction, -) from reagent.torch_utils import masked_softmax from reagent.training import ParametricDQNTrainer from reagent.training.dqn_trainer import DQNTrainer @@ -51,7 +42,6 @@ class EvaluationDataPage(NamedTuple): model_metrics_values_for_logged_action: Optional[torch.Tensor] = None possible_actions_state_concat: Optional[torch.Tensor] = None contexts: Optional[torch.Tensor] = None - sequential_estimator_input: Optional[RLEstimatorInput] = None @classmethod def create_from_training_batch( @@ -320,83 +310,6 @@ class EvaluationDataPage(NamedTuple): eval_action_idxs=eval_action_idxs, ) - @staticmethod - def create_rl_estimator_input_from_tensors_dqn( - trainer: DQNTrainer, - mdp_ids: torch.Tensor, - states: rlt.FeatureData, - actions: rlt.FeatureData, - propensities: torch.Tensor, - rewards: torch.Tensor, - ): - class DQNRLPolicy(RLPolicy): - def __init__(self, trainer: DQNTrainer): - super().__init__(ActionSpace(trainer.num_actions)) - self._trainer = trainer - - def action_dist(self, state: State): - feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1)) - # Only 1 batch - q_values = self._trainer.get_detached_q_values(feat_data)[0][0] - return self._action_space.distribution( - torch.nn.Softmax(dim=0)(q_values) - ) - - class CPEValueFunction(ValueFunction): - def __init__(self, trainer: DQNTrainer): - self._trainer = trainer - - def state_action_value(self, state: State, action: Action) -> float: - feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1)) - model_values = self._trainer.q_network_cpe(feat_data)[ - :, 0 : self._trainer.num_actions - ][0] - return model_values[action.value].item() - - def state_value(self, state: State) -> float: - feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1)) - model_values = self._trainer.q_network_cpe(feat_data)[ - :, 0 : self._trainer.num_actions - ][0] - q_values = self._trainer.get_detached_q_values(feat_data)[0][0] - dist = torch.nn.Softmax(dim=0)(q_values) - assert dist.shape == model_values.shape - return torch.dot(dist, model_values).item() - - def reset(self): - pass - - states_tensor = states.float_features - logged_actions = torch.argmax(actions.float(), dim=1) - log = [] - cur_mdp = [] - i = 0 - while i < mdp_ids.shape[0]: - if i + 1 < mdp_ids.shape[0] and mdp_ids[i, 0] == mdp_ids[i + 1, 0]: - cur_mdp.append( - Transition( - last_state=State(states_tensor[i]), - action=Action(logged_actions[i].item()), - action_prob=propensities[i, 0].item(), - state=State(states_tensor[i + 1]), - reward=rewards[i, 0].item(), - status=Transition.Status.NORMAL, - ) - ) - elif len(cur_mdp) > 0: - log.append(cur_mdp) - cur_mdp = [] - i += 1 - - # Temporary value of gamma - return RLEstimatorInput( - gamma=1.0, - log=log, - target_policy=DQNRLPolicy(trainer), - value_function=CPEValueFunction(trainer), - discrete_states=False, - ) - @classmethod # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because # its type `no_grad` is not callable. @@ -541,9 +454,6 @@ class EvaluationDataPage(NamedTuple): possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, - sequential_estimator_input=EvaluationDataPage.create_rl_estimator_input_from_tensors_dqn( - trainer, mdp_ids, states, actions, propensities, rewards - ), ) def append(self, edp): @@ -560,15 +470,6 @@ class EvaluationDataPage(NamedTuple): new_edp[x] = torch.cat((t, other_t), dim=0) elif isinstance(t, np.ndarray): new_edp[x] = np.concatenate((t, other_t), axis=0) - elif isinstance(t, RLEstimatorInput): - t.log.extend(other_t.log) - new_edp[x] = RLEstimatorInput( - gamma=t.gamma, - log=t.log, - target_policy=t.target_policy, - value_function=t.value_function, - discrete_states=t.discrete_states, - ) else: raise Exception("Invalid type in training data page") else: @@ -583,10 +484,7 @@ class EvaluationDataPage(NamedTuple): new_edp = {} for x in EvaluationDataPage._fields: t = getattr(self, x) - if hasattr(t, "__getitem__"): - new_edp[x] = t[sorted_idxs] if t is not None else None - else: - new_edp[x] = t + new_edp[x] = t[sorted_idxs] if t is not None else None return EvaluationDataPage(**new_edp) diff --git a/reagent/evaluation/evaluator.py b/reagent/evaluation/evaluator.py index 3affbb07..7df5e08e 100644 --- a/reagent/evaluation/evaluator.py +++ b/reagent/evaluation/evaluator.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional import torch import torch.nn.functional as F -from reagent.core import types as rlt +from reagent.core.tracker import observable from reagent.evaluation.cpe import CpeDetails, CpeEstimateSet from reagent.evaluation.doubly_robust_estimator import DoublyRobustEstimator from reagent.evaluation.evaluation_data_page import EvaluationDataPage @@ -53,6 +53,7 @@ def get_metrics_to_score(metric_reward_values: Optional[Dict[str, float]]) -> Li return sorted([*metric_reward_values.keys()]) +@observable(cpe_details=CpeDetails) class Evaluator: NUM_J_STEPS_FOR_MAGIC_ESTIMATOR = 25 @@ -69,15 +70,7 @@ class Evaluator: gamma ) - self.reporter = None - - def evaluate(self, eval_input: rlt.TensorDataClass) -> None: - pass - - def finish(self): - pass - - def evaluate_one_shot(self, edp: EvaluationDataPage) -> CpeDetails: + def evaluate_post_training(self, edp: EvaluationDataPage) -> CpeDetails: cpe_details = CpeDetails() cpe_details.reward_estimates = self.score_cpe("Reward", edp) @@ -123,9 +116,8 @@ class Evaluator: cpe_details.mc_loss = float( F.mse_loss(edp.logged_values, edp.model_values_for_logged_action) ) - - assert self.reporter is not None, "Missing reporter" - self.reporter.report(cpe_results=cpe_details) + # pyre-fixme[16]: `Evaluator` has no attribute `notify_observers`. + self.notify_observers(cpe_details=cpe_details) return cpe_details def score_cpe(self, metric_name, edp: EvaluationDataPage): diff --git a/reagent/evaluation/ope_adapter.py b/reagent/evaluation/ope_adapter.py index f0c3e74a..0397fea9 100644 --- a/reagent/evaluation/ope_adapter.py +++ b/reagent/evaluation/ope_adapter.py @@ -11,6 +11,9 @@ from reagent.evaluation.cpe import ( ) from reagent.evaluation.evaluation_data_page import EvaluationDataPage from reagent.evaluation.evaluator import Evaluator +from reagent.evaluation.weighted_sequential_doubly_robust_estimator import ( + WeightedSequentialDoublyRobustEstimator, +) from reagent.ope.estimators.contextual_bandits_estimators import ( BanditsEstimatorInput, DMEstimator, @@ -31,6 +34,10 @@ from reagent.ope.estimators.sequential_estimators import ( MAGICEstimator, RLEstimator, RLEstimatorInput, + RLPolicy, + State, + Transition, + ValueFunction, ) from reagent.ope.estimators.types import ActionSpace @@ -109,6 +116,92 @@ class SequentialOPEstimatorAdapter: self.gamma = gamma self._device = device + class EDPSeqPolicy(RLPolicy): + def __init__( + self, num_actions: int, model_propensities: torch.Tensor, device=None + ): + super().__init__(ActionSpace(num_actions), device) + self.model_propensities = model_propensities + + def action_dist(self, state: State) -> ActionDistribution: + # "state" is (trajectory, step) + return self.model_propensities[state.value] + + class EDPValueFunc(ValueFunction): + def __init__( + self, model_values: torch.Tensor, target_propensities: torch.Tensor + ): + self.model_values = model_values + self.target_propensities = target_propensities + + def state_action_value(self, state: State, action: Action) -> float: + return self.model_values[state.value][action].item() + + def state_value(self, state: State) -> float: + return torch.dot( + self.model_values[state.value], self.target_propensities[state.value] + ).item() + + def reset(self): + pass + + @staticmethod + def edp_to_rl_input( + edp: EvaluationDataPage, gamma, device=None + ) -> RLEstimatorInput: + assert edp.model_values is not None + eq_len = WeightedSequentialDoublyRobustEstimator.transform_to_equal_length_trajectories( + edp.mdp_id, + edp.action_mask.cpu().numpy(), + edp.logged_rewards.cpu().numpy().flatten(), + edp.logged_propensities.cpu().numpy().flatten(), + edp.model_propensities.cpu().numpy(), + edp.model_values.cpu().numpy(), + ) + + ( + actions, + rewards, + logged_propensities, + target_propensities, + estimated_q_values, + ) = ( + torch.tensor(x, dtype=torch.double, device=device, requires_grad=True) + for x in eq_len + ) + + num_examples = logged_propensities.shape[0] + horizon = logged_propensities.shape[1] + + log = [] + for traj in range(num_examples): + log.append( + [ + Transition( + last_state=State((traj, i)), + action=torch.argmax(actions[traj, i]).item(), + action_prob=logged_propensities[traj, i].item(), + state=State((traj, i + 1)), + reward=rewards[traj, i].item(), + ) + for i in range(horizon - 1) + if actions[traj, i][torch.argmax(actions[traj, i]).item()] != 0.0 + ] + ) + + return RLEstimatorInput( + gamma=gamma, + log=log, + target_policy=SequentialOPEstimatorAdapter.EDPSeqPolicy( + actions.shape[2], target_propensities + ), + value_function=SequentialOPEstimatorAdapter.EDPValueFunc( + estimated_q_values, target_propensities + ), + ground_truth=None, + horizon=horizon, + ) + @staticmethod def estimator_results_to_cpe_estimate( estimator_results: EstimatorResults, @@ -144,16 +237,8 @@ class SequentialOPEstimatorAdapter: ) def estimate(self, edp: EvaluationDataPage) -> CpeEstimate: - est_input = edp.sequential_estimator_input - assert est_input is not None, "EDP does not contain sequential estimator inputs" estimator_results = self.seq_ope_estimator.evaluate( - RLEstimatorInput( - gamma=self.gamma, - log=est_input.log, - target_policy=est_input.target_policy, - value_function=est_input.value_function, - discrete_states=est_input.discrete_states, - ) + SequentialOPEstimatorAdapter.edp_to_rl_input(edp, self.gamma, self._device) ) assert isinstance(estimator_results, EstimatorResults) return SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate( diff --git a/reagent/evaluation/ranking_listwise_evaluator.py b/reagent/evaluation/ranking_listwise_evaluator.py index 708d3d2d..21a45af6 100644 --- a/reagent/evaluation/ranking_listwise_evaluator.py +++ b/reagent/evaluation/ranking_listwise_evaluator.py @@ -7,8 +7,9 @@ from typing import Optional import numpy as np import torch import torch.nn as nn -from reagent.core.types import PreprocessedTrainingBatch +from reagent.core.tracker import observable from reagent.models.seq2slate import Seq2SlateMode +from reagent.types import PreprocessedTrainingBatch from sklearn.metrics import ( average_precision_score, dcg_score, @@ -28,6 +29,17 @@ class ListwiseRankingMetrics: cross_entropy_loss: Optional[float] = 0.0 +@observable( + cross_entropy_loss=torch.Tensor, + dcg=torch.Tensor, + ndcg=torch.Tensor, + mean_ap=torch.Tensor, + auc=torch.Tensor, + base_dcg=torch.Tensor, + base_ndcg=torch.Tensor, + base_map=torch.Tensor, + base_auc=torch.Tensor, +) class RankingListwiseEvaluator: """ Evaluate listwise ranking models on common ranking metrics """ @@ -43,7 +55,6 @@ class RankingListwiseEvaluator: self.base_map = [] self.log_softmax = nn.LogSoftmax(dim=1) self.kl_loss = nn.KLDivLoss(reduction="batchmean") - self.reporter = None # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because # its type `no_grad` is not callable. @@ -72,7 +83,9 @@ class RankingListwiseEvaluator: self.seq2slate_net.train(seq2slate_net_prev_mode) if not self.calc_cpe: - self.reporter.report_evaluation_minibatch(cross_entropy_loss=ce_loss) + # pyre-fixme[16]: `RankingListwiseEvaluator` has no attribute + # `notify_observers`. + self.notify_observers(cross_entropy_loss=ce_loss) return # shape: batch_size, tgt_seq_len @@ -119,7 +132,7 @@ class RankingListwiseEvaluator: batch_base_dcg.append(dcg_score(truth_scores, base_scores)) batch_base_ndcg.append(ndcg_score(truth_scores, base_scores)) - self.reporter.report_evaluation_minibatch( + self.notify_observers( cross_entropy_loss=ce_loss, dcg=torch.mean(torch.tensor(batch_dcg)).reshape(1), ndcg=torch.mean(torch.tensor(batch_ndcg)).reshape(1), @@ -132,5 +145,5 @@ class RankingListwiseEvaluator: ) @torch.no_grad() - def evaluate_one_shot(self): + def evaluate_post_training(self): pass diff --git a/reagent/evaluation/ranking_policy_gradient_evaluator.py b/reagent/evaluation/ranking_policy_gradient_evaluator.py index 6b9f7514..801ea4e6 100644 --- a/reagent/evaluation/ranking_policy_gradient_evaluator.py +++ b/reagent/evaluation/ranking_policy_gradient_evaluator.py @@ -8,15 +8,24 @@ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F -from reagent.core.types import PreprocessedTrainingBatch +from reagent.core.tracker import observable from reagent.evaluation.evaluation_data_page import EvaluationDataPage from reagent.models.seq2slate import Seq2SlateMode from reagent.training.ranking.seq2slate_trainer import Seq2SlateTrainer +from reagent.types import PreprocessedTrainingBatch logger = logging.getLogger(__name__) +@observable( + eval_baseline_loss=torch.Tensor, + eval_advantages=torch.Tensor, + logged_slate_rank_probs=torch.Tensor, + ranked_slate_rank_probs=torch.Tensor, + eval_data_pages_g=EvaluationDataPage, + eval_data_pages_ng=EvaluationDataPage, +) class RankingPolicyGradientEvaluator: """ Evaluate ranking models that are learned through policy gradient """ @@ -30,12 +39,13 @@ class RankingPolicyGradientEvaluator: self.trainer = trainer self.calc_cpe = calc_cpe self.reward_network = reward_network - self.reporter = None # Evaluate greedy/non-greedy version of the ranking model self.eval_data_pages_g: Optional[EvaluationDataPage] = None self.eval_data_pages_ng: Optional[EvaluationDataPage] = None + # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because + # its type `no_grad` is not callable. @torch.no_grad() def evaluate(self, eval_tdp: PreprocessedTrainingBatch) -> None: seq2slate_net = self.trainer.seq2slate_net @@ -117,7 +127,9 @@ class RankingPolicyGradientEvaluator: else: self.eval_data_pages_ng = self.eval_data_pages_ng.append(edp_ng) - self.reporter.report_evaluation_minibatch( + # pyre-fixme[16]: `RankingPolicyGradientEvaluator` has no attribute + # `notify_observers`. + self.notify_observers( eval_baseline_loss=eval_baseline_loss, eval_advantages=eval_advantage, logged_slate_rank_probs=logged_slate_rank_prob, @@ -125,13 +137,11 @@ class RankingPolicyGradientEvaluator: ) @torch.no_grad() - def finish(self): - self.reporter.report_evaluation_epoch( + def evaluate_post_training(self): + self.notify_observers( + # Use ValueListObserver as aggregating_observers requires input to be Tensor eval_data_pages_g=self.eval_data_pages_g, eval_data_pages_ng=self.eval_data_pages_ng, ) self.eval_data_pages_g = None self.eval_data_pages_ng = None - - def evaluate_one_shot(self, edp: EvaluationDataPage): - pass diff --git a/reagent/evaluation/reward_net_evaluator.py b/reagent/evaluation/reward_net_evaluator.py index bf9d6afc..0da77c0b 100644 --- a/reagent/evaluation/reward_net_evaluator.py +++ b/reagent/evaluation/reward_net_evaluator.py @@ -6,10 +6,9 @@ import logging import numpy as np import torch import torch.nn.functional as F -from reagent.core import types as rlt -from reagent.core.types import PreprocessedTrainingBatch -from reagent.evaluation.evaluation_data_page import EvaluationDataPage +from reagent import types as rlt from reagent.training.reward_network_trainer import RewardNetTrainer +from reagent.types import PreprocessedTrainingBatch logger = logging.getLogger(__name__) @@ -22,6 +21,7 @@ class RewardNetEvaluator: self.trainer = trainer self.mse_loss = [] self.rewards = [] + self.best_model = None self.best_model_loss = 1e9 # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because @@ -47,7 +47,7 @@ class RewardNetEvaluator: reward_net.train(reward_net_prev_mode) @torch.no_grad() - def finish(self): + def evaluate_post_training(self): mean_mse_loss = np.mean(self.mse_loss) logger.info(f"Evaluation MSE={mean_mse_loss}") eval_res = {"mse": mean_mse_loss, "rewards": torch.cat(self.rewards)} @@ -56,9 +56,6 @@ class RewardNetEvaluator: if mean_mse_loss < self.best_model_loss: self.best_model_loss = mean_mse_loss - self.trainer.best_model = copy.deepcopy(self.trainer.reward_net) + self.best_model = copy.deepcopy(self.trainer.reward_net) return eval_res - - def evaluate_one_shot(self, edp: EvaluationDataPage): - pass diff --git a/reagent/evaluation/seq2reward_evaluator.py b/reagent/evaluation/seq2reward_evaluator.py index afda5153..08e7d642 100644 --- a/reagent/evaluation/seq2reward_evaluator.py +++ b/reagent/evaluation/seq2reward_evaluator.py @@ -3,8 +3,8 @@ import logging import torch -from reagent.core.types import PreprocessedTrainingBatch from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer +from reagent.types import PreprocessedTrainingBatch logger = logging.getLogger(__name__) @@ -15,13 +15,15 @@ class Seq2RewardEvaluator: self.trainer = trainer self.reward_net = self.trainer.seq2reward_network + # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because + # its type `no_grad` is not callable. @torch.no_grad() def evaluate(self, eval_tdp: PreprocessedTrainingBatch): reward_net_prev_mode = self.reward_net.training self.reward_net.eval() # pyre-fixme[6]: Expected `MemoryNetworkInput` for 1st param but got # `PreprocessedTrainingBatch`. - loss = self.trainer.compute_loss(eval_tdp) + loss = self.trainer.get_loss(eval_tdp) detached_loss = loss.cpu().detach().item() q_values = ( self.trainer.get_Q( @@ -37,6 +39,3 @@ class Seq2RewardEvaluator: ) self.reward_net.train(reward_net_prev_mode) return (detached_loss, q_values) - - def finish(self): - pass diff --git a/reagent/evaluation/world_model_evaluator.py b/reagent/evaluation/world_model_evaluator.py index 0b0ff82e..62c695e1 100644 --- a/reagent/evaluation/world_model_evaluator.py +++ b/reagent/evaluation/world_model_evaluator.py @@ -4,28 +4,23 @@ import logging from typing import Dict, List import torch -from reagent.core.types import FeatureData, MemoryNetworkInput -from reagent.reporting.world_model_reporter import ( - DebugToolsReporter, - WorldModelReporter, -) from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer +from reagent.types import FeatureData, MemoryNetworkInput logger = logging.getLogger(__name__) -class WorldModelLossEvaluator(object): +class LossEvaluator(object): """ Evaluate losses on data pages """ def __init__(self, trainer: MDNRNNTrainer, state_dim: int) -> None: self.trainer = trainer self.state_dim = state_dim - self.reporter = WorldModelReporter(1) - def evaluate(self, tdp: MemoryNetworkInput) -> None: + def evaluate(self, tdp: MemoryNetworkInput) -> Dict[str, float]: self.trainer.memory_network.mdnrnn.eval() - losses = self.trainer.compute_loss(tdp, state_dim=self.state_dim) + losses = self.trainer.get_loss(tdp, state_dim=self.state_dim) detached_losses = { "loss": losses["loss"].cpu().detach().item(), "gmm": losses["gmm"].cpu().detach().item(), @@ -34,10 +29,7 @@ class WorldModelLossEvaluator(object): } del losses self.trainer.memory_network.mdnrnn.train() - self.reporter.report(**detached_losses) - - def finish(self): - pass + return detached_losses class FeatureImportanceEvaluator(object): @@ -65,7 +57,6 @@ class FeatureImportanceEvaluator(object): self.action_feature_num = action_feature_num self.sorted_action_feature_start_indices = sorted_action_feature_start_indices self.sorted_state_feature_start_indices = sorted_state_feature_start_indices - self.reporter = DebugToolsReporter() def evaluate(self, batch: MemoryNetworkInput): """ Calculate feature importance: setting each state/action feature to @@ -80,7 +71,7 @@ class FeatureImportanceEvaluator(object): state_feature_num = self.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) - orig_losses = self.trainer.compute_loss(batch, state_dim=state_dim) + orig_losses = self.trainer.get_loss(batch, state_dim=state_dim) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses @@ -124,7 +115,7 @@ class FeatureImportanceEvaluator(object): not_terminal=batch.not_terminal, step=None, ) - losses = self.trainer.compute_loss(new_batch, state_dim=state_dim) + losses = self.trainer.get_loss(new_batch, state_dim=state_dim) feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss del losses @@ -151,7 +142,7 @@ class FeatureImportanceEvaluator(object): not_terminal=batch.not_terminal, step=None, ) - losses = self.trainer.compute_loss(new_batch, state_dim=state_dim) + losses = self.trainer.get_loss(new_batch, state_dim=state_dim) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss ) @@ -161,7 +152,6 @@ class FeatureImportanceEvaluator(object): logger.info( "**** Debug tool feature importance ****: {}".format(feature_importance) ) - self.reporter.report(feature_importance=feature_importance.tolist()) return {"feature_loss_increase": feature_importance.numpy()} def compute_median_feature_value(self, features): @@ -180,9 +170,6 @@ class FeatureImportanceEvaluator(object): median_feature = features.mean(dim=0) return median_feature - def finish(self): - pass - class FeatureSensitivityEvaluator(object): """ Evaluate state feature sensitivity caused by varying actions """ @@ -196,7 +183,6 @@ class FeatureSensitivityEvaluator(object): self.trainer = trainer self.state_feature_num = state_feature_num self.sorted_state_feature_start_indices = sorted_state_feature_start_indices - self.reporter = DebugToolsReporter() def evaluate(self, batch: MemoryNetworkInput): """ Calculate state feature sensitivity due to actions: @@ -254,8 +240,4 @@ class FeatureSensitivityEvaluator(object): logger.info( "**** Debug tool feature sensitivity ****: {}".format(feature_sensitivity) ) - self.reporter.report(feature_sensitivity=feature_sensitivity.tolist()) return {"feature_sensitivity": feature_sensitivity.numpy()} - - def finish(self): - pass diff --git a/reagent/gym/envs/changing_arms.py b/reagent/gym/envs/changing_arms.py index b596e362..a89cd96b 100644 --- a/reagent/gym/envs/changing_arms.py +++ b/reagent/gym/envs/changing_arms.py @@ -19,7 +19,7 @@ import random import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.dataclasses import dataclass from reagent.gym.envs.env_wrapper import EnvWrapper diff --git a/reagent/gym/envs/env_wrapper.py b/reagent/gym/envs/env_wrapper.py index 350f5299..dfc2d327 100644 --- a/reagent/gym/envs/env_wrapper.py +++ b/reagent/gym/envs/env_wrapper.py @@ -7,7 +7,7 @@ from typing import Callable, Optional import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from gym import spaces from reagent.core.dataclasses import dataclass diff --git a/reagent/gym/envs/gym.py b/reagent/gym/envs/gym.py index 2a9933e4..3375e8e7 100644 --- a/reagent/gym/envs/gym.py +++ b/reagent/gym/envs/gym.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from gym import spaces from gym_minigrid.wrappers import ReseedWrapper diff --git a/reagent/gym/envs/pomdp/state_embed_env.py b/reagent/gym/envs/pomdp/state_embed_env.py index d22f3637..beafa5be 100644 --- a/reagent/gym/envs/pomdp/state_embed_env.py +++ b/reagent/gym/envs/pomdp/state_embed_env.py @@ -14,7 +14,7 @@ from typing import Optional import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from gym.spaces import Box from reagent.gym.envs.env_wrapper import EnvWrapper diff --git a/reagent/gym/envs/recsim.py b/reagent/gym/envs/recsim.py index 934e7e09..e5d376d2 100644 --- a/reagent/gym/envs/recsim.py +++ b/reagent/gym/envs/recsim.py @@ -5,7 +5,7 @@ import logging import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt from reagent.core.dataclasses import dataclass from reagent.gym.envs.env_wrapper import EnvWrapper from reagent.gym.envs.wrappers.recsim import ValueWrapper diff --git a/reagent/gym/policies/policy.py b/reagent/gym/policies/policy.py index e491c4bf..e83104f4 100644 --- a/reagent/gym/policies/policy.py +++ b/reagent/gym/policies/policy.py @@ -4,7 +4,7 @@ from typing import Any, Optional import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt from reagent.gym.types import Sampler, Scorer diff --git a/reagent/gym/policies/predictor_policies.py b/reagent/gym/policies/predictor_policies.py index cf245370..b46225ff 100644 --- a/reagent/gym/policies/predictor_policies.py +++ b/reagent/gym/policies/predictor_policies.py @@ -4,7 +4,7 @@ from typing import Any, Optional, Tuple, Union import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.fb_checker import IS_FB_ENVIRONMENT from reagent.gym.policies import Policy diff --git a/reagent/gym/policies/random_policies.py b/reagent/gym/policies/random_policies.py index f0cd0741..31f11c91 100644 --- a/reagent/gym/policies/random_policies.py +++ b/reagent/gym/policies/random_policies.py @@ -5,7 +5,7 @@ from typing import List, Optional import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.gym.policies.policy import Policy diff --git a/reagent/gym/policies/samplers/continuous_sampler.py b/reagent/gym/policies/samplers/continuous_sampler.py index 628a1ef7..0775e39f 100644 --- a/reagent/gym/policies/samplers/continuous_sampler.py +++ b/reagent/gym/policies/samplers/continuous_sampler.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.gym.types import GaussianSamplerScore, Sampler diff --git a/reagent/gym/policies/samplers/discrete_sampler.py b/reagent/gym/policies/samplers/discrete_sampler.py index ba62aa65..5a6649fa 100644 --- a/reagent/gym/policies/samplers/discrete_sampler.py +++ b/reagent/gym/policies/samplers/discrete_sampler.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.gym.types import Sampler @@ -41,9 +41,6 @@ class SoftmaxActionSampler(Sampler): assert raw_action.shape == ( batch_size, ), f"{raw_action.shape} != ({batch_size}, )" - assert ( - int(raw_action.max().item()) < num_actions - ), f"Invalid action: {int(raw_action.max().item())}" action = F.one_hot(raw_action, num_actions) assert action.ndim == 2 log_prob = m.log_prob(raw_action) diff --git a/reagent/gym/policies/samplers/top_k_sampler.py b/reagent/gym/policies/samplers/top_k_sampler.py index 77f3cd5b..3d814486 100644 --- a/reagent/gym/policies/samplers/top_k_sampler.py +++ b/reagent/gym/policies/samplers/top_k_sampler.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.gym.types import Sampler diff --git a/reagent/gym/policies/scorers/continuous_scorer.py b/reagent/gym/policies/scorers/continuous_scorer.py index 78265730..6a5892fb 100644 --- a/reagent/gym/policies/scorers/continuous_scorer.py +++ b/reagent/gym/policies/scorers/continuous_scorer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.gym.types import GaussianSamplerScore, Scorer from reagent.models.base import ModelBase diff --git a/reagent/gym/policies/scorers/discrete_scorer.py b/reagent/gym/policies/scorers/discrete_scorer.py index 895a29f8..3e461ab3 100644 --- a/reagent/gym/policies/scorers/discrete_scorer.py +++ b/reagent/gym/policies/scorers/discrete_scorer.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.gym.preprocessors.trainer_preprocessor import get_possible_actions_for_gym from reagent.gym.types import Scorer diff --git a/reagent/gym/policies/scorers/slate_q_scorer.py b/reagent/gym/policies/scorers/slate_q_scorer.py index 517df220..d304b763 100644 --- a/reagent/gym/policies/scorers/slate_q_scorer.py +++ b/reagent/gym/policies/scorers/slate_q_scorer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.gym.types import Scorer diff --git a/reagent/gym/preprocessors/default_preprocessors.py b/reagent/gym/preprocessors/default_preprocessors.py index 864a8922..edd43fb7 100644 --- a/reagent/gym/preprocessors/default_preprocessors.py +++ b/reagent/gym/preprocessors/default_preprocessors.py @@ -7,7 +7,7 @@ import logging from typing import List, Optional, Tuple import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from gym import Env, spaces diff --git a/reagent/gym/preprocessors/trainer_preprocessor.py b/reagent/gym/preprocessors/trainer_preprocessor.py index 77cd7740..c23e2a49 100644 --- a/reagent/gym/preprocessors/trainer_preprocessor.py +++ b/reagent/gym/preprocessors/trainer_preprocessor.py @@ -9,7 +9,7 @@ from typing import Optional import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE diff --git a/reagent/runners/__init__.py b/reagent/gym/runners/__init__.py similarity index 100% rename from reagent/runners/__init__.py rename to reagent/gym/runners/__init__.py diff --git a/reagent/gym/tests/configs/sparse/discrete_dqn_changing_arms_online.yaml b/reagent/gym/tests/configs/sparse/discrete_dqn_changing_arms_online.yaml index 97a49222..933ada54 100644 --- a/reagent/gym/tests/configs/sparse/discrete_dqn_changing_arms_online.yaml +++ b/reagent/gym/tests/configs/sparse/discrete_dqn_changing_arms_online.yaml @@ -75,5 +75,5 @@ train_every_ts: 1 train_after_ts: 20000 num_train_episodes: 10 num_eval_episodes: 10 -passing_score_bar: 190 +passing_score_bar: 200 use_gpu: false diff --git a/reagent/gym/tests/test_gym.py b/reagent/gym/tests/test_gym.py index ed1c74c7..47c5763d 100644 --- a/reagent/gym/tests/test_gym.py +++ b/reagent/gym/tests/test_gym.py @@ -17,21 +17,13 @@ from reagent.gym.agents.post_step import train_with_replay_buffer_post_step from reagent.gym.envs.union import Env__Union from reagent.gym.runners.gymrunner import evaluate_for_n_episodes, run_episode from reagent.gym.utils import build_normalizer, fill_replay_buffer -from reagent.model_managers.model_manager import ModelManager -from reagent.model_managers.union import ModelManager__Union from reagent.replay_memory.circular_replay_buffer import ReplayBuffer from reagent.tensorboardX import summary_writer_context from reagent.test.base.horizon_test_base import HorizonTestBase +from reagent.workflow.model_managers.union import ModelManager__Union from torch.utils.tensorboard import SummaryWriter -try: - # Use internal runner or OSS otherwise - from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner -except ImportError: - from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner - - # for seeding the environment SEED = 0 logger = logging.getLogger(__name__) @@ -116,12 +108,13 @@ def run_test( normalization = build_normalizer(env) logger.info(f"Normalization is: \n{pprint.pformat(normalization)}") - manager: ModelManager = model.value - runner = BatchRunner(use_gpu, manager, RewardOptions(), normalization) - trainer = runner.initialize_trainer() - reporter = manager.get_reporter() - trainer.reporter = reporter - training_policy = manager.create_policy(trainer) + manager = model.value + trainer = manager.initialize_trainer( + use_gpu=use_gpu, + reward_options=RewardOptions(), + normalization_data_map=normalization, + ) + training_policy = manager.create_policy(serving=False) replay_buffer = ReplayBuffer( replay_capacity=replay_memory_size, batch_size=trainer.minibatch_size @@ -172,7 +165,7 @@ def run_test( f"{len(train_rewards)} episodes is less than < {passing_score_bar}.\n" ) - serving_policy = manager.create_serving_policy(normalization, trainer) + serving_policy = manager.create_policy(serving=True) agent = Agent.create_for_env_with_serving_policy(env, serving_policy) eval_rewards = evaluate_for_n_episodes( diff --git a/reagent/gym/tests/test_gym_offline.py b/reagent/gym/tests/test_gym_offline.py index 8cfd8e83..578b2fe8 100644 --- a/reagent/gym/tests/test_gym_offline.py +++ b/reagent/gym/tests/test_gym_offline.py @@ -17,22 +17,14 @@ from reagent.gym.envs.gym import Gym from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor from reagent.gym.runners.gymrunner import evaluate_for_n_episodes from reagent.gym.utils import build_normalizer, fill_replay_buffer -from reagent.model_managers.union import ModelManager__Union from reagent.replay_memory.circular_replay_buffer import ReplayBuffer -from reagent.runners.oss_batch_runner import OssBatchRunner from reagent.tensorboardX import summary_writer_context from reagent.test.base.horizon_test_base import HorizonTestBase +from reagent.workflow.model_managers.union import ModelManager__Union from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -try: - # Use internal runner or OSS otherwise - from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner -except ImportError: - from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner - - # for seeding the environment SEED = 0 logger = logging.getLogger(__name__) @@ -86,7 +78,7 @@ class TestGymOffline(HorizonTestBase): def evaluate_cem(env, manager, num_eval_episodes: int): # NOTE: for CEM, serving isn't implemented - policy = manager.create_policy() + policy = manager.create_policy(serving=False) agent = Agent.create_for_env(env, policy) return evaluate_for_n_episodes( n=num_eval_episodes, env=env, agent=agent, max_steps=env.max_steps @@ -110,13 +102,11 @@ def run_test_offline( logger.info(f"Normalization is: \n{pprint.pformat(normalization)}") manager = model.value - runner = OssBatchRunner( - use_gpu, - manager, + trainer = manager.initialize_trainer( + use_gpu=use_gpu, reward_options=RewardOptions(), normalization_data_map=normalization, ) - trainer = runner.initialize_trainer() # first fill the replay buffer to burn_in replay_buffer = ReplayBuffer( diff --git a/reagent/gym/tests/test_seq2reward_model.py b/reagent/gym/tests/test_seq2reward_model.py index e8ecf8f2..b2adb3eb 100644 --- a/reagent/gym/tests/test_seq2reward_model.py +++ b/reagent/gym/tests/test_seq2reward_model.py @@ -4,7 +4,7 @@ import logging import os import unittest -from typing import Optional, cast +from typing import Optional import torch from reagent.core.types import RewardOptions @@ -12,18 +12,10 @@ from reagent.gym.envs.env_wrapper import EnvWrapper from reagent.gym.envs.gym import Gym from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor from reagent.gym.utils import build_normalizer, fill_replay_buffer -from reagent.model_managers.union import ModelManager__Union from reagent.replay_memory.circular_replay_buffer import ReplayBuffer -from reagent.runners.oss_batch_runner import OssBatchRunner from reagent.test.base.horizon_test_base import HorizonTestBase from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer - - -try: - # Use internal runner or OSS otherwise - from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner -except ImportError: - from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner +from reagent.workflow.model_managers.union import ModelManager__Union logging.basicConfig(level=logging.INFO) @@ -79,8 +71,8 @@ def train_seq2reward( ) preprocessed_test_batch = trainer_preprocessor(test_batch) adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim) - # valid_losses = trainer.get_loss(preprocessed_test_batch) - # print_seq2reward_losses(epoch, "validation", valid_losses) + valid_losses = trainer.get_loss(preprocessed_test_batch) + print_seq2reward_losses(epoch, "validation", valid_losses) trainer.seq2reward_network.train() return trainer @@ -117,13 +109,11 @@ def train_seq2reward_and_compute_reward_mse( env.seed(SEED) manager = model.value - runner = OssBatchRunner( - use_gpu, - manager, + trainer = manager.initialize_trainer( + use_gpu=use_gpu, reward_options=RewardOptions(), normalization_data_map=build_normalizer(env), ) - trainer = cast(Seq2RewardTrainer, runner.initialize_trainer()) device = "cuda" if use_gpu else "cpu" # pyre-fixme[6]: Expected `device` for 2nd param but got `str`. @@ -159,7 +149,7 @@ def train_seq2reward_and_compute_reward_mse( ) preprocessed_test_batch = trainer_preprocessor(test_batch) adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim) - losses = trainer.compute_loss(preprocessed_test_batch) + losses = trainer.get_loss(preprocessed_test_batch) detached_losses = losses.cpu().detach().item() trainer.seq2reward_network.train() return detached_losses diff --git a/reagent/gym/tests/test_world_model.py b/reagent/gym/tests/test_world_model.py index 80e6a3d0..c671a92b 100644 --- a/reagent/gym/tests/test_world_model.py +++ b/reagent/gym/tests/test_world_model.py @@ -3,11 +3,11 @@ import logging import os import unittest -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional import gym import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.types import RewardOptions from reagent.evaluation.world_model_evaluator import ( @@ -21,21 +21,14 @@ from reagent.gym.envs.pomdp.state_embed_env import StateEmbedEnvironment from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor from reagent.gym.runners.gymrunner import evaluate_for_n_episodes from reagent.gym.utils import build_normalizer, fill_replay_buffer -from reagent.model_managers.union import ModelManager__Union from reagent.models.world_model import MemoryNetwork from reagent.replay_memory.circular_replay_buffer import ReplayBuffer from reagent.test.base.horizon_test_base import HorizonTestBase from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer +from reagent.workflow.model_managers.union import ModelManager__Union from tqdm import tqdm -try: - # Use internal runner or OSS otherwise - from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner -except ImportError: - from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner - - logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -156,7 +149,7 @@ def train_mdnrnn( batch_size=batch_size ) preprocessed_test_batch = trainer_preprocessor(test_batch) - valid_losses = trainer.compute_loss(preprocessed_test_batch) + valid_losses = trainer.get_loss(preprocessed_test_batch) print_mdnrnn_losses(epoch, "validation", valid_losses) trainer.memory_network.mdnrnn.train() return trainer @@ -178,13 +171,11 @@ def train_mdnrnn_and_compute_feature_stats( env.seed(SEED) manager = model.value - runner = BatchRunner( - use_gpu, - manager, + trainer = manager.initialize_trainer( + use_gpu=use_gpu, reward_options=RewardOptions(), normalization_data_map=build_normalizer(env), ) - trainer = cast(MDNRNNTrainer, runner.initialize_trainer()) device = "cuda" if use_gpu else "cpu" # pyre-fixme[6]: Expected `device` for 2nd param but got `str`. @@ -297,13 +288,11 @@ def train_mdnrnn_and_train_on_embedded_env( env.seed(SEED) embedding_manager = embedding_model.value - embedding_runner = BatchRunner( - use_gpu, - embedding_manager, + embedding_trainer = embedding_manager.initialize_trainer( + use_gpu=use_gpu, reward_options=RewardOptions(), normalization_data_map=build_normalizer(env), ) - embedding_trainer = cast(MDNRNNTrainer, embedding_runner.initialize_trainer()) device = "cuda" if use_gpu else "cpu" embedding_trainer_preprocessor = make_replay_buffer_trainer_preprocessor( @@ -347,14 +336,13 @@ def train_mdnrnn_and_train_on_embedded_env( state_max_value=state_max, ) agent_manager = train_model.value - agent_trainer = agent_manager.build_trainer( + agent_trainer = agent_manager.initialize_trainer( use_gpu=use_gpu, reward_options=RewardOptions(), # pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got # `StateEmbedEnvironment`. normalization_data_map=build_normalizer(embed_env), ) - agent_trainer.reporter = agent_manager.get_reporter() device = "cuda" if use_gpu else "cpu" agent_trainer_preprocessor = make_replay_buffer_trainer_preprocessor( agent_trainer, @@ -371,7 +359,7 @@ def train_mdnrnn_and_train_on_embedded_env( # evaluate model rewards = [] - policy = agent_manager.create_policy(agent_trainer) + policy = agent_manager.create_policy(serving=False) # pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got # `StateEmbedEnvironment`. agent = Agent.create_for_env(embed_env, policy=policy, device=device) diff --git a/reagent/gym/types.py b/reagent/gym/types.py index 3a5ccee8..a068db9e 100644 --- a/reagent/gym/types.py +++ b/reagent/gym/types.py @@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field, fields from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch diff --git a/reagent/json_serialize.py b/reagent/json_serialize.py index b31f81c9..7169308e 100644 --- a/reagent/json_serialize.py +++ b/reagent/json_serialize.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import collections import json import logging -from dataclasses import asdict, fields, is_dataclass +from dataclasses import asdict, dataclass, fields, is_dataclass from typing import Any, NamedTuple, Type, Union diff --git a/reagent/model_managers/discrete_dqn_base.py b/reagent/model_managers/discrete_dqn_base.py deleted file mode 100644 index 2a854e07..00000000 --- a/reagent/model_managers/discrete_dqn_base.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 - -import logging -from typing import Dict, List, Optional, Tuple - -from reagent.core import types as rlt -from reagent.core.dataclasses import dataclass, field -from reagent.core.types import ( - Dataset, - PreprocessingOptions, - ReaderOptions, - RewardOptions, - TableSpec, -) -from reagent.core.union import ModelFeatureConfigProvider__Union -from reagent.data_fetchers.data_fetcher import DataFetcher -from reagent.evaluation.evaluator import Evaluator, get_metrics_to_score -from reagent.gym.policies.policy import Policy -from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler -from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer -from reagent.model_managers.model_manager import ModelManager -from reagent.models.base import ModelBase -from reagent.models.model_feature_config_provider import RawModelFeatureConfigProvider -from reagent.parameters import EvaluationParameters, NormalizationData, NormalizationKey -from reagent.preprocessing.batch_preprocessor import ( - BatchPreprocessor, - DiscreteDqnBatchPreprocessor, -) -from reagent.preprocessing.preprocessor import Preprocessor -from reagent.preprocessing.types import InputColumn -from reagent.reporting.discrete_dqn_reporter import DiscreteDQNReporter - - -logger = logging.getLogger(__name__) - - -@dataclass -class DiscreteDQNBase(ModelManager): - target_action_distribution: Optional[List[float]] = None - state_feature_config_provider: ModelFeatureConfigProvider__Union = field( - # pyre-fixme[28]: Unexpected keyword argument `raw`. - # pyre-fixme[28]: Unexpected keyword argument `raw`. - default_factory=lambda: ModelFeatureConfigProvider__Union( - raw=RawModelFeatureConfigProvider(float_feature_infos=[]) - ) - ) - eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters) - preprocessing_options: Optional[PreprocessingOptions] = None - reader_options: Optional[ReaderOptions] = None - - def __post_init_post_parse__(self): - super().__init__() - - def create_policy(self, trainer) -> Policy: - """ Create an online DiscreteDQN Policy from env. """ - sampler = SoftmaxActionSampler(temperature=self.trainer_param.rl.temperature) - scorer = discrete_dqn_scorer(trainer.q_network) - return Policy(scorer=scorer, sampler=sampler) - - @property - def state_feature_config(self) -> rlt.ModelFeatureConfig: - return self.state_feature_config_provider.value.get_model_feature_config() - - def metrics_to_score(self, reward_options: RewardOptions) -> List[str]: - return get_metrics_to_score(reward_options.metric_reward_values) - - @property - def should_generate_eval_dataset(self) -> bool: - return self.eval_parameters.calc_cpe_in_training - - @property - def required_normalization_keys(self) -> List[str]: - return [NormalizationKey.STATE] - - def run_feature_identification( - self, data_fetcher: DataFetcher, input_table_spec: TableSpec - ) -> Dict[str, NormalizationData]: - preprocessing_options = self.preprocessing_options or PreprocessingOptions() - logger.info("Overriding whitelist_features") - state_features = [ - ffi.feature_id for ffi in self.state_feature_config.float_feature_infos - ] - preprocessing_options = preprocessing_options._replace( - whitelist_features=state_features - ) - return { - NormalizationKey.STATE: NormalizationData( - dense_normalization_parameters=data_fetcher.identify_normalization_parameters( - input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options - ) - ) - } - - def query_data( - self, - data_fetcher: DataFetcher, - input_table_spec: TableSpec, - sample_range: Optional[Tuple[float, float]], - reward_options: RewardOptions, - ) -> Dataset: - return data_fetcher.query_data( - input_table_spec=input_table_spec, - discrete_action=True, - actions=self.trainer_param.actions, - include_possible_actions=True, - sample_range=sample_range, - custom_reward_expression=reward_options.custom_reward_expression, - multi_steps=self.multi_steps, - gamma=self.trainer_param.rl.gamma, - ) - - @property - def multi_steps(self) -> Optional[int]: - return self.trainer_param.rl.multi_steps - - def build_batch_preprocessor( - self, - reader_options: ReaderOptions, - use_gpu: bool, - batch_size: int, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> BatchPreprocessor: - state_preprocessor = Preprocessor( - normalization_data_map[ - NormalizationKey.STATE - ].dense_normalization_parameters, - use_gpu=use_gpu, - ) - return DiscreteDqnBatchPreprocessor( - num_actions=len(self.trainer_param.actions), - state_preprocessor=state_preprocessor, - use_gpu=use_gpu, - ) - - def get_reporter(self): - return DiscreteDQNReporter( - self.trainer_param.actions, - target_action_distribution=self.target_action_distribution, - ) - - def get_evaluator(self, trainer, reward_options: RewardOptions): - return Evaluator( - self.trainer_param.actions, - self.trainer_param.rl.gamma, - trainer, - metrics_to_score=self.metrics_to_score(reward_options), - ) diff --git a/reagent/model_managers/model_manager.py b/reagent/model_managers/model_manager.py deleted file mode 100644 index 4995992d..00000000 --- a/reagent/model_managers/model_manager.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python3 - -import abc -import logging -from typing import Dict, List, Optional, Tuple - -import torch -from reagent.core.registry_meta import RegistryMeta -from reagent.core.types import Dataset, ReaderOptions, RewardOptions, TableSpec -from reagent.data_fetchers.data_fetcher import DataFetcher -from reagent.gym.policies.policy import Policy -from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model -from reagent.parameters import NormalizationData -from reagent.preprocessing.batch_preprocessor import BatchPreprocessor -from reagent.training.trainer import Trainer - - -logger = logging.getLogger(__name__) - - -class ModelManager(metaclass=RegistryMeta): - """ - ModelManager manages how to train models. - - Each type of models can have their own config type, implemented as - `config_type()` class method. `__init__()` of the concrete class must take - this type. - - ModelManager abstracts over common phases of training, i.e.,: - 1. `run_feature_identification()` defines how to derive feature preprocessing - parameters from given data. - 2. `query_data()` massages the input table into the format expected by the trainer - 3. `initialize_trainer()` creates the trainer - 4. `train()` - 5. `build_serving_module()` builds the module for prediction - 6. `save_trainer()` saves the trainer for warmstarting - """ - - @abc.abstractmethod - def run_feature_identification( - self, data_fetcher: DataFetcher, input_table_spec: TableSpec - ) -> Dict[str, NormalizationData]: - """ - Derive preprocessing parameters from data. The keys of the dict should - match the keys from `required_normalization_keys()` - """ - pass - - @property - @abc.abstractmethod - def required_normalization_keys(self) -> List[str]: - """ Get the normalization keys required for current instance """ - pass - - @property - @abc.abstractmethod - def should_generate_eval_dataset(self) -> bool: - raise NotImplementedError() - - def get_evaluator(self, trainer, reward_options: RewardOptions): - return None - - @abc.abstractmethod - def query_data( - self, - data_fetcher: DataFetcher, - input_table_spec: TableSpec, - sample_range: Optional[Tuple[float, float]], - reward_options: RewardOptions, - ) -> Dataset: - """ - Massage input table into the format expected by the trainer - """ - pass - - @abc.abstractmethod - def get_reporter(self): - """ - Get the reporter that displays statistics after training - """ - pass - - @abc.abstractmethod - def build_batch_preprocessor( - self, - reader_options: ReaderOptions, - use_gpu: bool, - batch_size: int, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> BatchPreprocessor: - """ - The Batch Preprocessor is a module that transforms data to a form that can be (1) read by the trainer - or (2) used in part of the serving module. For training, the batch preprocessor is typically run - on reader machines in parallel so the GPUs on the trainer machines can be fully utilized. - """ - pass - - @abc.abstractmethod - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> Trainer: - """ - Implement this to build the trainer, given the config - """ - pass - - def create_policy(self, trainer) -> Policy: - """ Create a Policy from env. """ - raise NotImplementedError() - - def create_serving_policy( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> Policy: - """ Create an online Policy from env. """ - return create_predictor_policy_from_model( - self.build_serving_module(normalization_data_map, trainer) - ) - - @abc.abstractmethod - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> torch.nn.Module: - """ - Returns TorchScript module to be used in predictor - """ - pass diff --git a/reagent/model_managers/parametric/parametric_dqn.py b/reagent/model_managers/parametric/parametric_dqn.py deleted file mode 100644 index ddf0b929..00000000 --- a/reagent/model_managers/parametric/parametric_dqn.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 - -import logging -from typing import Dict - -import torch -from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.model_managers.parametric_dqn_base import ParametricDQNBase -from reagent.net_builder.parametric_dqn.fully_connected import FullyConnected -from reagent.net_builder.unions import ParametricDQNNetBuilder__Union -from reagent.parameters import NormalizationData, NormalizationKey, param_hash -from reagent.preprocessing.normalization import ( - get_feature_config, - get_num_output_features, -) -from reagent.training import ParametricDQNTrainer, ParametricDQNTrainerParameters - - -logger = logging.getLogger(__name__) - - -@dataclass -class ParametricDQN(ParametricDQNBase): - __hash__ = param_hash - - trainer_param: ParametricDQNTrainerParameters = field( - default_factory=ParametricDQNTrainerParameters - ) - net_builder: ParametricDQNNetBuilder__Union = field( - # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. - # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. - default_factory=lambda: ParametricDQNNetBuilder__Union( - FullyConnected=FullyConnected() - ) - ) - - def __post_init_post_parse__(self): - super().__post_init_post_parse__() - - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> ParametricDQNTrainer: - net_builder = self.net_builder.value - q_network = net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], - ) - # Metrics + reward - reward_output_dim = len(self.metrics_to_score(reward_options)) + 1 - reward_network = net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], - output_dim=reward_output_dim, - ) - - if use_gpu: - q_network = q_network.cuda() - reward_network = reward_network.cuda() - - q_network_target = q_network.get_target_network() - trainer = ParametricDQNTrainer( - q_network=q_network, - q_network_target=q_network_target, - reward_network=reward_network, - use_gpu=use_gpu, - # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute - # `asdict`. - # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute - # `asdict`. - **self.trainer_param.asdict(), - ) - - # HACK: injecting num_actions to build policies for gym - trainer.num_gym_actions = get_num_output_features( - normalization_data_map[ - NormalizationKey.ACTION - ].dense_normalization_parameters - ) - - return trainer - - def build_serving_module( - self, - normalization_data_map: Dict[str, NormalizationData], - trainer: ParametricDQNTrainer, - ) -> torch.nn.Module: - net_builder = self.net_builder.value - return net_builder.build_serving_module( - trainer.q_network, - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], - ) diff --git a/reagent/models/actor.py b/reagent/models/actor.py index 4858ded0..c08782dd 100644 --- a/reagent/models/actor.py +++ b/reagent/models/actor.py @@ -5,7 +5,7 @@ import math from typing import List, Optional import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.fully_connected_network import FullyConnectedNetwork from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE diff --git a/reagent/models/base.py b/reagent/models/base.py index 539e1d34..a7ce445d 100644 --- a/reagent/models/base.py +++ b/reagent/models/base.py @@ -5,7 +5,7 @@ from copy import deepcopy from typing import Any, Optional import torch.nn as nn -from reagent.core import types as rlt +from reagent import types as rlt # add ABCMeta once https://github.com/sphinx-doc/sphinx/issues/5995 is fixed diff --git a/reagent/models/categorical_dqn.py b/reagent/models/categorical_dqn.py index e859759d..f0dce217 100644 --- a/reagent/models/categorical_dqn.py +++ b/reagent/models/categorical_dqn.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase diff --git a/reagent/models/cem_planner.py b/reagent/models/cem_planner.py index 741fd619..dafdb301 100644 --- a/reagent/models/cem_planner.py +++ b/reagent/models/cem_planner.py @@ -17,7 +17,7 @@ import numpy as np import scipy.stats as stats import torch import torch.nn as nn -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.world_model import MemoryNetwork from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE diff --git a/reagent/models/critic.py b/reagent/models/critic.py index dd32cb37..5d570c55 100644 --- a/reagent/models/critic.py +++ b/reagent/models/critic.py @@ -4,7 +4,7 @@ from typing import List import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.fully_connected_network import FullyConnectedNetwork diff --git a/reagent/models/dqn.py b/reagent/models/dqn.py index 4ad90754..61d7c2b3 100644 --- a/reagent/models/dqn.py +++ b/reagent/models/dqn.py @@ -4,7 +4,7 @@ from typing import Optional import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.fully_connected_network import FullyConnectedNetwork diff --git a/reagent/models/dueling_q_network.py b/reagent/models/dueling_q_network.py index fd5f23ab..3681a9f6 100644 --- a/reagent/models/dueling_q_network.py +++ b/reagent/models/dueling_q_network.py @@ -5,7 +5,7 @@ import logging from typing import List, Optional, Tuple import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.critic import FullyConnectedCritic from reagent.models.dqn import FullyConnectedDQN diff --git a/reagent/models/embedding_bag_concat.py b/reagent/models/embedding_bag_concat.py index a4e3ec76..bfb1a8cf 100644 --- a/reagent/models/embedding_bag_concat.py +++ b/reagent/models/embedding_bag_concat.py @@ -4,7 +4,7 @@ from typing import Dict, List import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase diff --git a/reagent/models/mdn_rnn.py b/reagent/models/mdn_rnn.py index caf1a667..5aed52cb 100644 --- a/reagent/models/mdn_rnn.py +++ b/reagent/models/mdn_rnn.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as f -from reagent.core import types as rlt +from reagent import types as rlt from reagent.torch_utils import stack from torch.distributions.normal import Normal diff --git a/reagent/models/model_feature_config_provider.py b/reagent/models/model_feature_config_provider.py index b885e650..c711d69e 100644 --- a/reagent/models/model_feature_config_provider.py +++ b/reagent/models/model_feature_config_provider.py @@ -2,7 +2,7 @@ import abc -import reagent.core.types as rlt +import reagent.types as rlt from reagent.core.dataclasses import dataclass from reagent.core.registry_meta import RegistryMeta diff --git a/reagent/models/seq2reward_model.py b/reagent/models/seq2reward_model.py index a67cde98..319144ee 100644 --- a/reagent/models/seq2reward_model.py +++ b/reagent/models/seq2reward_model.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase diff --git a/reagent/models/seq2slate.py b/reagent/models/seq2slate.py index 522da13d..c21a7ccf 100644 --- a/reagent/models/seq2slate.py +++ b/reagent/models/seq2slate.py @@ -10,7 +10,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from torch.nn.parallel.distributed import DistributedDataParallel diff --git a/reagent/models/seq2slate_reward.py b/reagent/models/seq2slate_reward.py index cfe456fd..68c2ac12 100644 --- a/reagent/models/seq2slate_reward.py +++ b/reagent/models/seq2slate_reward.py @@ -6,7 +6,7 @@ import logging import torch import torch.nn as nn import torch.nn.functional as F -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.seq2slate import ( DECODER_START_SYMBOL, diff --git a/reagent/models/world_model.py b/reagent/models/world_model.py index 6f6fd6ef..e6beabd8 100644 --- a/reagent/models/world_model.py +++ b/reagent/models/world_model.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.models.mdn_rnn import MDNRNN diff --git a/reagent/net_builder/categorical_dqn_net_builder.py b/reagent/net_builder/categorical_dqn_net_builder.py index 164c5034..7125d6bc 100644 --- a/reagent/net_builder/categorical_dqn_net_builder.py +++ b/reagent/net_builder/categorical_dqn_net_builder.py @@ -3,7 +3,7 @@ import abc from typing import List -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.fb_checker import IS_FB_ENVIRONMENT from reagent.core.registry_meta import RegistryMeta diff --git a/reagent/net_builder/discrete_dqn/dueling.py b/reagent/net_builder/discrete_dqn/dueling.py index 07d412af..fc2fe4b2 100644 --- a/reagent/net_builder/discrete_dqn/dueling.py +++ b/reagent/net_builder/discrete_dqn/dueling.py @@ -2,7 +2,7 @@ from typing import List -from reagent.core import types as rlt +from reagent import types as rlt from reagent.core.dataclasses import dataclass, field from reagent.models.base import ModelBase from reagent.models.dueling_q_network import DuelingQNetwork diff --git a/reagent/net_builder/discrete_dqn/fully_connected.py b/reagent/net_builder/discrete_dqn/fully_connected.py index 33000f69..fa2d033a 100644 --- a/reagent/net_builder/discrete_dqn/fully_connected.py +++ b/reagent/net_builder/discrete_dqn/fully_connected.py @@ -2,7 +2,7 @@ from typing import List -from reagent.core import types as rlt +from reagent import types as rlt from reagent.core.dataclasses import dataclass, field from reagent.models.base import ModelBase from reagent.models.dqn import FullyConnectedDQN diff --git a/reagent/net_builder/discrete_dqn/fully_connected_with_embedding.py b/reagent/net_builder/discrete_dqn/fully_connected_with_embedding.py index 2c95b40c..6795ff1c 100644 --- a/reagent/net_builder/discrete_dqn/fully_connected_with_embedding.py +++ b/reagent/net_builder/discrete_dqn/fully_connected_with_embedding.py @@ -3,7 +3,7 @@ from typing import List import reagent.models as models -from reagent.core import types as rlt +from reagent import types as rlt from reagent.core.dataclasses import dataclass, field from reagent.net_builder.discrete_dqn_net_builder import DiscreteDQNNetBuilder from reagent.parameters import NormalizationData, param_hash diff --git a/reagent/net_builder/discrete_dqn_net_builder.py b/reagent/net_builder/discrete_dqn_net_builder.py index b86e71e3..5acd0b62 100644 --- a/reagent/net_builder/discrete_dqn_net_builder.py +++ b/reagent/net_builder/discrete_dqn_net_builder.py @@ -3,7 +3,7 @@ import abc from typing import List -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.fb_checker import IS_FB_ENVIRONMENT from reagent.core.registry_meta import RegistryMeta diff --git a/reagent/net_builder/quantile_dqn_net_builder.py b/reagent/net_builder/quantile_dqn_net_builder.py index 105c390d..d05cf99d 100644 --- a/reagent/net_builder/quantile_dqn_net_builder.py +++ b/reagent/net_builder/quantile_dqn_net_builder.py @@ -3,7 +3,7 @@ import abc from typing import List -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.fb_checker import IS_FB_ENVIRONMENT from reagent.core.registry_meta import RegistryMeta diff --git a/reagent/ope/estimators/sequential_estimators.py b/reagent/ope/estimators/sequential_estimators.py index 2e46b206..b52b0b5d 100644 --- a/reagent/ope/estimators/sequential_estimators.py +++ b/reagent/ope/estimators/sequential_estimators.py @@ -687,27 +687,15 @@ class NeuralDualDICE(RLEstimator): ), "Expected all fields to be present" tgt_dist = input.target_policy.action_dist(t.state) tgt_action = tgt_dist.sample()[0] - samples["init_state"].append( - state.value.cpu().numpy() - if isinstance(state.value, torch.Tensor) - else state.value - ) + samples["init_state"].append(state.value) samples["init_action"].append( torch.nn.functional.one_hot( torch.tensor(tgt_init_action.value, dtype=torch.long), self.action_dim, ).float() ) - samples["last_state"].append( - t.last_state.value.cpu().numpy() - if isinstance(t.last_state.value, torch.Tensor) - else t.last_state.value - ) - samples["state"].append( - t.state.value.cpu().numpy() - if isinstance(t.state.value, torch.Tensor) - else t.state.value - ) + samples["last_state"].append(t.last_state.value) + samples["state"].append(t.state.value) samples["log_action"].append( torch.nn.functional.one_hot( torch.tensor(t.action.value, dtype=torch.long), self.action_dim diff --git a/reagent/parameters.py b/reagent/parameters.py index 950001f7..635fd8b9 100644 --- a/reagent/parameters.py +++ b/reagent/parameters.py @@ -58,7 +58,6 @@ class MDNRNNTrainerParameters(BaseDataClass): action_dim: int = 2 action_names: List[str] = field(default_factory=lambda: []) multi_steps: int = 1 - shuffle_training_data: bool = False @dataclass(frozen=True) diff --git a/reagent/parameters_seq2slate.py b/reagent/parameters_seq2slate.py index 14784834..d680d82d 100644 --- a/reagent/parameters_seq2slate.py +++ b/reagent/parameters_seq2slate.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Dict, Optional from reagent.core.dataclasses import dataclass -from reagent.core.types import BaseDataClass +from reagent.types import BaseDataClass class LearningMethod(Enum): diff --git a/reagent/prediction/predictor_wrapper.py b/reagent/prediction/predictor_wrapper.py index b0173d46..ea0db9dc 100644 --- a/reagent/prediction/predictor_wrapper.py +++ b/reagent/prediction/predictor_wrapper.py @@ -4,7 +4,7 @@ import logging from typing import Dict, List, Optional, Tuple -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.models.base import ModelBase from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet diff --git a/reagent/preprocessing/batch_preprocessor.py b/reagent/preprocessing/batch_preprocessor.py index 37797e3c..b2bfd7f6 100644 --- a/reagent/preprocessing/batch_preprocessor.py +++ b/reagent/preprocessing/batch_preprocessor.py @@ -6,7 +6,7 @@ from typing import Dict import torch import torch.nn as nn import torch.nn.functional as F -from reagent.core import types as rlt +from reagent import types as rlt from reagent.preprocessing.preprocessor import Preprocessor diff --git a/reagent/preprocessing/normalization.py b/reagent/preprocessing/normalization.py index b4426372..d3600926 100644 --- a/reagent/preprocessing/normalization.py +++ b/reagent/preprocessing/normalization.py @@ -7,24 +7,12 @@ from dataclasses import asdict from typing import Dict, List, Optional, Tuple import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import six import torch from reagent.parameters import NormalizationData, NormalizationParameters from reagent.preprocessing import identify_types from reagent.preprocessing.identify_types import DEFAULT_MAX_UNIQUE_ENUM, FEATURE_TYPES -from reagent.preprocessing.normalization_constants import ( - BOX_COX_MARGIN, - BOX_COX_MAX_STDDEV, - DEFAULT_MAX_QUANTILE_SIZE, - DEFAULT_NUM_SAMPLES, - DEFAULT_QUANTILE_K2_THRESHOLD, - EPS, - MAX_FEATURE_VALUE, - MIN_FEATURE_VALUE, - MINIMUM_SAMPLES_TO_IDENTIFY, - MISSING_VALUE, -) from scipy import stats from scipy.stats.mstats import mquantiles @@ -32,6 +20,18 @@ from scipy.stats.mstats import mquantiles logger = logging.getLogger(__name__) +BOX_COX_MAX_STDDEV = 1e8 +BOX_COX_MARGIN = 1e-4 +MISSING_VALUE = -1337.1337 +DEFAULT_QUANTILE_K2_THRESHOLD = 1000.0 +MINIMUM_SAMPLES_TO_IDENTIFY = 20 +DEFAULT_MAX_QUANTILE_SIZE = 20 +DEFAULT_NUM_SAMPLES = 100000 +MAX_FEATURE_VALUE = 6.0 +MIN_FEATURE_VALUE = MAX_FEATURE_VALUE * -1 +EPS = 1e-6 + + def no_op_feature(): return NormalizationParameters( identify_types.CONTINUOUS, None, 0, 0, 1, None, None, None, None diff --git a/reagent/preprocessing/normalization_constants.py b/reagent/preprocessing/normalization_constants.py deleted file mode 100644 index d2dbc07e..00000000 --- a/reagent/preprocessing/normalization_constants.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from reagent.preprocessing.identify_types import ( # noqa - DEFAULT_MAX_UNIQUE_ENUM, - FEATURE_TYPES, -) - - -BOX_COX_MAX_STDDEV = 1e8 -BOX_COX_MARGIN = 1e-4 -MISSING_VALUE = -1337.1337 -DEFAULT_QUANTILE_K2_THRESHOLD = 1000.0 -MINIMUM_SAMPLES_TO_IDENTIFY = 20 -DEFAULT_MAX_QUANTILE_SIZE = 20 -DEFAULT_NUM_SAMPLES = 100000 -MAX_FEATURE_VALUE = 6.0 -MIN_FEATURE_VALUE = MAX_FEATURE_VALUE * -1 -EPS = 1e-6 diff --git a/reagent/preprocessing/sparse_preprocessor.py b/reagent/preprocessing/sparse_preprocessor.py index 268b218e..00e250e9 100644 --- a/reagent/preprocessing/sparse_preprocessor.py +++ b/reagent/preprocessing/sparse_preprocessor.py @@ -4,7 +4,7 @@ import logging from typing import Dict, Tuple -import reagent.core.types as rlt +import reagent.types as rlt import torch diff --git a/reagent/preprocessing/transforms.py b/reagent/preprocessing/transforms.py index fbac6e73..fff4789d 100644 --- a/reagent/preprocessing/transforms.py +++ b/reagent/preprocessing/transforms.py @@ -5,7 +5,7 @@ import logging from typing import Callable, Dict, List, Optional, Tuple import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.parameters import NormalizationData diff --git a/reagent/publishers/file_system_publisher.py b/reagent/publishers/file_system_publisher.py index 37fe49bd..8d6bc59f 100644 --- a/reagent/publishers/file_system_publisher.py +++ b/reagent/publishers/file_system_publisher.py @@ -6,10 +6,9 @@ from typing import Optional from reagent.core.dataclasses import dataclass from reagent.core.result_types import NoPublishingResults -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import RecurringPeriod -from reagent.model_managers.model_manager import ModelManager +from reagent.core.types import RecurringPeriod, RLTrainingOutput from reagent.publishers.model_publisher import ModelPublisher +from reagent.workflow.model_managers.model_manager import ModelManager try: @@ -73,7 +72,7 @@ if HAS_TINYDB: child_workflow_id: int, recurring_period: Optional[RecurringPeriod], ) -> NoPublishingResults: - path = training_output.local_output_path + path = training_output.output_path assert path is not None, f"Given path is None." assert os.path.exists(path), f"Given path {path} doesn't exist." Model = Query() diff --git a/reagent/publishers/model_publisher.py b/reagent/publishers/model_publisher.py index ceae6f89..83baa66a 100644 --- a/reagent/publishers/model_publisher.py +++ b/reagent/publishers/model_publisher.py @@ -5,10 +5,9 @@ import inspect from typing import Optional from reagent.core.registry_meta import RegistryMeta -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import RecurringPeriod -from reagent.model_managers.model_manager import ModelManager -from reagent.reporting.result_registries import PublishingResult +from reagent.core.types import RecurringPeriod, RLTrainingOutput +from reagent.workflow.model_managers.model_manager import ModelManager +from reagent.workflow.result_registries import PublishingResult class ModelPublisher(metaclass=RegistryMeta): @@ -39,7 +38,7 @@ class ModelPublisher(metaclass=RegistryMeta): recurring_period, ) # Avoid circular dependency at import time - from reagent.core.union import PublishingResult__Union + from reagent.core.types import PublishingResult__Union # We need to use inspection because the result can be a future when running on # FBL diff --git a/reagent/publishers/no_publishing.py b/reagent/publishers/no_publishing.py index 670d05d6..1eda17da 100644 --- a/reagent/publishers/no_publishing.py +++ b/reagent/publishers/no_publishing.py @@ -4,10 +4,9 @@ from typing import Optional from reagent.core.dataclasses import dataclass from reagent.core.result_types import NoPublishingResults -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import RecurringPeriod -from reagent.model_managers.model_manager import ModelManager +from reagent.core.types import RecurringPeriod, RLTrainingOutput from reagent.publishers.model_publisher import ModelPublisher +from reagent.workflow.model_managers.model_manager import ModelManager @dataclass diff --git a/reagent/register.py b/reagent/register.py deleted file mode 100644 index 52d3a489..00000000 --- a/reagent/register.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from reagent.core.fb_checker import IS_FB_ENVIRONMENT - - -if True: # To prevent auto sorting of inputs - # Triggering registration to registries - import reagent.core.result_types # noqa - import reagent.reporting.oss_training_reports # noqa - from reagent.model_managers.union import * # noqa - - if IS_FB_ENVIRONMENT: - import reagent.core.fb.fb_result_types # noqa - - # Register all unions - from reagent.core.union import * # noqa - from reagent.model_managers.union import * # noqa - from reagent.optimizer.union import * # noqa - from reagent.publishers.union import * # noqa - from reagent.validators.union import * # noqa - - if IS_FB_ENVIRONMENT: - from reagent.model_managers.fb.union import * # noqa diff --git a/reagent/reporting/__init__.py b/reagent/reporting/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/reagent/reporting/actor_critic_reporter.py b/reagent/reporting/actor_critic_reporter.py deleted file mode 100644 index 96d7a315..00000000 --- a/reagent/reporting/actor_critic_reporter.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 - -import itertools -import logging - -from reagent.core import aggregators as agg -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.union import TrainingReport__Union -from reagent.reporting.oss_training_reports import OssActorCriticTrainingReport -from reagent.reporting.reporter_base import ReporterBase - - -logger = logging.getLogger(__name__) - - -class ActorCriticReporter(ReporterBase): - def __init__(self, report_interval: int = 100): - aggregators = itertools.chain( - [ - ("cpe_results", agg.AppendAggregator("cpe_details")), - ("td_loss", agg.MeanAggregator("td_loss", interval=report_interval)), - ( - "reward_loss", - agg.MeanAggregator("reward_loss", interval=report_interval), - ), - ( - "recent_rewards", - agg.RecentValuesAggregator( - "logged_rewards", interval=report_interval - ), - ), - ], - [ - ( - f"{key}_tb", - agg.TensorBoardHistogramAndMeanAggregator( - key, log_key, interval=report_interval - ), - ) - for key, log_key in [ - ("td_loss", "td_loss"), - ("reward_loss", "reward_loss"), - ("logged_propensities", "propensities/logged"), - ("logged_rewards", "reward/logged"), - ] - ], - ) - super().__init__(aggregators) - - # TODO: T71636196 write this for OSS - def publish(self) -> RLTrainingOutput: - report = OssActorCriticTrainingReport() - return RLTrainingOutput( - training_report=TrainingReport__Union(oss_actor_critic_report=report) - ) diff --git a/reagent/reporting/discrete_dqn_reporter.py b/reagent/reporting/discrete_dqn_reporter.py deleted file mode 100644 index e8f2a89f..00000000 --- a/reagent/reporting/discrete_dqn_reporter.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 - -import itertools -import logging -from typing import List, Optional - -import torch -from reagent.core import aggregators as agg -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.union import TrainingReport__Union -from reagent.reporting.oss_training_reports import OssDQNTrainingReport -from reagent.reporting.reporter_base import ReporterBase - - -logger = logging.getLogger(__name__) - - -class DiscreteDQNReporter(ReporterBase): - def __init__( - self, - actions: List[str], - report_interval: int = 100, - target_action_distribution: Optional[List[float]] = None, - recent_window_size: int = 100, - ): - aggregators = itertools.chain( - [ - ("CPE Results", agg.AppendAggregator("cpe_details")), - ("TD Loss", agg.MeanAggregator("td_loss", interval=report_interval)), - ( - "Reward Loss", - agg.MeanAggregator("reward_loss", interval=report_interval), - ), - ( - "Model Action Values", - agg.FunctionsByActionAggregator( - "model_values", - actions, - {"mean": torch.mean, "std": torch.std}, - interval=report_interval, - ), - ), - ( - "Logged Actions", - agg.ActionCountAggregator( - "logged_actions", actions, interval=report_interval - ), - ), - ( - "model_action", - agg.ActionCountAggregator( - "model_action_idxs", actions, interval=report_interval - ), - ), - ( - "Recent Logged Rewards", - agg.RecentValuesAggregator( - "logged_rewards", interval=report_interval - ), - ), - ], - [ - ( - f"{key}_tb", - agg.TensorBoardActionCountAggregator( - key, title, actions, interval=report_interval - ), - ) - for key, title in [ - ("logged_actions", "logged"), - ("model_action_idxs", "model"), - ] - ], - [ - ( - f"{key}_tb", - agg.TensorBoardHistogramAndMeanAggregator( - key, log_key, interval=report_interval - ), - ) - for key, log_key in [ - ("td_loss", "td_loss"), - ("reward_loss", "reward_loss"), - ("logged_propensities", "propensities/logged"), - ("logged_rewards", "reward/logged"), - ] - ], - [ - ( - f"{key}_tb", - agg.TensorBoardActionHistogramAndMeanAggregator( - key, category, title, actions, interval=report_interval - ), - ) - for key, category, title in [ - ("model_propensities", "propensities", "model"), - ("model_rewards", "reward", "model"), - ("model_values", "value", "model"), - ] - ], - ) - super().__init__(aggregators) - self.target_action_distribution = target_action_distribution - self.recent_window_size = recent_window_size - - def publish(self) -> RLTrainingOutput: - return RLTrainingOutput( - training_report=TrainingReport__Union(oss_dqn_report=OssDQNTrainingReport()) - ) diff --git a/reagent/reporting/oss_training_reports.py b/reagent/reporting/oss_training_reports.py deleted file mode 100644 index 52f9c893..00000000 --- a/reagent/reporting/oss_training_reports.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 - -from typing import List, Optional - -from reagent.core.dataclasses import dataclass -from reagent.evaluation.cpe import CpeEstimate -from reagent.reporting.training_reports import TrainingReport - - -@dataclass -class OssDQNTrainingReport(TrainingReport): - __registry_name__ = "oss_dqn_report" - - td_loss: Optional[float] = None - mc_loss: Optional[float] = None - reward_ips: Optional[CpeEstimate] = None - reward_dm: Optional[CpeEstimate] = None - reward_dr: Optional[CpeEstimate] = None - value_sequential_dr: Optional[CpeEstimate] = None - value_weighted_dr: Optional[CpeEstimate] = None - value_magic_dr: Optional[CpeEstimate] = None - - -@dataclass -class OssActorCriticTrainingReport(TrainingReport): - __registry_name__ = "oss_actor_critic_report" - - -@dataclass -class OssParametricDQNTrainingReport(TrainingReport): - __registry_name__ = "oss_parametric_dqn_report" - - td_loss: Optional[float] = None - mc_loss: Optional[float] = None - reward_ips: Optional[CpeEstimate] = None - reward_dm: Optional[CpeEstimate] = None - reward_dr: Optional[CpeEstimate] = None - value_sequential_dr: Optional[CpeEstimate] = None - value_weighted_dr: Optional[CpeEstimate] = None - value_magic_dr: Optional[CpeEstimate] = None - - -@dataclass -class OssWorldModelTrainingReport(TrainingReport): - __registry_name__ = "oss_world_model_report" - loss: List[float] - gmm: List[float] - bce: List[float] - mse: List[float] - - -@dataclass -class DebugToolsReport(TrainingReport): - __registry_name__ = "oss_debug_tools_report" - - feature_importance: Optional[List[float]] = None - feature_sensitivity: Optional[List[float]] = None - - -@dataclass -class OssRankingModelTrainingReport(TrainingReport): - __registry_name__ = "oss_ranking_model_training_report" diff --git a/reagent/reporting/parametric_dqn_reporter.py b/reagent/reporting/parametric_dqn_reporter.py deleted file mode 100644 index f348f200..00000000 --- a/reagent/reporting/parametric_dqn_reporter.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 - -import itertools -import logging -from typing import List, Optional - -from reagent.core import aggregators as agg -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.union import TrainingReport__Union -from reagent.reporting.oss_training_reports import OssParametricDQNTrainingReport -from reagent.reporting.reporter_base import ReporterBase - - -logger = logging.getLogger(__name__) - - -class ParametricDQNReporter(ReporterBase): - def __init__( - self, - report_interval: int = 100, - target_action_distribution: Optional[List[float]] = None, - recent_window_size: int = 100, - ): - aggregators = itertools.chain( - [ - ("cpe_results", agg.AppendAggregator("cpe_results")), - ("td_loss", agg.MeanAggregator("td_loss", interval=report_interval)), - ( - "reward_loss", - agg.MeanAggregator("reward_loss", interval=report_interval), - ), - ( - "logged_rewards", - agg.RecentValuesAggregator( - "logged_rewards", interval=report_interval - ), - ), - ], - [ - ( - f"{key}_tb", - agg.TensorBoardHistogramAndMeanAggregator( - key, log_key, interval=report_interval - ), - ) - for key, log_key in [ - ("td_loss", "td_loss"), - ("reward_loss", "reward_loss"), - ("logged_propensities", "propensities/logged"), - ("logged_rewards", "reward/logged"), - ] - ], - ) - super().__init__(aggregators) - self.target_action_distribution = target_action_distribution - self.recent_window_size = recent_window_size - - # TODO: T71636218 write this for OSS - def publish(self) -> RLTrainingOutput: - cpe_results = self.cpe_results.values - report = OssParametricDQNTrainingReport() - return RLTrainingOutput( - training_report=TrainingReport__Union(oss_parametric_dqn_report=report) - ) diff --git a/reagent/reporting/ranking_model_reporter.py b/reagent/reporting/ranking_model_reporter.py deleted file mode 100644 index 3c77de52..00000000 --- a/reagent/reporting/ranking_model_reporter.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 - -import logging - -from reagent.core import aggregators as agg -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.union import TrainingReport__Union -from reagent.reporting.oss_training_reports import OssRankingModelTrainingReport -from reagent.reporting.reporter_base import ReporterBase - - -logger = logging.getLogger(__name__) - - -class RankingModelReporter(ReporterBase): - def __init__(self, report_interval: int = 100): - """ - For Ranking model: - 'pg' (policy gradient loss) - 'baseline' (the baseline model's loss, usually for fitting V(s)) - 'kendall_tau' (kendall_tau coefficient between advantage and log_probs, - used in evaluation page handlers) - 'kendaull_tau_p_value' (the p-value for kendall_tau test, used in - evaluation page handlers) - """ - aggregators = [ - ("pg", agg.MeanAggregator("pg", interval=report_interval)), - ("baseline", agg.MeanAggregator("baseline", interval=report_interval)), - ( - "kendall_tau", - agg.MeanAggregator("kendall_tau", interval=report_interval), - ), - ( - "kendaull_tau_p_value", - agg.MeanAggregator("kendaull_tau_p_value", interval=report_interval), - ), - ] + [ - ( - f"{key}_tb", - agg.TensorBoardHistogramAndMeanAggregator( - key, log_key, interval=report_interval - ), - ) - for key, log_key in [ - ("pg", "pg"), - ("baseline", "baseline"), - ("kendall_tau", "kendall_tau"), - ("kendaull_tau_p_value", "kendaull_tau_p_value"), - ] - ] - super().__init__(aggregators) - - # TODO: T71636236 write this for OSS - def publish(self) -> RLTrainingOutput: - report = OssRankingModelTrainingReport() - return RLTrainingOutput( - training_report=TrainingReport__Union( - oss_ranking_model_training_report=report - ) - ) diff --git a/reagent/reporting/reporter_base.py b/reagent/reporting/reporter_base.py deleted file mode 100644 index ba1f2682..00000000 --- a/reagent/reporting/reporter_base.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 - -import logging -from collections import OrderedDict -from typing import Any, Dict, List, Tuple - -import torch -from reagent.core import aggregators as agg -from reagent.core.rl_training_output import RLTrainingOutput - - -logger = logging.getLogger(__name__) - - -class ReporterBase: - def __init__(self, aggregators: List[Tuple[str, agg.Aggregator]]): - self.aggregators = OrderedDict(aggregators) - - def report(self, **kwargs: Dict[str, Any]): - for name, value in kwargs.items(): - for aggregator in self.aggregators.values(): - if aggregator.key == name: - aggregator.update(name, value) - - def finish_epoch(self): - for aggregator in self.aggregators.values(): - aggregator.finish_epoch() - - def publish(self) -> RLTrainingOutput: - pass - - def get_recent(self, key: str, count: int, average: bool): - for _, aggregator in self.aggregators.items(): - if aggregator.key == key: - recent = aggregator.aggregator.get_recent(count) - if len(recent) == 0: - return None - if average: - return float(torch.mean(torch.tensor(recent))) - return recent - return None - - def get_all(self, key: str, average: bool): - for _, aggregator in self.aggregators.items(): - if aggregator.key == key: - all_data = aggregator.aggregator.get_all() - if len(all_data) == 0: - return None - if average: - return float(torch.mean(torch.tensor(all_data))) - return all_data - return None - - def __getattr__(self, key: str): - return self.aggregators[key] - - def end_epoch(self): - for aggregator in self.aggregators.values(): - aggregator.end_epoch() diff --git a/reagent/reporting/training_reporter.py b/reagent/reporting/training_reporter.py deleted file mode 100644 index d6e41c67..00000000 --- a/reagent/reporting/training_reporter.py +++ /dev/null @@ -1,363 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -import logging -import math -from collections import deque -from typing import Deque, List, NamedTuple, Optional - -import numpy as np -import torch -from reagent.tensorboardX import SummaryWriterContext - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -LOSS_REPORT_INTERVAL = 100 - - -class BatchStats(NamedTuple): - td_loss: Optional[torch.Tensor] = None - reward_loss: Optional[torch.Tensor] = None - imitator_loss: Optional[torch.Tensor] = None - logged_actions: Optional[torch.Tensor] = None - logged_propensities: Optional[torch.Tensor] = None - logged_rewards: Optional[torch.Tensor] = None - logged_values: Optional[torch.Tensor] = None - model_propensities: Optional[torch.Tensor] = None - model_rewards: Optional[torch.Tensor] = None - model_values: Optional[torch.Tensor] = None - model_values_on_logged_actions: Optional[torch.Tensor] = None - model_action_idxs: Optional[torch.Tensor] = None - - def write_summary(self, actions: List[str]): - if actions: - for field, log_key in [ - ("logged_actions", "actions/logged"), - ("model_action_idxs", "actions/model"), - ]: - val = getattr(self, field) - if val is None: - continue - for i, action in enumerate(actions): - # pyre-fixme[16]: `SummaryWriterContext` has no attribute - # `add_scalar`. - SummaryWriterContext.add_scalar( - "{}/{}".format(log_key, action), (val == i).sum().item() - ) - - for field, log_key in [ - ("td_loss", "td_loss"), - ("imitator_loss", "imitator_loss"), - ("reward_loss", "reward_loss"), - ("logged_propensities", "propensities/logged"), - ("logged_rewards", "reward/logged"), - ("logged_values", "value/logged"), - ("model_values_on_logged_actions", "value/model_logged_action"), - ]: - val = getattr(self, field) - if val is None: - continue - assert len(val.shape) == 1 or ( - len(val.shape) == 2 and val.shape[1] == 1 - ), "Unexpected shape for {}: {}".format(field, val.shape) - self._log_histogram_and_mean(log_key, val) - - for field, log_key in [ - ("model_propensities", "propensities/model"), - ("model_rewards", "reward/model"), - ("model_values", "value/model"), - ]: - val = getattr(self, field) - if val is None: - continue - if ( - len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1) - ) and not actions: - self._log_histogram_and_mean(log_key, val) - elif len(val.shape) == 2 and val.shape[1] == len(actions): - for i, action in enumerate(actions): - self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i]) - else: - raise ValueError( - "Unexpected shape for {}: {}; actions: {}".format( - field, val.shape, actions - ) - ) - - def _log_histogram_and_mean(self, log_key, val): - try: - SummaryWriterContext.add_histogram(log_key, val) - SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean()) - except ValueError: - logger.warning( - f"Cannot create histogram for key: {log_key}; " - "this is likely because you have NULL value in your input; " - f"value: {val}" - ) - raise - - @staticmethod - def add_custom_scalars(action_names: Optional[List[str]]): - if not action_names: - return - - SummaryWriterContext.add_custom_scalars_multilinechart( - [ - "propensities/model/{}/mean".format(action_name) - for action_name in action_names - ], - category="propensities", - title="model", - ) - SummaryWriterContext.add_custom_scalars_multilinechart( - [ - "propensities/logged/{}/mean".format(action_name) - for action_name in action_names - ], - category="propensities", - title="logged", - ) - SummaryWriterContext.add_custom_scalars_multilinechart( - ["actions/logged/{}".format(action_name) for action_name in action_names], - category="actions", - title="logged", - ) - SummaryWriterContext.add_custom_scalars_multilinechart( - ["actions/model/{}".format(action_name) for action_name in action_names], - category="actions", - title="model", - ) - - -def merge_tensor_namedtuple_list(l, cls): - def merge_tensor(f): - vals = [getattr(e, f) for e in l] - not_none_vals = [v for v in vals if v is not None] - assert len(not_none_vals) == 0 or len(not_none_vals) == len(vals) - if not not_none_vals: - return None - return torch.cat(not_none_vals, dim=0) - - return cls(**{f: merge_tensor(f) for f in cls._fields}) - - -class StatsByAction(object): - def __init__(self, actions): - self.stats = {action: [] for action in actions} - - def append(self, stats): - for k in stats: - assert k in self.stats - for k in self.stats: - v = stats.get(k, 0) - if isinstance(v, torch.Tensor): - v = v.item() - self.stats[k].append(v) - - def items(self): - return self.stats.items() - - def __len__(self): - return len(self.stats) - - -class NoOpTrainingReporter: - def report(self, **kwargs): - pass - - def flush(self): - pass - - -class TrainingReporter(object): - RECENT_WINDOW_SIZE = 100 - - def __init__(self, action_names: Optional[List[str]] = None): - assert action_names is None or len(action_names) > 0 - self.action_names: List[str] = action_names or [] - self.loss_report_interval = LOSS_REPORT_INTERVAL - BatchStats.add_custom_scalars(action_names) - self.clear() - - def clear(self): - self.running_reward: Deque[float] = deque(maxlen=int(1e6)) - - self.td_loss: List[float] = [] - self.reward_loss: List[float] = [] - self.imitator_loss: List[float] = [] - self.logged_action_q_value: List[float] = [] - self.logged_action_counts = {action: 0 for action in self.action_names} - self.model_values = StatsByAction(self.action_names) - self.model_value_stds = StatsByAction(self.action_names) - self.model_action_counts = StatsByAction(self.action_names) - self.model_action_counts_cumulative = { - action: 0 for action in self.action_names - } - self.model_action_distr = StatsByAction(self.action_names) - - self.incoming_stats: List[BatchStats] = [] - - @property - def num_batches(self): - return len(self.td_loss) - - def report(self, **kwargs): - def _to_tensor(v): - if v is None: - return None - if not isinstance(v, torch.Tensor): - v = torch.tensor(v) - if len(v.shape) == 0: - v = v.reshape(1) - return v.detach().cpu() - - kwargs = {k: _to_tensor(v) for k, v in kwargs.items()} - batch_stats = BatchStats(**kwargs) - self.incoming_stats.append(batch_stats) - if len(self.incoming_stats) >= self.loss_report_interval: - self.flush() - - @torch.no_grad() - def flush(self): - if not len(self.incoming_stats): - logger.info("Nothing to report") - return - - logger.info("Loss on {} batches".format(len(self.incoming_stats))) - - batch_stats = merge_tensor_namedtuple_list(self.incoming_stats, BatchStats) - batch_stats.write_summary(self.action_names) - - print_details = "Loss:\n" - - td_loss_mean = float(batch_stats.td_loss.mean()) - self.td_loss.append(td_loss_mean) - print_details = print_details + "TD LOSS: {0:.3f}\n".format(td_loss_mean) - - if batch_stats.logged_rewards is not None: - flattened_rewards = torch.flatten(batch_stats.logged_rewards).tolist() - self.running_reward.extend(flattened_rewards) - - if batch_stats.reward_loss is not None: - reward_loss_mean = float(batch_stats.reward_loss.mean()) - self.reward_loss.append(reward_loss_mean) - print_details = print_details + "REWARD LOSS: {0:.3f}\n".format( - reward_loss_mean - ) - - if batch_stats.imitator_loss is not None: - imitator_loss_mean = float(batch_stats.imitator_loss.mean()) - self.imitator_loss.append(imitator_loss_mean) - print_details = print_details + "IMITATOR LOSS: {0:.3f}\n".format( - imitator_loss_mean - ) - - if batch_stats.model_values is not None and self.action_names: - self.model_values.append( - dict(zip(self.action_names, batch_stats.model_values.mean(dim=0))) - ) - self.model_value_stds.append( - dict(zip(self.action_names, batch_stats.model_values.std(dim=0))) - ) - - if batch_stats.model_values_on_logged_actions is not None: - self.logged_action_q_value.append( - batch_stats.model_values_on_logged_actions.mean().item() - ) - - if ( - batch_stats.logged_actions is not None - and batch_stats.model_action_idxs is not None - ): - logged_action_counts = { - action: (batch_stats.logged_actions == i).sum().item() - for i, action in enumerate(self.action_names) - } - model_action_counts = { - action: (batch_stats.model_action_idxs == i).sum().item() - for i, action in enumerate(self.action_names) - } - print_details += "The distribution of logged actions : {}\n".format( - logged_action_counts - ) - print_details += "The distribution of model actions : {}\n".format( - model_action_counts - ) - for action, count in logged_action_counts.items(): - self.logged_action_counts[action] += count - - self.model_action_counts.append(model_action_counts) - - for action, count in model_action_counts.items(): - self.model_action_counts_cumulative[action] += count - - total = float(sum(model_action_counts.values())) - self.model_action_distr.append( - {action: count / total for action, count in model_action_counts.items()} - ) - - print_details += "Batch Evaluator Finished" - for print_detail in print_details.split("\n"): - logger.info(print_detail) - - self.incoming_stats.clear() - - def get_td_loss_after_n(self, n): - return self.td_loss[n:] - - def get_recent_td_loss(self): - return TrainingReporter.calculate_recent_window_average( - self.td_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1 - ) - - def get_recent_reward_loss(self): - return TrainingReporter.calculate_recent_window_average( - self.reward_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1 - ) - - def get_recent_imitator_loss(self): - return TrainingReporter.calculate_recent_window_average( - self.imitator_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1 - ) - - def get_logged_action_distribution(self): - total_actions = 1.0 * sum(self.logged_action_counts.values()) - return {k: (v / total_actions) for k, v in self.logged_action_counts.items()} - - def get_model_action_distribution(self): - total_actions = 1.0 * sum(self.model_action_counts_cumulative.values()) - return { - k: (v / total_actions) - for k, v in self.model_action_counts_cumulative.items() - } - - def get_recent_rewards(self): - return self.running_reward - - def log_to_tensorboard(self, epoch: int) -> None: - def none_to_zero(x: Optional[float]) -> float: - if x is None or math.isnan(x): - return 0.0 - return x - - for name, value in [ - ("Training/td_loss", self.get_recent_td_loss()), - ("Training/reward_loss", self.get_recent_reward_loss()), - ("Training/imitator_loss", self.get_recent_imitator_loss()), - ]: - # pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`. - SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch) - - @staticmethod - def calculate_recent_window_average(arr, window_size, num_entries): - if len(arr) > 0: - begin = max(0, len(arr) - window_size) - return np.mean(np.array(arr[begin:]), axis=0) - else: - logger.error("Not enough samples for evaluation.") - if num_entries == 1: - return float("nan") - else: - return [float("nan")] * num_entries diff --git a/reagent/reporting/training_reports.py b/reagent/reporting/training_reports.py deleted file mode 100644 index d3c42feb..00000000 --- a/reagent/reporting/training_reports.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env python3 - -from typing import Optional - -from reagent.core.registry_meta import RegistryMeta - - -class TrainingReport(metaclass=RegistryMeta): - pass diff --git a/reagent/reporting/world_model_reporter.py b/reagent/reporting/world_model_reporter.py deleted file mode 100644 index 6dde6c95..00000000 --- a/reagent/reporting/world_model_reporter.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python3 - -import itertools -import logging -from typing import List, Tuple - -from reagent.core import aggregators as agg -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.union import TrainingReport__Union -from reagent.reporting.oss_training_reports import ( - DebugToolsReport, - OssWorldModelTrainingReport, -) -from reagent.reporting.reporter_base import ReporterBase - - -logger = logging.getLogger(__name__) - - -class WorldModelReporter(ReporterBase): - def __init__(self, report_interval: int = 10): - """ - For world model: - 'loss' (referring to total loss), - 'bce' (loss for predicting not_terminal), - 'gmm' (loss for next state prediction), - 'mse' (loss for predicting reward) - """ - aggregators: List[Tuple[str, agg.Aggregator]] = list( - itertools.chain( - [ - ("loss", agg.MeanAggregator("loss", interval=report_interval)), - ("bce", agg.MeanAggregator("bce", interval=report_interval)), - ("gmm", agg.MeanAggregator("gmm", interval=report_interval)), - ("mse", agg.MeanAggregator("mse", interval=report_interval)), - ], - [ - ( - f"{key}_tb", - agg.TensorBoardHistogramAndMeanAggregator( - key, log_key, interval=report_interval - ), - ) - for key, log_key in [ - ("loss", "loss"), - ("bce", "bce"), - ("gmm", "gmm"), - ("mse", "mse"), - ] - ], - ) - ) - super().__init__(aggregators) - - def publish(self) -> RLTrainingOutput: - report = OssWorldModelTrainingReport( - loss=self.loss.values, - bce=self.bce.values, - gmm=self.gmm.values, - mse=self.mse.values, - ) - return RLTrainingOutput( - training_report=TrainingReport__Union(oss_world_model_report=report) - ) - - -class DebugToolsReporter(ReporterBase): - def __init__(self, report_interval: int = 1): - """ - For debug tools: feature_importance, feature_sensitivity - """ - aggregators: List[Tuple[str, agg.Aggregator]] = [ - ("feature_importance", agg.AppendAggregator("feature_importance")), - ("feature_sensitivity", agg.AppendAggregator("feature_sensitivity")), - ] - super().__init__(aggregators) - - def publish(self) -> RLTrainingOutput: - feature_importance = ( - [] - if len(self.feature_importance.values) == 0 - else self.feature_importance.values[-1] - ) - feature_sensitivity = ( - [] - if len(self.feature_sensitivity.values) == 0 - else self.feature_sensitivity.values[-1] - ) - report = DebugToolsReport( - feature_importance=feature_importance, - feature_sensitivity=feature_sensitivity, - ) - return RLTrainingOutput( - training_report=TrainingReport__Union(oss_debug_tools_report=report) - ) diff --git a/reagent/runners/batch_runner.py b/reagent/runners/batch_runner.py deleted file mode 100644 index 8335873c..00000000 --- a/reagent/runners/batch_runner.py +++ /dev/null @@ -1,402 +0,0 @@ -#!/usr/bin/env python3 - -import dataclasses -import logging -import time -from contextlib import contextmanager -from typing import Dict, NamedTuple, Optional, Tuple - -import torch -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import ( - Dataset, - ReaderOptions, - RecurringPeriod, - ResourceOptions, - RewardOptions, - TableSpec, -) -from reagent.data_fetchers.data_fetcher import DataFetcher -from reagent.evaluation.evaluator import Evaluator -from reagent.model_managers.model_manager import ModelManager -from reagent.parameters import NormalizationData -from reagent.preprocessing.batch_preprocessor import BatchPreprocessor -from reagent.publishers.model_publisher import ModelPublisher -from reagent.tensorboardX import SummaryWriterContext, summary_writer_context -from reagent.training.trainer import Trainer -from reagent.validators.model_validator import ModelValidator -from reagent.workflow_utils.iterators import DataLoaderWrapper -from torch.utils.tensorboard import SummaryWriter - - -logger = logging.getLogger(__name__) - - -class TrainEvalSampleRanges(NamedTuple): - train_sample_range: Tuple[float, float] - eval_sample_range: Tuple[float, float] - - -class BatchRunner: - def __init__( - self, - use_gpu: bool, - model_manager: ModelManager, - data_fetcher: DataFetcher, - reward_options: RewardOptions, - normalization_data_map: Dict[str, NormalizationData], - warmstart_path: Optional[str] = None, - ): - self.use_gpu = use_gpu - self.model_manager = model_manager - self.data_fetcher = data_fetcher - self.normalization_data_map = normalization_data_map - self.reward_options = reward_options - self.warmstart_path = warmstart_path - - def get_workflow_id(self) -> int: - raise NotImplementedError() - - def initialize_trainer(self) -> Trainer: - # validate that we have all the required keys - for normalization_key in self.model_manager.required_normalization_keys: - normalization_data = self.normalization_data_map.get( - normalization_key, None - ) - assert normalization_data is not None, ( - f"NormalizationData for {normalization_key} " - "is required but not provided." - ) - # NOTE: Don't need this check in the future, for non-dense parameters - assert normalization_data.dense_normalization_parameters is not None, ( - f"Dense normalization parameters for " - f"{normalization_key} is not provided." - ) - trainer = self.model_manager.build_trainer( - self.use_gpu, self.normalization_data_map, self.reward_options - ) - if self.warmstart_path is not None: - trainer_state = torch.load(self.warmstart_path) - trainer.load_state_dict(trainer_state) - - self.trainer = trainer - return trainer - - def save_trainer(self, trainer: Trainer, output_path: str) -> None: - """ - Save the trainer for warmstarting/checkpointing. - """ - trainer_state = trainer.state_dict() - torch.save(trainer_state, output_path) - - @staticmethod - def get_sample_range( - input_table_spec: TableSpec, calc_cpe_in_training: bool - ) -> TrainEvalSampleRanges: - table_sample = input_table_spec.table_sample - eval_table_sample = input_table_spec.eval_table_sample - - if not calc_cpe_in_training: - # use all data if table sample = None - if table_sample is None: - train_sample_range = (0.0, 100.0) - else: - train_sample_range = (0.0, table_sample) - return TrainEvalSampleRanges( - train_sample_range=train_sample_range, - # eval samples will not be used - eval_sample_range=(0.0, 0.0), - ) - - error_msg = ( - "calc_cpe_in_training is set to True. " - f"Please specify table_sample(current={table_sample}) and " - f"eval_table_sample(current={eval_table_sample}) such that " - "eval_table_sample + table_sample <= 100. " - "In order to reliably calculate CPE, eval_table_sample " - "should not be too small." - ) - assert table_sample is not None, error_msg - assert eval_table_sample is not None, error_msg - assert (eval_table_sample + table_sample) <= (100.0 + 1e-3), error_msg - - return TrainEvalSampleRanges( - train_sample_range=(0.0, table_sample), - eval_sample_range=(100.0 - eval_table_sample, 100.0), - ) - - def query( - self, - input_table_spec: TableSpec, - reader_options: ReaderOptions, - resource_options: ResourceOptions, - ) -> Tuple[Dataset, Dataset]: - logger.info("Starting query") - - calc_cpe_in_training = self.model_manager.should_generate_eval_dataset - sample_range_output = BatchRunner.get_sample_range( - input_table_spec, calc_cpe_in_training - ) - train_dataset = self.model_manager.query_data( - data_fetcher=self.data_fetcher, - input_table_spec=input_table_spec, - sample_range=sample_range_output.train_sample_range, - reward_options=self.reward_options, - ) - eval_dataset = None - if calc_cpe_in_training: - eval_dataset = self.model_manager.query_data( - data_fetcher=self.data_fetcher, - input_table_spec=input_table_spec, - sample_range=sample_range_output.eval_sample_range, - reward_options=self.reward_options, - ) - - return (train_dataset, eval_dataset) - - def run_feature_identification( - self, input_table_spec: TableSpec - ) -> Dict[str, NormalizationData]: - return self.model_manager.run_feature_identification( - self.data_fetcher, input_table_spec - ) - - def train( - self, - train_dataset: Dataset, - eval_dataset: Dataset, - normalization_data_map: Dict[str, NormalizationData], - num_epochs: int, - reader_options: ReaderOptions, - resource_options: Optional[ResourceOptions] = None, - warmstart_path: Optional[str] = None, - validator: Optional[ModelValidator] = None, - parent_workflow_id: Optional[int] = None, - recurring_period: Optional[RecurringPeriod] = None, - ) -> RLTrainingOutput: - logger.info(f"{reader_options}") - child_workflow_id = self.get_workflow_id() - if parent_workflow_id is None: - parent_workflow_id = child_workflow_id - - resource_options = resource_options or ResourceOptions() - - logger.info("Starting training") - results = self.train_workflow( - train_dataset, - eval_dataset, - num_epochs, - parent_workflow_id=parent_workflow_id, - child_workflow_id=child_workflow_id, - reader_options=reader_options, - resource_options=resource_options, - ) - - if validator is not None: - results = self.run_validator(validator, results) - - return results - - def run_validator( - self, model_validator: ModelValidator, training_output: RLTrainingOutput - ) -> RLTrainingOutput: - assert ( - training_output.validation_result is None - ), f"validation_output was set to f{training_output.validation_output}" - validation_result = model_validator.validate(training_output) - return dataclasses.replace(training_output, validation_result=validation_result) - - def run_publisher( - self, - model_publisher: ModelPublisher, - training_output: RLTrainingOutput, - recurring_workflow_id: int, - child_workflow_id: int, - recurring_period: Optional[RecurringPeriod], - ) -> RLTrainingOutput: - assert ( - training_output.publishing_result is None - ), f"publishing_output was set to f{training_output.publishing_output}" - publishing_result = model_publisher.publish( - self.model_manager, - training_output, - recurring_workflow_id, - child_workflow_id, - recurring_period, - ) - return dataclasses.replace(training_output, publishing_result=publishing_result) - - def train_workflow( - self, - train_dataset: Dataset, - eval_dataset: Optional[Dataset], - num_epochs: int, - parent_workflow_id: int, - child_workflow_id: int, - reader_options: ReaderOptions, - resource_options: Optional[ResourceOptions] = None, - ) -> RLTrainingOutput: - writer = SummaryWriter() - logger.info("TensorBoard logging location is: {}".format(writer.log_dir)) - - trainer = self.initialize_trainer() - - with summary_writer_context(writer): - train_output: RLTrainingOutput = self._train( - train_dataset, eval_dataset, num_epochs, reader_options, trainer - ) - - torchscript_output_path = f"model_{round(time.time())}.torchscript" - serving_module = self.model_manager.build_serving_module( - self.normalization_data_map, trainer - ) - torch.jit.save(serving_module, torchscript_output_path) - logger.info(f"Saved torchscript model to {torchscript_output_path}") - return dataclasses.replace( - train_output, local_output_path=torchscript_output_path - ) - - def _train( - self, - train_dataset: Dataset, - eval_dataset: Optional[Dataset], - num_epochs: int, - reader_options: ReaderOptions, - trainer: Trainer, - ) -> RLTrainingOutput: - reporter = self.model_manager.get_reporter() - trainer.reporter = reporter - - evaluator = self.model_manager.get_evaluator(trainer, self.reward_options) - if evaluator is not None: - evaluator.reporter = reporter - - batch_preprocessor = self.model_manager.build_batch_preprocessor( - reader_options, - self.use_gpu, - trainer.minibatch_size, - self.normalization_data_map, - self.reward_options, - ) - return self.train_and_evaluate_generic( - train_dataset, - eval_dataset, - trainer, - num_epochs, - self.use_gpu, - batch_preprocessor, - evaluator, - reader_options, - ) - - def run_on_dataset_batches( - self, - run_on_batch_fn, - dataset: Dataset, - minibatch_size: int, - batch_preprocessor: BatchPreprocessor, - use_gpu: bool, - reader_options: ReaderOptions, - dataset_size: Optional[int] = None, - ) -> torch.utils.data.DataLoader: - logger.info(f"{reader_options}") - """ run_on_batch_fn is a function f that expects batches """ - if dataset_size is None: - dataset_size = self.data_fetcher.get_table_row_count(dataset) - assert dataset_size is not None - assert dataset_size > 0, f"{dataset_size} is expected to be positive" - - @contextmanager - def cleanup_dataloader_session(data_loader): - try: - yield data_loader - finally: - logger.info("Closing data loader") - if hasattr(data_loader, "destroy_session"): - logger.info("Closing DistributedDataLoader") - data_loader.destroy_session() - - _dataloader = self.data_fetcher.get_dataloader( - dataset=dataset, - batch_size=minibatch_size, - batch_preprocessor=batch_preprocessor, - use_gpu=use_gpu, - reader_options=reader_options, - ) - with cleanup_dataloader_session(_dataloader) as dataloader: - post_dataloader_preprocessor = self.data_fetcher.get_post_dataloader_preprocessor( - reader_options=reader_options, use_gpu=use_gpu - ) - dataloader_wrapper = DataLoaderWrapper( - dataloader=dataloader, - dataloader_size=dataset_size, - post_dataloader_preprocessor=post_dataloader_preprocessor, - ) - for batch in dataloader_wrapper: - run_on_batch_fn(batch) - return dataloader - - def train_and_evaluate_generic( - self, - train_dataset: Dataset, - eval_dataset: Optional[Dataset], - trainer: Trainer, - num_epochs: int, - use_gpu: bool, - batch_preprocessor: BatchPreprocessor, - evaluator: Optional[Evaluator], - reader_options: ReaderOptions, - sort_eval_data: bool = True, - ) -> RLTrainingOutput: - logger.info(f"{reader_options}") - assert num_epochs > 0, f"Epoch should be positive, got {num_epochs}" - train_dataset_size = self.data_fetcher.get_table_row_count(train_dataset) - if eval_dataset is not None and not sort_eval_data: - eval_dataset_size = self.data_fetcher.get_table_row_count(eval_dataset) - - for epoch in range(num_epochs): - SummaryWriterContext._reset_globals() - logger.info(f"Starting training epoch {epoch}.") - data_loader = self.run_on_dataset_batches( - run_on_batch_fn=trainer.train, - dataset=train_dataset, - minibatch_size=trainer.minibatch_size, - batch_preprocessor=batch_preprocessor, - use_gpu=use_gpu, - reader_options=reader_options, - dataset_size=train_dataset_size, - ) - if eval_dataset is not None and evaluator is not None: - if sort_eval_data: - logger.info( - f"Starting evaluation epoch {epoch} by sorting and one shot" - ) - eval_data = self.data_fetcher.gather_and_sort_eval_data( - trainer=trainer, - eval_dataset=eval_dataset, - batch_preprocessor=batch_preprocessor, - use_gpu=use_gpu, - reader_options=reader_options, - ) - evaluator.evaluate_one_shot(eval_data) - evaluator.finish() - else: - logger.info( - f"Starting evaluation epoch {epoch} by running on batches" - ) - data_loader = self.run_on_dataset_batches( - run_on_batch_fn=evaluator.evaluate, - dataset=eval_dataset, - minibatch_size=trainer.minibatch_size, - batch_preprocessor=batch_preprocessor, - use_gpu=use_gpu, - reader_options=reader_options, - dataset_size=eval_dataset_size, - ) - evaluator.finish() - trainer.reporter.finish_epoch() - report = trainer.reporter.publish() - - if hasattr(data_loader, "shutdown"): - data_loader.shutdown() - return report diff --git a/reagent/runners/oss_batch_runner.py b/reagent/runners/oss_batch_runner.py deleted file mode 100644 index ed391445..00000000 --- a/reagent/runners/oss_batch_runner.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -import logging -import random -from typing import Dict, Optional - -from reagent.core.types import RewardOptions -from reagent.data_fetchers.oss_data_fetcher import OssDataFetcher -from reagent.model_managers.model_manager import ModelManager -from reagent.parameters import NormalizationData -from reagent.runners.batch_runner import BatchRunner - - -logger = logging.getLogger(__name__) - - -class OssBatchRunner(BatchRunner): - def __init__( - self, - use_gpu: bool, - model_manager: ModelManager, - reward_options: RewardOptions, - normalization_data_map: Dict[str, NormalizationData], - warmstart_path: Optional[str] = None, - ): - super().__init__( - use_gpu, - model_manager, - OssDataFetcher(), - reward_options, - normalization_data_map, - warmstart_path, - ) - # Generate a random workflow id for this batch runner - self.workflow_id = random.randint(1000, 10000000) - - def get_workflow_id(self) -> int: - return self.workflow_id diff --git a/reagent/test/core/tracker_test.py b/reagent/test/core/tracker_test.py new file mode 100644 index 00000000..51484498 --- /dev/null +++ b/reagent/test/core/tracker_test.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import unittest + +from reagent.core.observers import ValueListObserver +from reagent.core.tracker import observable + + +class TestObservable(unittest.TestCase): + def test_observable(self): + @observable(td_loss=float, str_val=str) + class DummyClass: + def __init__(self, a, b, c=10): + super().__init__() + self.a = a + self.b = b + self.c = c + + def do_something(self, i): + self.notify_observers(td_loss=i, str_val="not_used") + + instance = DummyClass(1, 2) + self.assertIsInstance(instance, DummyClass) + self.assertEqual(instance.a, 1) + self.assertEqual(instance.b, 2) + self.assertEqual(instance.c, 10) + + observers = [ValueListObserver("td_loss") for _i in range(3)] + instance.add_observers(observers) + # Adding twice should not result in double update + instance.add_observer(observers[0]) + + for i in range(10): + instance.do_something(float(i)) + + for observer in observers: + self.assertEqual(observer.values, [float(i) for i in range(10)]) + + def test_no_observable_values(self): + try: + + @observable() + class NoObservableValues: + pass + + except AssertionError: + pass diff --git a/reagent/test/evaluation/test_evaluation_data_page.py b/reagent/test/evaluation/test_evaluation_data_page.py index fa2d2828..8fa9a372 100644 --- a/reagent/test/evaluation/test_evaluation_data_page.py +++ b/reagent/test/evaluation/test_evaluation_data_page.py @@ -8,7 +8,7 @@ from typing import Optional import numpy as np import torch import torch.nn as nn -from reagent.core import types as rlt +from reagent import types as rlt from reagent.evaluation.doubly_robust_estimator import DoublyRobustEstimator from reagent.evaluation.evaluation_data_page import EvaluationDataPage from reagent.evaluation.ope_adapter import OPEstimatorAdapter diff --git a/reagent/test/evaluation/test_ope_integration.py b/reagent/test/evaluation/test_ope_integration.py index 948ada42..3c46abbf 100644 --- a/reagent/test/evaluation/test_ope_integration.py +++ b/reagent/test/evaluation/test_ope_integration.py @@ -1,11 +1,15 @@ import logging +import random import unittest import numpy as np import torch -from reagent.core import types as rlt +from reagent import types as rlt from reagent.evaluation.evaluation_data_page import EvaluationDataPage -from reagent.evaluation.ope_adapter import OPEstimatorAdapter +from reagent.evaluation.ope_adapter import ( + OPEstimatorAdapter, + SequentialOPEstimatorAdapter, +) from reagent.ope.estimators.contextual_bandits_estimators import ( DMEstimator, DoublyRobustEstimator, @@ -13,6 +17,20 @@ from reagent.ope.estimators.contextual_bandits_estimators import ( SwitchDREstimator, SwitchEstimator, ) +from reagent.ope.estimators.sequential_estimators import ( + DoublyRobustEstimator as SeqDREstimator, + EpsilonGreedyRLPolicy, + RandomRLPolicy, + RLEstimatorInput, +) +from reagent.ope.estimators.types import Action, ActionSpace +from reagent.ope.test.envs import PolicyLogGenerator +from reagent.ope.test.gridworld import GridWorld, NoiseGridWorldModel +from reagent.ope.trainers.rl_tabular_trainers import ( + DPTrainer, + DPValueFunction, + TabularPolicy, +) from reagent.test.evaluation.test_evaluation_data_page import ( FakeSeq2SlateRewardNetwork, FakeSeq2SlateTransformerNet, @@ -22,6 +40,56 @@ from reagent.test.evaluation.test_evaluation_data_page import ( logger = logging.getLogger(__name__) +def rlestimator_input_to_edp( + input: RLEstimatorInput, num_actions: int +) -> EvaluationDataPage: + mdp_ids = [] + logged_propensities = [] + logged_rewards = [] + action_mask = [] + model_propensities = [] + model_values = [] + + for mdp in input.log: + mdp_id = len(mdp_ids) + for t in mdp: + mdp_ids.append(mdp_id) + logged_propensities.append(t.action_prob) + logged_rewards.append(t.reward) + assert t.action is not None + action_mask.append( + [1 if x == t.action.value else 0 for x in range(num_actions)] + ) + assert t.last_state is not None + model_propensities.append( + [ + input.target_policy(t.last_state)[Action(x)] + for x in range(num_actions) + ] + ) + assert input.value_function is not None + model_values.append( + [ + input.value_function(t.last_state, Action(x)) + for x in range(num_actions) + ] + ) + + return EvaluationDataPage( + mdp_id=torch.tensor(mdp_ids).reshape(len(mdp_ids), 1), + logged_propensities=torch.tensor(logged_propensities).reshape( + (len(logged_propensities), 1) + ), + logged_rewards=torch.tensor(logged_rewards).reshape((len(logged_rewards), 1)), + action_mask=torch.tensor(action_mask), + model_propensities=torch.tensor(model_propensities), + model_values=torch.tensor(model_values), + sequence_number=torch.tensor([]), + model_rewards=torch.tensor([]), + model_rewards_for_logged_action=torch.tensor([]), + ) + + class TestOPEModuleAlgs(unittest.TestCase): GAMMA = 0.9 CPE_PASS_BAR = 1.0 @@ -30,6 +98,98 @@ class TestOPEModuleAlgs(unittest.TestCase): NOISE_EPSILON = 0.3 EPISODES = 2 + def test_gridworld_sequential_adapter(self): + """ + Create a gridworld environment, logging policy, and target policy + Evaluates target policy using the direct OPE sequential doubly robust estimator, + then transforms the log into an evaluation data page which is passed to the ope adapter. + + This test is meant to verify the adaptation of EDPs into RLEstimatorInputs as employed + by ReAgent since ReAgent provides EDPs to Evaluators. Going from EDP -> RLEstimatorInput + is more involved than RLEstimatorInput -> EDP since the EDP does not store the state + at each timestep in each MDP, only the corresponding logged outputs & model outputs. + Thus, the adapter must do some tricks to represent these timesteps as states so the + ope module can extract the correct outputs. + + Note that there is some randomness in the model outputs since the model is purposefully + noisy. However, the same target policy is being evaluated on the same logged walks through + the gridworld, so the two results should be close in value (within 1). + + """ + random.seed(0) + np.random.seed(0) + torch.random.manual_seed(0) + + device = torch.device("cuda") if torch.cuda.is_available() else None + + gridworld = GridWorld.from_grid( + [ + ["s", "0", "0", "0", "0"], + ["0", "0", "0", "W", "0"], + ["0", "0", "0", "0", "0"], + ["0", "W", "0", "0", "0"], + ["0", "0", "0", "0", "g"], + ], + max_horizon=TestOPEModuleAlgs.MAX_HORIZON, + ) + + action_space = ActionSpace(4) + opt_policy = TabularPolicy(action_space) + trainer = DPTrainer(gridworld, opt_policy) + value_func = trainer.train(gamma=TestOPEModuleAlgs.GAMMA) + + behavivor_policy = RandomRLPolicy(action_space) + target_policy = EpsilonGreedyRLPolicy( + opt_policy, TestOPEModuleAlgs.NOISE_EPSILON + ) + model = NoiseGridWorldModel( + gridworld, + action_space, + epsilon=TestOPEModuleAlgs.NOISE_EPSILON, + max_horizon=TestOPEModuleAlgs.MAX_HORIZON, + ) + value_func = DPValueFunction(target_policy, model, TestOPEModuleAlgs.GAMMA) + ground_truth = DPValueFunction( + target_policy, gridworld, TestOPEModuleAlgs.GAMMA + ) + + log = [] + log_generator = PolicyLogGenerator(gridworld, behavivor_policy) + num_episodes = TestOPEModuleAlgs.EPISODES + for state in gridworld.states: + for _ in range(num_episodes): + log.append(log_generator.generate_log(state)) + + estimator_input = RLEstimatorInput( + gamma=TestOPEModuleAlgs.GAMMA, + log=log, + target_policy=target_policy, + value_function=value_func, + ground_truth=ground_truth, + ) + + edp = rlestimator_input_to_edp(estimator_input, len(model.action_space)) + + dr_estimator = SeqDREstimator( + weight_clamper=None, weighted=False, device=device + ) + + module_results = SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate( + dr_estimator.evaluate(estimator_input) + ) + adapter_results = SequentialOPEstimatorAdapter( + dr_estimator, TestOPEModuleAlgs.GAMMA, device=device + ).estimate(edp) + + self.assertAlmostEqual( + adapter_results.raw, + module_results.raw, + delta=TestOPEModuleAlgs.CPE_PASS_BAR, + ), f"OPE adapter results differed too much from underlying module (Diff: {abs(adapter_results.raw - module_results.raw)} > {TestOPEModuleAlgs.CPE_PASS_BAR})" + self.assertLess( + adapter_results.raw, TestOPEModuleAlgs.CPE_MAX_VALUE + ), f"OPE adapter results are too large ({adapter_results.raw} > {TestOPEModuleAlgs.CPE_MAX_VALUE})" + def test_seq2slate_eval_data_page(self): """ Create 3 slate ranking logs and evaluate using Direct Method, Inverse diff --git a/reagent/test/models/test_base.py b/reagent/test/models/test_base.py index 3201a186..d162a587 100644 --- a/reagent/test/models/test_base.py +++ b/reagent/test/models/test_base.py @@ -8,7 +8,7 @@ from typing import Any import torch import torch.nn as nn -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.base import ModelBase from reagent.test.models.test_utils import check_save_load diff --git a/reagent/test/models/test_bcq.py b/reagent/test/models/test_bcq.py index a496a87c..08876344 100644 --- a/reagent/test/models/test_bcq.py +++ b/reagent/test/models/test_bcq.py @@ -7,7 +7,7 @@ import unittest import numpy.testing as npt import torch import torch.nn.init as init -from reagent.core import types as rlt +from reagent import types as rlt from reagent.models.bcq import BatchConstrainedDQN from reagent.models.dqn import FullyConnectedDQN from reagent.models.fully_connected_network import FullyConnectedNetwork diff --git a/reagent/test/models/test_no_soft_update_embedding.py b/reagent/test/models/test_no_soft_update_embedding.py index a9ac839d..0dd19143 100644 --- a/reagent/test/models/test_no_soft_update_embedding.py +++ b/reagent/test/models/test_no_soft_update_embedding.py @@ -43,9 +43,7 @@ class TestNoSoftUpdteEmbedding(unittest.TestCase): self.assertEqual(1, len(params)) param = params[0].detach().numpy() - trainer = RLTrainer( - rl_parameters=RLParameters(), minibatch_size=1024, use_gpu=False - ) + trainer = RLTrainer(rl_parameters=RLParameters(), use_gpu=False) trainer._soft_update(model, target_model, 0.1) target_params = list(target_model.parameters()) diff --git a/reagent/test/net_builder/test_discrete_dqn_net_builder.py b/reagent/test/net_builder/test_discrete_dqn_net_builder.py index 7aea22c0..bae53c0e 100644 --- a/reagent/test/net_builder/test_discrete_dqn_net_builder.py +++ b/reagent/test/net_builder/test_discrete_dqn_net_builder.py @@ -4,7 +4,7 @@ import unittest from typing import Optional -from reagent.core import types as rlt +from reagent import types as rlt from reagent.core.fb_checker import IS_FB_ENVIRONMENT from reagent.net_builder import discrete_dqn from reagent.net_builder.unions import DiscreteDQNNetBuilder__Union diff --git a/reagent/test/prediction/test_predictor_wrapper.py b/reagent/test/prediction/test_predictor_wrapper.py index dd217c0e..a920c653 100644 --- a/reagent/test/prediction/test_predictor_wrapper.py +++ b/reagent/test/prediction/test_predictor_wrapper.py @@ -3,8 +3,8 @@ import unittest -import reagent.core.types as rlt import reagent.models as models +import reagent.types as rlt import torch from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet from reagent.prediction.predictor_wrapper import ( diff --git a/reagent/test/workflow/reagent_sql_test_base.py b/reagent/test/workflow/reagent_sql_test_base.py index aaf1b3ed..35aefdb0 100644 --- a/reagent/test/workflow/reagent_sql_test_base.py +++ b/reagent/test/workflow/reagent_sql_test_base.py @@ -11,7 +11,7 @@ import torch # pyre-fixme[21]: Could not find `pyspark`. from pyspark import SparkConf -from reagent.workflow.spark_utils import DEFAULT_SPARK_CONFIG, TEST_SPARK_SESSION +from reagent.workflow.spark_utils import DEFAULT_SPARK_CONFIG # pyre-fixme[21]: Could not find `sparktestingbase`. from sparktestingbase.sqltestcase import SQLTestCase @@ -49,7 +49,6 @@ class ReagentSQLTestBase(SQLTestCase): def setUp(self): super().setUp() - TEST_SPARK_SESSION = self.sc assert not os.path.isdir( HIVE_METASTORE ), f"{HIVE_METASTORE} already exists! Try deleting it." diff --git a/reagent/test/workflow/test_oss_workflows.py b/reagent/test/workflow/test_oss_workflows.py index 781a9662..1eae8105 100644 --- a/reagent/test/workflow/test_oss_workflows.py +++ b/reagent/test/workflow/test_oss_workflows.py @@ -12,7 +12,7 @@ import reagent import reagent.workflow.cli as cli import torch from click.testing import CliRunner -from reagent.core.types import OssDataset +from reagent.core.types import Dataset, OssDataset from reagent.parameters import NormalizationParameters from reagent.test.base.horizon_test_base import HorizonTestBase from ruamel.yaml import YAML @@ -36,7 +36,7 @@ DQN_WORKFLOW_YAML = os.path.join( NEW_CONFIG_NAME = "config.yaml" # module to patch -OSS_DATA_FECTHER = "reagent.data_fetchers.oss_data_fetcher" +DISCRETE_DQN_BASE = "reagent.workflow.model_managers.discrete_dqn_base" def get_test_workflow_config(path_to_config: str, use_gpu: bool): @@ -93,9 +93,9 @@ class TestOSSWorkflows(HorizonTestBase): ) mock_normalization = mock_cartpole_normalization() with patch( - f"{OSS_DATA_FECTHER}.query_data", return_value=mock_dataset + f"{DISCRETE_DQN_BASE}.query_data", return_value=mock_dataset ), patch( - f"{OSS_DATA_FECTHER}.identify_normalization_parameters", + f"{DISCRETE_DQN_BASE}.identify_normalization_parameters", return_value=mock_normalization, ): # call the cli test diff --git a/reagent/test/workflow/test_preprocessing.py b/reagent/test/workflow/test_preprocessing.py index e90baa57..96298b03 100644 --- a/reagent/test/workflow/test_preprocessing.py +++ b/reagent/test/workflow/test_preprocessing.py @@ -9,11 +9,11 @@ import numpy as np # pyre-fixme[21]: Could not find `pytest`. import pytest from reagent.core.types import PreprocessingOptions, TableSpec -from reagent.data_fetchers.oss_data_fetcher import OssDataFetcher from reagent.preprocessing.identify_types import CONTINUOUS # pyre-fixme[21]: Could not find `workflow`. from reagent.test.workflow.reagent_sql_test_base import ReagentSQLTestBase +from reagent.workflow.identify_types_flow import identify_normalization_parameters logger = logging.getLogger(__name__) @@ -52,8 +52,7 @@ class TestPreprocessing(ReagentSQLTestBase): table_spec = TableSpec(table=TABLE_NAME) - df = OssDataFetcher() - normalization_params = df.identify_normalization_parameters( + normalization_params = identify_normalization_parameters( table_spec, COL_NAME, preprocessing_options, seed=self.test_class_seed ) diff --git a/reagent/test/workflow/test_query_data.py b/reagent/test/workflow/test_query_data.py index b7eabaae..dadd57ae 100644 --- a/reagent/test/workflow/test_query_data.py +++ b/reagent/test/workflow/test_query_data.py @@ -14,11 +14,11 @@ from pyspark.sql.functions import asc from reagent.core.types import Dataset, TableSpec # pyre-fixme[21]: Could not find `workflow`. -from reagent.data_fetchers.oss_data_fetcher import query_data from reagent.test.workflow.reagent_sql_test_base import ReagentSQLTestBase # pyre-fixme[21]: Could not find module `reagent.test.workflow.test_data.ex_mdps`. from reagent.test.workflow.test_data.ex_mdps import generate_discrete_mdp_pandas_df +from reagent.workflow.data_fetcher import query_data logger = logging.getLogger(__name__) diff --git a/reagent/test/workflow/test_query_data_parametric.py b/reagent/test/workflow/test_query_data_parametric.py index ba3a082f..58961b32 100644 --- a/reagent/test/workflow/test_query_data_parametric.py +++ b/reagent/test/workflow/test_query_data_parametric.py @@ -14,11 +14,11 @@ from pyspark.sql.functions import asc from reagent.core.types import Dataset, TableSpec # pyre-fixme[21]: Could not find `workflow`. -from reagent.data_fetchers.oss_data_fetcher import query_data from reagent.test.workflow.reagent_sql_test_base import ReagentSQLTestBase # pyre-fixme[21]: Could not find module `reagent.test.workflow.test_data.ex_mdps`. from reagent.test.workflow.test_data.ex_mdps import generate_parametric_mdp_pandas_df +from reagent.workflow.data_fetcher import query_data logger = logging.getLogger(__name__) diff --git a/reagent/test/world_model/test_mdnrnn.py b/reagent/test/world_model/test_mdnrnn.py index 1a5df22b..4705dc87 100644 --- a/reagent/test/world_model/test_mdnrnn.py +++ b/reagent/test/world_model/test_mdnrnn.py @@ -9,7 +9,6 @@ import torch from reagent.models.mdn_rnn import MDNRNNMemoryPool, gmm_loss from reagent.models.world_model import MemoryNetwork from reagent.parameters import MDNRNNTrainerParameters -from reagent.reporting.world_model_reporter import WorldModelReporter from reagent.test.world_model.simulated_world_model import SimulatedWorldModel from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer from torch.distributions.categorical import Categorical @@ -145,27 +144,39 @@ class TestMDNRNN(unittest.TestCase): ) if use_gpu: mdnrnn_net = mdnrnn_net.cuda() - trainer = MDNRNNTrainer(memory_network=mdnrnn_net, params=mdnrnn_params) - trainer.reporter = WorldModelReporter(1) + trainer = MDNRNNTrainer( + memory_network=mdnrnn_net, params=mdnrnn_params, cum_loss_hist=num_batch + ) for e in range(num_epochs): for i in range(num_batch): training_batch = replay_buffer.sample_memories( batch_size, use_gpu=use_gpu ) - trainer.train(training_batch) + losses = trainer.train(training_batch) + logger.info( + "{}-th epoch, {}-th minibatch: \n" + "loss={}, bce={}, gmm={}, mse={} \n" + "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format( + e, + i, + losses["loss"], + losses["bce"], + losses["gmm"], + losses["mse"], + np.mean(trainer.cum_loss), + np.mean(trainer.cum_bce), + np.mean(trainer.cum_gmm), + np.mean(trainer.cum_mse), + ) + ) - trainer.reporter.finish_epoch() - report = trainer.reporter.publish().training_report.oss_world_model_report - loss = np.mean(report.loss) - bce = np.mean(report.bce) - gmm = np.mean(report.gmm) - mse = np.mean(report.mse) - logger.info( - f"{e}-th epoch: \n" f"loss={loss}, bce={bce}, gmm={gmm}, mse={mse}" - ) - - if loss < 0 and gmm < -3.0 and bce < 0.6 and mse < 0.2: - return + if ( + np.mean(trainer.cum_loss) < 0 + and np.mean(trainer.cum_gmm) < -3.0 + and np.mean(trainer.cum_bce) < 0.6 + and np.mean(trainer.cum_mse) < 0.2 + ): + return raise RuntimeError("losses not reduced significantly during training") diff --git a/reagent/training/__init__.py b/reagent/training/__init__.py index ddce98cf..5eb0741d 100644 --- a/reagent/training/__init__.py +++ b/reagent/training/__init__.py @@ -11,7 +11,6 @@ from reagent.training.rl_trainer_pytorch import RLTrainer from reagent.training.sac_trainer import SACTrainer from reagent.training.slate_q_trainer import SlateQTrainer from reagent.training.td3_trainer import TD3Trainer -from reagent.training.trainer import Trainer from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer from .parameters import ( diff --git a/reagent/training/c51_trainer.py b/reagent/training/c51_trainer.py index 5e99d08e..36fc2ab0 100644 --- a/reagent/training/c51_trainer.py +++ b/reagent/training/c51_trainer.py @@ -3,15 +3,24 @@ from typing import List -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.configuration import resolve_defaults from reagent.core.dataclasses import field +from reagent.core.tracker import observable from reagent.optimizer.union import Optimizer__Union from reagent.parameters import EvaluationParameters, RLParameters from reagent.training.rl_trainer_pytorch import RLTrainer +@observable( + td_loss=torch.Tensor, + logged_actions=torch.Tensor, + logged_propensities=torch.Tensor, + logged_rewards=torch.Tensor, + model_values=torch.Tensor, + model_action_idxs=torch.Tensor, +) class C51Trainer(RLTrainer): """ Implementation of 51 Categorical DQN (C51) @@ -25,7 +34,7 @@ class C51Trainer(RLTrainer): q_network, q_network_target, metrics_to_score=None, - reporter=None, + loss_reporter=None, use_gpu: bool = False, actions: List[str] = field(default_factory=list), # noqa: B008 rl: RLParameters = field(default_factory=RLParameters), # noqa: B008 @@ -46,10 +55,9 @@ class C51Trainer(RLTrainer): self, rl, use_gpu=use_gpu, - minibatch_size=minibatch_size, metrics_to_score=metrics_to_score, actions=actions, - reporter=reporter, + loss_reporter=loss_reporter, ) self.double_q_learning = double_q_learning @@ -169,7 +177,8 @@ class C51Trainer(RLTrainer): possible_actions_mask if self.maxq_learning else training_batch.action, ) - self.reporter.report( + # pyre-fixme[16]: `C51Trainer` has no attribute `notify_observers`. + self.notify_observers( td_loss=loss, logged_actions=torch.argmax(training_batch.action, dim=1, keepdim=True), logged_propensities=training_batch.extras.action_probability, diff --git a/reagent/training/cem_trainer.py b/reagent/training/cem_trainer.py index 002c1752..4036e92a 100644 --- a/reagent/training/cem_trainer.py +++ b/reagent/training/cem_trainer.py @@ -11,7 +11,7 @@ The idea is inspired by: https://arxiv.org/abs/1805.12114 import logging from typing import List -import reagent.core.types as rlt +import reagent.types as rlt from reagent.models.cem_planner import CEMPlannerNetwork from reagent.parameters import CEMTrainerParameters from reagent.training.rl_trainer_pytorch import RLTrainer @@ -21,6 +21,14 @@ from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer logger = logging.getLogger(__name__) +def print_mdnrnn_losses(minibatch, model_index, losses) -> None: + logger.info( + f"{minibatch}-th minibatch {model_index}-th model: \n" + f'loss={losses["loss"]}, bce={losses["bce"]}, ' + f'gmm={losses["gmm"]}, mse={losses["mse"]}\n' + ) + + class CEMTrainer(RLTrainer): def __init__( self, @@ -29,15 +37,15 @@ class CEMTrainer(RLTrainer): parameters: CEMTrainerParameters, use_gpu: bool = False, ) -> None: - super().__init__( - parameters.rl, - use_gpu=use_gpu, - minibatch_size=parameters.mdnrnn.minibatch_size, - ) + super().__init__(parameters.rl, use_gpu=use_gpu) self.cem_planner_network = cem_planner_network self.world_model_trainers = world_model_trainers + self.minibatch_size = parameters.mdnrnn.minibatch_size def train(self, training_batch: rlt.MemoryNetworkInput) -> None: - for _, trainer in enumerate(self.world_model_trainers): - trainer.train(training_batch) + for i, trainer in enumerate(self.world_model_trainers): + losses = trainer.train(training_batch) + # TODO: report losses instead of printing them + # print_mdnrnn_losses(self.minibatch, i, losses) + self.minibatch += 1 diff --git a/reagent/training/dqn_trainer.py b/reagent/training/dqn_trainer.py index df83f805..e7df54c3 100644 --- a/reagent/training/dqn_trainer.py +++ b/reagent/training/dqn_trainer.py @@ -4,10 +4,11 @@ import logging from typing import List, Optional, Tuple -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.configuration import resolve_defaults from reagent.core.dataclasses import dataclass, field +from reagent.core.tracker import observable from reagent.optimizer.union import Optimizer__Union from reagent.parameters import EvaluationParameters, RLParameters from reagent.training.dqn_trainer_base import DQNTrainerBase @@ -23,6 +24,17 @@ class BCQConfig: drop_threshold: float = 0.1 +@observable( + td_loss=torch.Tensor, + reward_loss=torch.Tensor, + logged_actions=torch.Tensor, + logged_propensities=torch.Tensor, + logged_rewards=torch.Tensor, + model_propensities=torch.Tensor, + model_rewards=torch.Tensor, + model_values=torch.Tensor, + model_action_idxs=torch.Tensor, +) class DQNTrainer(DQNTrainerBase): @resolve_defaults def __init__( @@ -34,7 +46,7 @@ class DQNTrainer(DQNTrainerBase): q_network_cpe_target=None, metrics_to_score=None, imitator=None, - reporter=None, + loss_reporter=None, use_gpu: bool = False, actions: List[str] = field(default_factory=list), # noqa: B008 rl: RLParameters = field(default_factory=RLParameters), # noqa: B008 @@ -55,8 +67,7 @@ class DQNTrainer(DQNTrainerBase): metrics_to_score=metrics_to_score, actions=actions, evaluation_parameters=evaluation, - reporter=reporter, - minibatch_size=minibatch_size, + loss_reporter=loss_reporter, ) assert self._actions is not None, "Discrete-action DQN needs action names" self.double_q_learning = double_q_learning @@ -213,20 +224,29 @@ class DQNTrainer(DQNTrainerBase): possible_actions_mask if self.maxq_learning else training_batch.action, )[1] - self.reporter.report( + # pyre-fixme[16]: `DQNTrainer` has no attribute `notify_observers`. + self.notify_observers( td_loss=self.loss, + reward_loss=reward_loss, + logged_actions=logged_action_idxs, + logged_propensities=training_batch.extras.action_probability, + logged_rewards=rewards, + model_propensities=model_propensities, + model_rewards=model_rewards, + model_values=self.all_action_scores, + model_action_idxs=model_action_idxs, + ) + + self.loss_reporter.report( + td_loss=self.loss, + reward_loss=reward_loss, logged_actions=logged_action_idxs, logged_propensities=training_batch.extras.action_probability, logged_rewards=rewards, logged_values=None, # Compute at end of each epoch for CPE + model_propensities=model_propensities, + model_rewards=model_rewards, model_values=self.all_action_scores, model_values_on_logged_actions=None, # Compute at end of each epoch for CPE model_action_idxs=model_action_idxs, ) - - if reward_loss is not None: - self.reporter.report( - reward_loss=reward_loss, - model_propensities=model_propensities, - model_rewards=model_rewards, - ) diff --git a/reagent/training/loss_reporter.py b/reagent/training/loss_reporter.py index ad262810..f21677e9 100644 --- a/reagent/training/loss_reporter.py +++ b/reagent/training/loss_reporter.py @@ -43,10 +43,9 @@ class BatchStats(NamedTuple): for i, action in enumerate(actions): # pyre-fixme[16]: `SummaryWriterContext` has no attribute # `add_scalar`. - # SummaryWriterContext.add_scalar( - # "{}/{}".format(log_key, action), (val == i).sum().item() - # ) - pass + SummaryWriterContext.add_scalar( + "{}/{}".format(log_key, action), (val == i).sum().item() + ) for field, log_key in [ ("td_loss", "td_loss"), @@ -89,9 +88,8 @@ class BatchStats(NamedTuple): def _log_histogram_and_mean(self, log_key, val): try: - # SummaryWriterContext.add_histogram(log_key, val) - # SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean()) - pass + SummaryWriterContext.add_histogram(log_key, val) + SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean()) except ValueError: logger.warning( f"Cannot create histogram for key: {log_key}; " @@ -105,32 +103,32 @@ class BatchStats(NamedTuple): if not action_names: return - # SummaryWriterContext.add_custom_scalars_multilinechart( - # [ - # "propensities/model/{}/mean".format(action_name) - # for action_name in action_names - # ], - # category="propensities", - # title="model", - # ) - # SummaryWriterContext.add_custom_scalars_multilinechart( - # [ - # "propensities/logged/{}/mean".format(action_name) - # for action_name in action_names - # ], - # category="propensities", - # title="logged", - # ) - # SummaryWriterContext.add_custom_scalars_multilinechart( - # ["actions/logged/{}".format(action_name) for action_name in action_names], - # category="actions", - # title="logged", - # ) - # SummaryWriterContext.add_custom_scalars_multilinechart( - # ["actions/model/{}".format(action_name) for action_name in action_names], - # category="actions", - # title="model", - # ) + SummaryWriterContext.add_custom_scalars_multilinechart( + [ + "propensities/model/{}/mean".format(action_name) + for action_name in action_names + ], + category="propensities", + title="model", + ) + SummaryWriterContext.add_custom_scalars_multilinechart( + [ + "propensities/logged/{}/mean".format(action_name) + for action_name in action_names + ], + category="propensities", + title="logged", + ) + SummaryWriterContext.add_custom_scalars_multilinechart( + ["actions/logged/{}".format(action_name) for action_name in action_names], + category="actions", + title="logged", + ) + SummaryWriterContext.add_custom_scalars_multilinechart( + ["actions/model/{}".format(action_name) for action_name in action_names], + category="actions", + title="model", + ) def merge_tensor_namedtuple_list(l, cls): @@ -350,8 +348,7 @@ class LossReporter(object): ("Training/imitator_loss", self.get_recent_imitator_loss()), ]: # pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`. - # SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch) - pass + SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch) @staticmethod def calculate_recent_window_average(arr, window_size, num_entries): diff --git a/reagent/training/parameters.py b/reagent/training/parameters.py index 055639f0..d07cbd05 100644 --- a/reagent/training/parameters.py +++ b/reagent/training/parameters.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from reagent.core.configuration import make_config_class -from reagent.core.types import BaseDataClass +from reagent.types import BaseDataClass from .c51_trainer import C51Trainer from .dqn_trainer import DQNTrainer @@ -57,8 +57,7 @@ class ParametricDQNTrainerParameters: "q_network_cpe_target", "metrics_to_score", "imitator", - "reporter", - "evaluation", + "loss_reporter", ], ) class DQNTrainerParameters: @@ -75,8 +74,7 @@ class DQNTrainerParameters: "reward_network", "q_network_cpe", "q_network_cpe_target", - "reporter", - "evaluation", + "loss_reporter", ], ) class QRDQNTrainerParameters: @@ -90,8 +88,7 @@ class QRDQNTrainerParameters: "q_network", "q_network_target", "metrics_to_score", - "reporter", - "evaluation", + "loss_reporter", ], ) class C51TrainerParameters: diff --git a/reagent/training/parametric_dqn_trainer.py b/reagent/training/parametric_dqn_trainer.py index ef14a587..ce469ea6 100644 --- a/reagent/training/parametric_dqn_trainer.py +++ b/reagent/training/parametric_dqn_trainer.py @@ -4,8 +4,8 @@ import logging from typing import Tuple -import reagent.core.types as rlt import reagent.parameters as rlp +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.core.configuration import resolve_defaults @@ -34,7 +34,7 @@ class ParametricDQNTrainer(DQNTrainerBase): default_factory=Optimizer__Union.default ), ) -> None: - super().__init__(rl, minibatch_size=minibatch_size, use_gpu=use_gpu) + super().__init__(rl, use_gpu=use_gpu) self.double_q_learning = double_q_learning self.minibatch_size = minibatch_size @@ -161,7 +161,7 @@ class ParametricDQNTrainer(DQNTrainerBase): self.reward_network_optimizer, self.minibatches_per_step ) - self.reporter.report( + self.loss_reporter.report( td_loss=td_loss.detach().cpu(), reward_loss=reward_loss.detach().cpu(), logged_rewards=reward, diff --git a/reagent/training/qrdqn_trainer.py b/reagent/training/qrdqn_trainer.py index 225cff5c..10b78ff3 100644 --- a/reagent/training/qrdqn_trainer.py +++ b/reagent/training/qrdqn_trainer.py @@ -4,10 +4,11 @@ import logging from typing import List, Tuple -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.configuration import resolve_defaults from reagent.core.dataclasses import field +from reagent.core.tracker import observable from reagent.optimizer.union import Optimizer__Union from reagent.parameters import EvaluationParameters, RLParameters from reagent.training.dqn_trainer_base import DQNTrainerBase @@ -16,6 +17,16 @@ from reagent.training.dqn_trainer_base import DQNTrainerBase logger = logging.getLogger(__name__) +@observable( + td_loss=torch.Tensor, + logged_actions=torch.Tensor, + logged_propensities=torch.Tensor, + logged_rewards=torch.Tensor, + model_propensities=torch.Tensor, + model_rewards=torch.Tensor, + model_values=torch.Tensor, + model_action_idxs=torch.Tensor, +) class QRDQNTrainer(DQNTrainerBase): """ Implementation of QR-DQN (Quantile Regression Deep Q-Network) @@ -32,7 +43,7 @@ class QRDQNTrainer(DQNTrainerBase): reward_network=None, q_network_cpe=None, q_network_cpe_target=None, - reporter=None, + loss_reporter=None, use_gpu: bool = False, actions: List[str] = field(default_factory=list), # noqa: B008 rl: RLParameters = field(default_factory=RLParameters), # noqa: B008 @@ -56,11 +67,11 @@ class QRDQNTrainer(DQNTrainerBase): metrics_to_score=metrics_to_score, actions=actions, evaluation_parameters=evaluation, - reporter=reporter, - minibatch_size=minibatch_size, + loss_reporter=loss_reporter, ) self.double_q_learning = double_q_learning + self.minibatch_size = minibatch_size self.minibatches_per_step = minibatches_per_step self._actions = actions @@ -183,21 +194,30 @@ class QRDQNTrainer(DQNTrainerBase): possible_actions_mask if self.maxq_learning else training_batch.action, ) - self.reporter.report( + # pyre-fixme[16]: `QRDQNTrainer` has no attribute `notify_observers`. + self.notify_observers( td_loss=loss, logged_actions=logged_action_idxs, logged_propensities=training_batch.extras.action_probability, logged_rewards=rewards, + model_propensities=model_propensities, + model_rewards=model_rewards, model_values=all_q_values, model_action_idxs=model_action_idxs, ) - if reward_loss is not None: - self.reporter.report( - reward_loss=reward_loss, - model_propensities=model_propensities, - model_rewards=model_rewards, - ) + self.loss_reporter.report( + td_loss=loss, + logged_actions=logged_action_idxs, + logged_propensities=training_batch.extras.action_probability, + logged_rewards=rewards, + logged_values=None, # Compute at end of each epoch for CPE + model_propensities=model_propensities, + model_rewards=model_rewards, + model_values=all_q_values, + model_values_on_logged_actions=None, # Compute at end of each epoch for CPE + model_action_idxs=model_action_idxs, + ) # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because # its type `no_grad` is not callable. diff --git a/reagent/training/ranking/seq2slate_attn_trainer.py b/reagent/training/ranking/seq2slate_attn_trainer.py index 476a2b71..203a4515 100644 --- a/reagent/training/ranking/seq2slate_attn_trainer.py +++ b/reagent/training/ranking/seq2slate_attn_trainer.py @@ -2,19 +2,22 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn as nn from reagent.core.dataclasses import field +from reagent.core.tracker import observable from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet from reagent.optimizer.union import Optimizer__Union from reagent.parameters import TransformerParameters +from reagent.training.loss_reporter import NoOpLossReporter from reagent.training.trainer import Trainer logger = logging.getLogger(__name__) +@observable(cross_entropy_loss=torch.Tensor) class Seq2SlatePairwiseAttnTrainer(Trainer): """ Seq2Slate without a decoder learned in a supervised learning fashion ( @@ -25,13 +28,13 @@ class Seq2SlatePairwiseAttnTrainer(Trainer): self, seq2slate_net: Seq2SlateTransformerNet, minibatch_size: int = 1024, - reporter=None, + loss_reporter=None, use_gpu: bool = False, policy_optimizer: Optimizer__Union = field( # noqa: B008 default_factory=Optimizer__Union.default ), ) -> None: - self.reporter = reporter + self.loss_reporter = loss_reporter self.use_gpu = use_gpu self.seq2slate_net = seq2slate_net self.minibatch_size = minibatch_size @@ -41,6 +44,8 @@ class Seq2SlatePairwiseAttnTrainer(Trainer): ) self.log_softmax = nn.LogSoftmax(dim=1) self.kl_loss = nn.KLDivLoss(reduction="batchmean") + if self.loss_reporter is None: + self.loss_reporter = NoOpLossReporter() def warm_start_components(self): components = ["seq2slate_net"] @@ -67,6 +72,8 @@ class Seq2SlatePairwiseAttnTrainer(Trainer): loss = loss.detach() self.minibatch += 1 - self.reporter.report(cross_entropy_loss=loss) + # pyre-fixme[16]: `Seq2SlatePairwiseAttnTrainer` has no attribute + # `notify_observers`. + self.notify_observers(cross_entropy_loss=loss) return {"cross_entropy_loss": loss} diff --git a/reagent/training/ranking/seq2slate_dr_trainer.py b/reagent/training/ranking/seq2slate_dr_trainer.py index 0c5fc6e6..890afcc1 100644 --- a/reagent/training/ranking/seq2slate_dr_trainer.py +++ b/reagent/training/ranking/seq2slate_dr_trainer.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn as nn import torch.nn.functional as F diff --git a/reagent/training/ranking/seq2slate_sim_trainer.py b/reagent/training/ranking/seq2slate_sim_trainer.py index ed2c086e..658acfe0 100644 --- a/reagent/training/ranking/seq2slate_sim_trainer.py +++ b/reagent/training/ranking/seq2slate_sim_trainer.py @@ -6,9 +6,10 @@ from itertools import permutations from typing import List, Optional import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.dataclasses import field +from reagent.core.tracker import observable from reagent.models.seq2slate import ( DECODER_START_SYMBOL, BaselineNet, @@ -62,6 +63,15 @@ def swap_dist(idx: List[int]): return swap_dist_in_slate(idx) + swap_dist_out_slate(idx) +@observable( + train_ips_score=torch.Tensor, + train_clamped_ips_score=torch.Tensor, + train_baseline_loss=torch.Tensor, + train_log_probs=torch.Tensor, + train_ips_ratio=torch.Tensor, + train_clamped_ips_ratio=torch.Tensor, + train_advantage=torch.Tensor, +) class Seq2SlateSimulationTrainer(Trainer): """ Seq2Slate learned with simulation data, with the action @@ -224,7 +234,7 @@ class Seq2SlateSimulationTrainer(Trainer): ) return on_policy_input - def train(self, training_batch: rlt.PreprocessedTrainingBatch) -> None: + def train(self, training_batch: rlt.PreprocessedTrainingBatch): assert type(training_batch) is rlt.PreprocessedTrainingBatch training_input = training_batch.training_input assert isinstance(training_input, rlt.PreprocessedRankingInput) diff --git a/reagent/training/ranking/seq2slate_tf_trainer.py b/reagent/training/ranking/seq2slate_tf_trainer.py index ddbe07a9..02d022a2 100644 --- a/reagent/training/ranking/seq2slate_tf_trainer.py +++ b/reagent/training/ranking/seq2slate_tf_trainer.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn as nn import torch.nn.functional as F diff --git a/reagent/training/ranking/seq2slate_trainer.py b/reagent/training/ranking/seq2slate_trainer.py index b10222a5..4ed819be 100644 --- a/reagent/training/ranking/seq2slate_trainer.py +++ b/reagent/training/ranking/seq2slate_trainer.py @@ -3,13 +3,13 @@ import logging from typing import Optional, Tuple -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.dataclasses import field +from reagent.core.tracker import observable from reagent.models.seq2slate import BaselineNet, Seq2SlateMode, Seq2SlateTransformerNet from reagent.optimizer.union import Optimizer__Union from reagent.parameters import Seq2SlateParameters -from reagent.reporting.ranking_model_reporter import RankingModelReporter from reagent.training.ranking.helper import ips_clamp from reagent.training.trainer import Trainer @@ -17,6 +17,15 @@ from reagent.training.trainer import Trainer logger = logging.getLogger(__name__) +@observable( + train_ips_score=torch.Tensor, + train_clamped_ips_score=torch.Tensor, + train_baseline_loss=torch.Tensor, + train_log_probs=torch.Tensor, + train_ips_ratio=torch.Tensor, + train_clamped_ips_ratio=torch.Tensor, + train_advantages=torch.Tensor, +) class Seq2SlateTrainer(Trainer): def __init__( self, @@ -54,8 +63,6 @@ class Seq2SlateTrainer(Trainer): self.baseline_net.parameters() ) - self.reporter = RankingModelReporter() - def warm_start_components(self): components = ["seq2slate_net"] if self.baseline_net: @@ -76,7 +83,7 @@ class Seq2SlateTrainer(Trainer): clamped_impt_smpl = ips_clamp(impt_smpl, self.parameters.ips_clamp) return impt_smpl, clamped_impt_smpl - def train(self, training_batch: rlt.PreprocessedTrainingBatch) -> None: + def train(self, training_batch: rlt.PreprocessedTrainingBatch): assert type(training_batch) is rlt.PreprocessedTrainingBatch training_input = training_batch.training_input assert isinstance(training_input, rlt.PreprocessedRankingInput) @@ -168,8 +175,10 @@ class Seq2SlateTrainer(Trainer): torch.mean(impt_smpl), ) ) - - self.reporter.report( + # See RankingTrainingPageHandler.finish() function in page_handler.py + # pyre-fixme[16]: `Seq2SlateTrainer` has no attribute + # `notify_observers`. + self.notify_observers( train_ips_score=torch.tensor(ips_rl_loss).reshape(1), train_clamped_ips_score=torch.tensor(clamped_ips_rl_loss).reshape(1), train_baseline_loss=torch.tensor(baseline_loss).reshape(1), diff --git a/reagent/training/reinforce.py b/reagent/training/reinforce.py index ba2ec740..53ae5096 100644 --- a/reagent/training/reinforce.py +++ b/reagent/training/reinforce.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, field from typing import List -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.optim from reagent.optimizer.union import Optimizer__Union diff --git a/reagent/training/reward_network_trainer.py b/reagent/training/reward_network_trainer.py index 5336ca46..013e59dc 100644 --- a/reagent/training/reward_network_trainer.py +++ b/reagent/training/reward_network_trainer.py @@ -2,12 +2,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.dataclasses import field from reagent.models.base import ModelBase from reagent.optimizer.union import Optimizer__Union -from reagent.reporting.world_model_reporter import WorldModelReporter from reagent.training.trainer import Trainer @@ -30,10 +29,8 @@ class RewardNetTrainer(Trainer): self.minibatch = 0 self.loss_fn = torch.nn.MSELoss(reduction="mean") self.opt = optimizer.make_optimizer(self.reward_net.parameters()) - self.reporter = WorldModelReporter() - self.best_model = reward_net - def train(self, training_batch: rlt.PreprocessedTrainingBatch) -> None: + def train(self, training_batch: rlt.PreprocessedTrainingBatch): training_input = training_batch.training_input if isinstance(training_input, rlt.PreprocessedRankingInput): target_reward = training_input.slate_reward @@ -51,7 +48,7 @@ class RewardNetTrainer(Trainer): if self.minibatch % 10 == 0: logger.info("{}-th batch: mse_loss={}".format(self.minibatch, mse_loss)) - self.reporter.report(mse=mse_loss) + return mse_loss def warm_start_components(self): return ["reward_net"] diff --git a/reagent/training/rl_trainer_pytorch.py b/reagent/training/rl_trainer_pytorch.py index 372d322d..f43a91cb 100644 --- a/reagent/training/rl_trainer_pytorch.py +++ b/reagent/training/rl_trainer_pytorch.py @@ -27,14 +27,13 @@ class RLTrainer(Trainer): self, rl_parameters: RLParameters, use_gpu: bool, - minibatch_size: int, metrics_to_score=None, actions: Optional[List[str]] = None, evaluation_parameters: Optional[EvaluationParameters] = None, - reporter=None, + loss_reporter=None, ) -> None: - super().__init__(minibatch_size) self.minibatch = 0 + self.minibatch_size: Optional[int] = None self.minibatches_per_step: Optional[int] = None self.rl_parameters = rl_parameters self.rl_temperature = float(rl_parameters.temperature) @@ -76,8 +75,7 @@ class RLTrainer(Trainer): self.use_gpu = False self.device = torch.device("cpu") - self.reporter = reporter - self.loss_reporter = LossReporter(actions) + self.loss_reporter = loss_reporter or LossReporter(actions) self._actions = actions @property diff --git a/reagent/training/sac_trainer.py b/reagent/training/sac_trainer.py index 671c80ee..4121cfdf 100644 --- a/reagent/training/sac_trainer.py +++ b/reagent/training/sac_trainer.py @@ -5,11 +5,12 @@ import logging from typing import List, Optional import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.core.configuration import resolve_defaults from reagent.core.dataclasses import field +from reagent.core.tracker import observable from reagent.optimizer.union import Optimizer__Union from reagent.parameters import RLParameters from reagent.tensorboardX import SummaryWriterContext @@ -19,6 +20,17 @@ from reagent.training.rl_trainer_pytorch import RLTrainer logger = logging.getLogger(__name__) +@observable( + td_loss=torch.Tensor, + reward_loss=torch.Tensor, + logged_actions=torch.Tensor, + logged_propensities=torch.Tensor, + logged_rewards=torch.Tensor, + model_propensities=torch.Tensor, + model_rewards=torch.Tensor, + model_values=torch.Tensor, + model_action_idxs=torch.Tensor, +) class SACTrainer(RLTrainer): """ Soft Actor-Critic trainer as described in https://arxiv.org/pdf/1801.01290 @@ -68,8 +80,9 @@ class SACTrainer(RLTrainer): # alpha in the paper; controlling explore & exploit # TODO: finish """ - super().__init__(rl, use_gpu=use_gpu, minibatch_size=minibatch_size) + super().__init__(rl, use_gpu=use_gpu) + self.minibatch_size = minibatch_size self.minibatches_per_step = 1 self.q1_network = q1_network @@ -366,8 +379,9 @@ class SACTrainer(RLTrainer): SummaryWriterContext.add_histogram("kld/var", action_batch_v) SummaryWriterContext.add_scalar("kld/kld", kld) - self.reporter.report( + self.loss_reporter.report( td_loss=float(q1_loss), + reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, model_propensities=actor_output.log_prob.exp(), diff --git a/reagent/training/slate_q_trainer.py b/reagent/training/slate_q_trainer.py index 5fe862a7..ae6e9284 100644 --- a/reagent/training/slate_q_trainer.py +++ b/reagent/training/slate_q_trainer.py @@ -4,8 +4,8 @@ import logging from typing import List, Optional -import reagent.core.types as rlt import reagent.parameters as rlp +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.core.dataclasses import field @@ -35,7 +35,7 @@ class SlateQTrainer(DQNTrainerBase): default_factory=lambda: rlp.EvaluationParameters(calc_cpe_in_training=False) ), ) -> None: - super().__init__(rl, use_gpu=use_gpu, minibatch_size=minibatch_size) + super().__init__(rl, use_gpu=use_gpu) self.minibatches_per_step = 1 self.minibatch_size = minibatch_size self.single_selection = single_selection @@ -148,6 +148,6 @@ class SlateQTrainer(DQNTrainerBase): if not self.single_selection: all_action_scores = all_action_scores.sum(dim=1, keepdim=True) - self.reporter.report( + self.loss_reporter.report( td_loss=td_loss, model_values_on_logged_actions=all_action_scores ) diff --git a/reagent/training/td3_trainer.py b/reagent/training/td3_trainer.py index 03ae2053..84a54931 100644 --- a/reagent/training/td3_trainer.py +++ b/reagent/training/td3_trainer.py @@ -3,7 +3,7 @@ import copy import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.configuration import resolve_defaults from reagent.core.dataclasses import field @@ -47,7 +47,7 @@ class TD3Trainer(RLTrainer): """ Args: TODO: fill in """ - super().__init__(rl, use_gpu=use_gpu, minibatch_size=minibatch_size) + super().__init__(rl, use_gpu=use_gpu) self.minibatch_size = minibatch_size self.minibatches_per_step = minibatches_per_step or 1 @@ -180,8 +180,9 @@ class TD3Trainer(RLTrainer): SummaryWriterContext.add_histogram(k, v.numpy()) SummaryWriterContext.add_scalar(f"{k}_mean", v.mean().item()) - self.reporter.report( + self.loss_reporter.report( td_loss=float(q1_loss), + reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, ) diff --git a/reagent/training/trainer.py b/reagent/training/trainer.py index 4fb3588a..09bb9719 100644 --- a/reagent/training/trainer.py +++ b/reagent/training/trainer.py @@ -9,10 +9,6 @@ logger = logging.getLogger(__name__) class Trainer: - def __init__(self, minibatch_size: int): - self.reporter = None - self.minibatch_size = minibatch_size - def train(self, training_batch) -> None: raise NotImplementedError() diff --git a/reagent/training/world_model/compress_model_trainer.py b/reagent/training/world_model/compress_model_trainer.py index 836708f4..cf631c12 100644 --- a/reagent/training/world_model/compress_model_trainer.py +++ b/reagent/training/world_model/compress_model_trainer.py @@ -3,7 +3,7 @@ import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.models.fully_connected_network import FullyConnectedNetwork diff --git a/reagent/training/world_model/mdnrnn_trainer.py b/reagent/training/world_model/mdnrnn_trainer.py index 9be47371..a94844a5 100644 --- a/reagent/training/world_model/mdnrnn_trainer.py +++ b/reagent/training/world_model/mdnrnn_trainer.py @@ -2,16 +2,15 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import logging -from typing import Optional +from collections import deque +from typing import Deque, Optional -import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.models.mdn_rnn import gmm_loss from reagent.models.world_model import MemoryNetwork from reagent.parameters import MDNRNNTrainerParameters -from reagent.reporting.world_model_reporter import WorldModelReporter from reagent.training.trainer import Trainer @@ -21,54 +20,48 @@ logger = logging.getLogger(__name__) class MDNRNNTrainer(Trainer): """ Trainer for MDN-RNN """ - def __init__(self, memory_network: MemoryNetwork, params: MDNRNNTrainerParameters): - super().__init__(params.minibatch_size) + def __init__( + self, + memory_network: MemoryNetwork, + params: MDNRNNTrainerParameters, + cum_loss_hist: int = 100, + ): self.memory_network = memory_network self.params = params self.optimizer = torch.optim.Adam( self.memory_network.mdnrnn.parameters(), lr=params.learning_rate ) self.minibatch = 0 - self.reporter = WorldModelReporter() - - def train(self, training_batch: rlt.MemoryNetworkInput) -> None: - if self.params.shuffle_training_data: - _, batch_size, _ = training_batch.next_state.float_features.size() - - training_batch = rlt.MemoryNetworkInput( - state=training_batch.state, - action=training_batch.action, - time_diff=torch.ones_like(training_batch.reward), - # shuffle the data - next_state=training_batch.next_state._replace( - float_features=training_batch.next_state.float_features[ - :, torch.randperm(batch_size), : - ] - ), - reward=training_batch.reward[:, torch.randperm(batch_size)], - not_terminal=training_batch.not_terminal[ # type: ignore - :, torch.randperm(batch_size) - ], - step=None, - ) + self.minibatch_size = params.minibatch_size + self.cum_loss: Deque[float] = deque([], maxlen=cum_loss_hist) + self.cum_bce: Deque[float] = deque([], maxlen=cum_loss_hist) + self.cum_gmm: Deque[float] = deque([], maxlen=cum_loss_hist) + self.cum_mse: Deque[float] = deque([], maxlen=cum_loss_hist) # PageHandler must use this to activate evaluator: self.calc_cpe_in_training = True + + def train(self, training_batch: rlt.MemoryNetworkInput): self.minibatch += 1 (seq_len, batch_size, state_dim) = training_batch.state.float_features.shape self.memory_network.mdnrnn.train() self.optimizer.zero_grad() - losses = self.compute_loss(training_batch, state_dim) + losses = self.get_loss(training_batch, state_dim) losses["loss"].backward() self.optimizer.step() detached_losses = {k: loss.cpu().detach().item() for k, loss in losses.items()} - self.reporter.report(**detached_losses) + self.cum_loss.append(detached_losses["loss"]) + self.cum_gmm.append(detached_losses["gmm"]) + self.cum_bce.append(detached_losses["bce"]) + self.cum_mse.append(detached_losses["mse"]) + del losses + return detached_losses - def compute_loss( + def get_loss( self, training_batch: rlt.MemoryNetworkInput, state_dim: Optional[int] = None ): """ diff --git a/reagent/training/world_model/seq2reward_trainer.py b/reagent/training/world_model/seq2reward_trainer.py index 61895b03..db5259b3 100644 --- a/reagent/training/world_model/seq2reward_trainer.py +++ b/reagent/training/world_model/seq2reward_trainer.py @@ -3,12 +3,12 @@ import logging -import reagent.core.types as rlt +import reagent.types as rlt import torch import torch.nn.functional as F from reagent.models.seq2reward_model import Seq2RewardNetwork from reagent.parameters import Seq2RewardTrainerParameters -from reagent.reporting.world_model_reporter import WorldModelReporter +from reagent.training.loss_reporter import NoOpLossReporter from reagent.training.trainer import Trainer from reagent.training.utils import gen_permutations @@ -28,7 +28,7 @@ class Seq2RewardTrainer(Trainer): self.seq2reward_network.parameters(), lr=params.learning_rate ) self.minibatch_size = self.params.batch_size - self.reporter = WorldModelReporter() + self.loss_reporter = NoOpLossReporter() # PageHandler must use this to activate evaluator: self.calc_cpe_in_training = True @@ -37,7 +37,7 @@ class Seq2RewardTrainer(Trainer): def train(self, training_batch: rlt.MemoryNetworkInput): self.optimizer.zero_grad() - loss = self.compute_loss(training_batch) + loss = self.get_loss(training_batch) loss.backward() self.optimizer.step() detached_loss = loss.cpu().detach().item() @@ -51,11 +51,10 @@ class Seq2RewardTrainer(Trainer): .mean(0) .tolist() ) - self.reporter.report(mse=detached_loss) return (detached_loss, q_values) - def compute_loss(self, training_batch: rlt.MemoryNetworkInput): + def get_loss(self, training_batch: rlt.MemoryNetworkInput): """ Compute losses: MSE(predicted_acc_reward, target_acc_reward) diff --git a/reagent/types.py b/reagent/types.py new file mode 100644 index 00000000..868930e1 --- /dev/null +++ b/reagent/types.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import dataclasses +import logging + +# The dataclasses in this file should be vanilla dataclass to have minimal overhead +from dataclasses import dataclass, field +from typing import Dict, List, NamedTuple, Optional, Tuple, Union + +# Triggering registration to registries +import reagent.core.result_types # noqa +import torch +import torch.nn.functional as F +from reagent.base_dataclass import BaseDataClass +from reagent.core.configuration import param_hash +from reagent.core.dataclasses import dataclass as pydantic_dataclass +from reagent.core.fb_checker import IS_FB_ENVIRONMENT +from reagent.preprocessing.types import InputColumn + + +if IS_FB_ENVIRONMENT: + import reagent.core.fb.fb_result_types # noqa + + +class NoDuplicatedWarningLogger: + def __init__(self, logger): + self.logger = logger + self.msg = set() + + def warning(self, msg): + if msg not in self.msg: + self.logger.warning(msg) + self.msg.add(msg) + + +logger = logging.getLogger(__name__) +no_dup_logger = NoDuplicatedWarningLogger(logger) + + +def isinstance_namedtuple(x): + return isinstance(x, tuple) and hasattr(x, "_fields") + + +@dataclass +class TensorDataClass(BaseDataClass): + def __getattr__(self, attr): + if attr.startswith("__") and attr.endswith("__"): + raise AttributeError + + tensor_attr = getattr(torch.Tensor, attr, None) + + if tensor_attr is None or not callable(tensor_attr): + logger.error( + f"Attemping to call torch.Tensor.{attr} on " + f"{type(self)} (instance of TensorDataClass)." + ) + if tensor_attr is None: + raise AttributeError(f"torch.Tensor doesn't have {attr} attribute.") + else: + raise RuntimeError(f"Tensor.{attr} is not callable.") + + def continuation(*args, **kwargs): + def f(v): + # if possible, returns v.attr(*args, **kwargs). + # otws, return v + if isinstance(v, (torch.Tensor, TensorDataClass)): + return getattr(v, attr)(*args, **kwargs) + elif isinstance(v, dict): + return {kk: f(vv) for kk, vv in v.items()} + elif isinstance(v, tuple): + return tuple(f(vv) for vv in v) + return v + + return type(self)(**f(self.__dict__)) + + return continuation + + def cuda(self, *args, **kwargs): + cuda_tensor = {} + for k, v in self.__dict__.items(): # noqa F402 + if isinstance(v, torch.Tensor): + kwargs["non_blocking"] = kwargs.get("non_blocking", True) + cuda_tensor[k] = v.cuda(*args, **kwargs) + elif isinstance(v, TensorDataClass): + cuda_tensor[k] = v.cuda(*args, **kwargs) + else: + cuda_tensor[k] = v + return type(self)(**cuda_tensor) + + +# (offset, value) +IdListFeatureValue = Tuple[torch.Tensor, torch.Tensor] +# (offset, key, value) +IdScoreListFeatureValue = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] +# name -> value +IdListFeature = Dict[str, IdListFeatureValue] +IdScoreListFeature = Dict[str, IdScoreListFeatureValue] +# id -> value +ServingIdListFeature = Dict[int, IdListFeatureValue] +ServingIdScoreListFeature = Dict[int, IdScoreListFeatureValue] + + +##### +# FIXME: These config types are misplaced but we need to write FBL config adapter +# if we moved them. +###### + + +@pydantic_dataclass +class IdListFeatureConfig(BaseDataClass): + name: str + # integer feature ID + feature_id: int + # name of the embedding table to use + id_mapping_name: str + + +@pydantic_dataclass +class IdScoreListFeatureConfig(BaseDataClass): + name: str + # integer feature ID + feature_id: int + # name of the embedding table to use + id_mapping_name: str + + +@pydantic_dataclass +class FloatFeatureInfo(BaseDataClass): + name: str + feature_id: int + + +@pydantic_dataclass +class IdMapping(object): + __hash__ = param_hash + + ids: List[int] = field(default_factory=list) + + def __post_init_post_parse__(self): + """ + used in preprocessing + ids list represents mapping from idx -> value + we want the reverse: from feature to embedding table indices + """ + self._id2index: Dict[int, int] = {} + + @property + def id2index(self) -> Dict[int, int]: + # pyre-fixme[16]: `IdMapping` has no attribute `_id2index`. + if not self._id2index: + self._id2index = {id: i for i, id in enumerate(self.ids)} + return self._id2index + + @property + def table_size(self): + return len(self.ids) + + +@pydantic_dataclass +class ModelFeatureConfig(BaseDataClass): + float_feature_infos: List[FloatFeatureInfo] = field(default_factory=list) + # table name -> id mapping + id_mapping_config: Dict[str, IdMapping] = field(default_factory=dict) + # id_list_feature_configs is feature_id -> list of values + id_list_feature_configs: List[IdListFeatureConfig] = field(default_factory=list) + # id_score_list_feature_configs is feature_id -> (keys -> values) + id_score_list_feature_configs: List[IdScoreListFeatureConfig] = field( + default_factory=list + ) + + def __post_init_post_parse__(self): + both_lists = self.id_list_feature_configs + self.id_score_list_feature_configs + if not self.only_dense: + # sanity check for keys in mapping config + ids = [config.feature_id for config in both_lists] + names = [config.name for config in both_lists] + assert len(ids) == len(set(ids)), f"duplicates in ids: {ids}" + assert len(names) == len(set(names)), f"duplicates in names: {names}" + assert len(ids) == len(names), f"{len(ids)} != {len(names)}" + + self._id2name = {config.feature_id: config.name for config in both_lists} + self._name2id = {config.name: config.feature_id for config in both_lists} + self._id2config = {config.feature_id: config for config in both_lists} + self._name2config = {config.name: config for config in both_lists} + + @property + def only_dense(self): + return not (self.id_list_feature_configs or self.id_score_list_feature_configs) + + @property + def id2name(self): + return self._id2name + + @property + def name2id(self): + return self._name2id + + @property + def id2config(self): + return self._id2config + + @property + def name2config(self): + return self._name2config + + +###### +# dataclasses for internal API +###### + + +@dataclass +class ValuePresence(TensorDataClass): + value: torch.Tensor + presence: Optional[torch.Tensor] + + +@dataclass +class ActorOutput(TensorDataClass): + action: torch.Tensor + log_prob: Optional[torch.Tensor] = None + squashed_mean: Optional[torch.Tensor] = None + + +@dataclass +class DocList(TensorDataClass): + # the shape is (batch_size, num_candidates, num_document_features) + float_features: torch.Tensor + # the shapes are (batch_size, num_candidates) + mask: torch.Tensor + value: torch.Tensor + + def __post_init__(self): + assert ( + len(self.float_features.shape) == 3 + ), f"Unexpected shape: {self.float_features.shape}" + + # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because + # its type `no_grad` is not callable. + @torch.no_grad() + def select_slate(self, action: torch.Tensor): + row_idx = torch.repeat_interleave( + torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1 + ) + mask = self.mask[row_idx, action] + # Make sure the indices are in the right range + assert mask.to(torch.bool).all() + float_features = self.float_features[row_idx, action] + value = self.value[row_idx, action] + return DocList(float_features, mask, value) + + def as_feature_data(self): + _batch_size, _slate_size, feature_dim = self.float_features.shape + return FeatureData(self.float_features.view(-1, feature_dim)) + + +@dataclass +class FeatureData(TensorDataClass): + # For dense features, shape is (batch_size, feature_dim) + float_features: torch.Tensor + id_list_features: IdListFeature = dataclasses.field(default_factory=dict) + id_score_list_features: IdScoreListFeature = dataclasses.field(default_factory=dict) + # For sequence, shape is (stack_size, batch_size, feature_dim) + stacked_float_features: Optional[torch.Tensor] = None + # For ranking algos, + candidate_docs: Optional[DocList] = None + # Experimental: sticking this here instead of putting it in float_features + # because a lot of places derive the shape of float_features from + # normalization parameters. + time_since_first: Optional[torch.Tensor] = None + + def __post_init__(self): + def usage(): + return ( + f"For sequence features, use `stacked_float_features`." + f"For document features, use `candidate_doc_float_features`." + ) + + if self.float_features.ndim == 3: + no_dup_logger.warning(f"`float_features` should be 2D.\n{usage()}") + elif self.float_features.ndim != 2: + raise ValueError( + f"float_features should be 2D; got {self.float_features.shape}.\n{usage()}" + ) + + @property + def has_float_features_only(self) -> bool: + return ( + not self.id_list_features + and self.time_since_first is None + and self.candidate_docs is None + ) + + def get_tiled_batch(self, num_tiles: int): + assert ( + self.has_float_features_only + ), f"only works for float features now: {self}" + """ + tiled_feature should be (batch_size * num_tiles, feature_dim) + forall i in [batch_size], + tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i] + """ + feat = self.float_features + assert ( + len(feat.shape) == 2 + ), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}." + batch_size, _ = feat.shape + # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`. + tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0) + return FeatureData(float_features=tiled_feat) + + +class TensorFeatureData(torch.nn.Module): + """ + Primarily for using in nn.Sequential + """ + + def forward(self, input: torch.Tensor) -> FeatureData: + assert isinstance(input, torch.Tensor) + return FeatureData(input) + + +class ServingFeatureData(NamedTuple): + float_features_with_presence: Tuple[torch.Tensor, torch.Tensor] + id_list_features: ServingIdListFeature + id_score_list_features: ServingIdScoreListFeature + + +@dataclass +class PreprocessedRankingInput(TensorDataClass): + state: FeatureData + src_seq: FeatureData + src_src_mask: torch.Tensor + tgt_in_seq: Optional[FeatureData] = None + tgt_out_seq: Optional[FeatureData] = None + tgt_tgt_mask: Optional[torch.Tensor] = None + slate_reward: Optional[torch.Tensor] = None + position_reward: Optional[torch.Tensor] = None + # all indices will be +2 to account for padding + # symbol (0) and decoder_start_symbol (1) + src_in_idx: Optional[torch.Tensor] = None + tgt_in_idx: Optional[torch.Tensor] = None + tgt_out_idx: Optional[torch.Tensor] = None + tgt_out_probs: Optional[torch.Tensor] = None + # store ground-truth target sequences + optim_tgt_in_idx: Optional[torch.Tensor] = None + optim_tgt_out_idx: Optional[torch.Tensor] = None + optim_tgt_in_seq: Optional[FeatureData] = None + optim_tgt_out_seq: Optional[FeatureData] = None + + def batch_size(self) -> int: + return self.state.float_features.size()[0] + + @classmethod + def from_tensors( + cls, + state: torch.Tensor, + src_seq: torch.Tensor, + src_src_mask: torch.Tensor, + tgt_in_seq: Optional[torch.Tensor] = None, + tgt_out_seq: Optional[torch.Tensor] = None, + tgt_tgt_mask: Optional[torch.Tensor] = None, + slate_reward: Optional[torch.Tensor] = None, + position_reward: Optional[torch.Tensor] = None, + src_in_idx: Optional[torch.Tensor] = None, + tgt_in_idx: Optional[torch.Tensor] = None, + tgt_out_idx: Optional[torch.Tensor] = None, + tgt_out_probs: Optional[torch.Tensor] = None, + optim_tgt_in_idx: Optional[torch.Tensor] = None, + optim_tgt_out_idx: Optional[torch.Tensor] = None, + optim_tgt_in_seq: Optional[torch.Tensor] = None, + optim_tgt_out_seq: Optional[torch.Tensor] = None, + **kwargs, + ): + assert isinstance(state, torch.Tensor) + assert isinstance(src_seq, torch.Tensor) + assert isinstance(src_src_mask, torch.Tensor) + assert tgt_in_seq is None or isinstance(tgt_in_seq, torch.Tensor) + assert tgt_out_seq is None or isinstance(tgt_out_seq, torch.Tensor) + assert tgt_tgt_mask is None or isinstance(tgt_tgt_mask, torch.Tensor) + assert slate_reward is None or isinstance(slate_reward, torch.Tensor) + assert position_reward is None or isinstance(position_reward, torch.Tensor) + assert src_in_idx is None or isinstance(src_in_idx, torch.Tensor) + assert tgt_in_idx is None or isinstance(tgt_in_idx, torch.Tensor) + assert tgt_out_idx is None or isinstance(tgt_out_idx, torch.Tensor) + assert tgt_out_probs is None or isinstance(tgt_out_probs, torch.Tensor) + assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor) + assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor) + assert optim_tgt_in_seq is None or isinstance(optim_tgt_in_seq, torch.Tensor) + assert optim_tgt_out_seq is None or isinstance(optim_tgt_out_seq, torch.Tensor) + + return cls( + state=FeatureData(float_features=state), + src_seq=FeatureData(float_features=src_seq), + src_src_mask=src_src_mask, + tgt_in_seq=FeatureData(float_features=tgt_in_seq) + if tgt_in_seq is not None + else None, + tgt_out_seq=FeatureData(float_features=tgt_out_seq) + if tgt_out_seq is not None + else None, + tgt_tgt_mask=tgt_tgt_mask, + slate_reward=slate_reward, + position_reward=position_reward, + src_in_idx=src_in_idx, + tgt_in_idx=tgt_in_idx, + tgt_out_idx=tgt_out_idx, + tgt_out_probs=tgt_out_probs, + optim_tgt_in_idx=optim_tgt_in_idx, + optim_tgt_out_idx=optim_tgt_out_idx, + optim_tgt_in_seq=FeatureData(float_features=optim_tgt_in_seq) + if optim_tgt_in_seq is not None + else None, + optim_tgt_out_seq=FeatureData(float_features=optim_tgt_out_seq) + if optim_tgt_out_seq is not None + else None, + ) + + def __post_init__(self): + if ( + isinstance(self.state, torch.Tensor) + or isinstance(self.src_seq, torch.Tensor) + or isinstance(self.tgt_in_seq, torch.Tensor) + or isinstance(self.tgt_out_seq, torch.Tensor) + or isinstance(self.optim_tgt_in_seq, torch.Tensor) + or isinstance(self.optim_tgt_out_seq, torch.Tensor) + ): + raise ValueError( + f"Use from_tensors() {type(self.state)} {type(self.src_seq)} " + f"{type(self.tgt_in_seq)} {type(self.tgt_out_seq)} " + f"{type(self.optim_tgt_in_seq)} {type(self.optim_tgt_out_seq)} " + ) + + +@dataclass +class BaseInput(TensorDataClass): + """ + Base class for all inputs, both raw and preprocessed + """ + + state: FeatureData + next_state: FeatureData + reward: torch.Tensor + time_diff: torch.Tensor + step: Optional[torch.Tensor] + not_terminal: torch.Tensor + + def batch_size(self): + return self.state.float_features.size()[0] + + @classmethod + def from_dict(cls, batch): + id_list_features = batch.get(InputColumn.STATE_ID_LIST_FEATURES, None) or {} + id_score_list_features = ( + batch.get(InputColumn.STATE_ID_SCORE_LIST_FEATURES, None) or {} + ) + next_id_list_features = ( + batch.get(InputColumn.NEXT_STATE_ID_LIST_FEATURES, None) or {} + ) + next_id_score_list_features = ( + batch.get(InputColumn.NEXT_STATE_ID_SCORE_LIST_FEATURES, None) or {} + ) + return BaseInput( + state=FeatureData( + float_features=batch[InputColumn.STATE_FEATURES], + id_list_features=id_list_features, + id_score_list_features=id_score_list_features, + ), + next_state=FeatureData( + float_features=batch[InputColumn.NEXT_STATE_FEATURES], + id_list_features=next_id_list_features, + id_score_list_features=next_id_score_list_features, + ), + reward=batch[InputColumn.REWARD], + time_diff=batch[InputColumn.TIME_DIFF], + step=batch[InputColumn.STEP], + not_terminal=batch[InputColumn.NOT_TERMINAL], + ) + + +@dataclass +class ExtraData(TensorDataClass): + mdp_id: Optional[torch.Tensor] = None + sequence_number: Optional[torch.Tensor] = None + action_probability: Optional[torch.Tensor] = None + max_num_actions: Optional[int] = None + metrics: Optional[torch.Tensor] = None + + @classmethod + def from_dict(cls, d): + return cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)}) + + +@dataclass +class DiscreteDqnInput(BaseInput): + action: torch.Tensor + next_action: torch.Tensor + possible_actions_mask: torch.Tensor + possible_next_actions_mask: torch.Tensor + extras: ExtraData + + @classmethod + def from_dict(cls, batch): + base = super().from_dict(batch) + return cls( + state=base.state, + next_state=base.next_state, + reward=base.reward, + time_diff=base.time_diff, + step=base.step, + not_terminal=base.not_terminal, + action=batch[InputColumn.ACTION], + next_action=batch[InputColumn.NEXT_ACTION], + possible_actions_mask=batch[InputColumn.POSSIBLE_ACTIONS_MASK], + possible_next_actions_mask=batch[InputColumn.POSSIBLE_NEXT_ACTIONS_MASK], + extras=batch[InputColumn.EXTRAS], + ) + + +@dataclass +class SlateQInput(BaseInput): + """ + The shapes of `reward`, `reward_mask`, & `next_item_mask` are + `(batch_size, slate_size)`. + + `reward_mask` indicated whether the reward could be observed, e.g., + the item got into viewport or not. + """ + + action: torch.Tensor + next_action: torch.Tensor + reward_mask: torch.Tensor + extras: Optional[ExtraData] = None + + @classmethod + def from_dict(cls, d): + action = d["action"] + next_action = d["next_action"] + return cls( + state=FeatureData( + float_features=d["state_features"], + candidate_docs=DocList( + float_features=d["candidate_features"], + mask=d["item_mask"], + value=d["item_probability"], + ), + ), + next_state=FeatureData( + float_features=d["next_state_features"], + candidate_docs=DocList( + float_features=d["next_candidate_features"], + mask=d["next_item_mask"], + value=d["next_item_probability"], + ), + ), + action=action, + next_action=next_action, + reward=d["position_reward"], + reward_mask=d["reward_mask"], + time_diff=d["time_diff"], + not_terminal=d["not_terminal"], + step=None, + extras=ExtraData.from_dict(d), + ) + + +@dataclass +class ParametricDqnInput(BaseInput): + action: FeatureData + next_action: FeatureData + possible_actions: FeatureData + possible_actions_mask: torch.Tensor + possible_next_actions: FeatureData + possible_next_actions_mask: torch.Tensor + extras: Optional[ExtraData] = None + + @classmethod + def from_dict(cls, batch): + return cls( + state=FeatureData(float_features=batch["state_features"]), + action=FeatureData(float_features=batch["action"]), + next_state=FeatureData(float_features=batch["next_state_features"]), + next_action=FeatureData(float_features=batch["next_action"]), + possible_actions=FeatureData(float_features=batch["possible_actions"]), + possible_actions_mask=batch["possible_actions_mask"], + possible_next_actions=FeatureData( + float_features=batch["possible_next_actions"] + ), + possible_next_actions_mask=batch["possible_next_actions_mask"], + reward=batch["reward"], + not_terminal=batch["not_terminal"], + time_diff=batch["time_diff"], + step=batch["step"], + extras=batch["extras"], + ) + + +@dataclass +class PolicyNetworkInput(BaseInput): + action: FeatureData + next_action: FeatureData + extras: Optional[ExtraData] = None + + @classmethod + def from_dict(cls, batch): + return cls( + state=FeatureData(float_features=batch["state_features"]), + action=FeatureData(float_features=batch["action"]), + next_state=FeatureData(float_features=batch["next_state_features"]), + next_action=FeatureData(float_features=batch["next_action"]), + reward=batch["reward"], + not_terminal=batch["not_terminal"], + time_diff=batch["time_diff"], + step=batch["step"], + extras=batch["extras"], + ) + + def batch_size(self) -> int: + return self.state.float_features.shape[0] + + +@dataclass +class PolicyGradientInput(BaseDataClass): + state: FeatureData + action: torch.Tensor + reward: torch.Tensor + log_prob: torch.Tensor + + @classmethod + def input_prototype(cls): + num_classes = 5 + batch_size = 10 + state_dim = 3 + return cls( + state=FeatureData(float_features=torch.randn(batch_size, state_dim)), + action=F.one_hot(torch.randint(high=num_classes, size=(batch_size,))), + reward=torch.rand(batch_size), + log_prob=torch.log(torch.rand(batch_size)), + ) + + +@dataclass +class MemoryNetworkInput(BaseInput): + action: torch.Tensor + + def batch_size(self): + if len(self.state.float_features.size()) == 2: + return self.state.float_features.size()[0] + elif len(self.state.float_features.size()) == 3: + return self.state.float_features.size()[1] + else: + raise NotImplementedError() + + +@dataclass +class PreprocessedTrainingBatch(TensorDataClass): + training_input: Union[PreprocessedRankingInput] + # TODO: deplicate this and move into individual ones. + extras: ExtraData = field(default_factory=ExtraData) + + def batch_size(self): + return self.training_input.state.float_features.size()[0] + + +@dataclass +class MemoryNetworkOutput(TensorDataClass): + mus: torch.Tensor + sigmas: torch.Tensor + logpi: torch.Tensor + reward: torch.Tensor + not_terminal: torch.Tensor + last_step_lstm_hidden: torch.Tensor + last_step_lstm_cell: torch.Tensor + all_steps_lstm_hidden: torch.Tensor + + +@dataclass +class Seq2RewardOutput(TensorDataClass): + acc_reward: torch.Tensor + + +@dataclass +class DqnPolicyActionSet(TensorDataClass): + greedy: int + softmax: Optional[int] = None + greedy_act_name: Optional[str] = None + softmax_act_name: Optional[str] = None + softmax_act_prob: Optional[float] = None + + +@dataclass +class PlanningPolicyOutput(TensorDataClass): + # best action to take next + next_best_continuous_action: Optional[torch.Tensor] = None + next_best_discrete_action_one_hot: Optional[torch.Tensor] = None + next_best_discrete_action_idx: Optional[int] = None + + +@dataclass +class RankingOutput(TensorDataClass): + # a tensor of integer indices w.r.t. to possible candidates + # shape: batch_size, tgt_seq_len + ranked_tgt_out_idx: Optional[torch.Tensor] = None + # generative probability of ranked tgt sequences at each decoding step + # shape: batch_size, tgt_seq_len, candidate_size + ranked_tgt_out_probs: Optional[torch.Tensor] = None + # log probabilities of given tgt sequences are used in REINFORCE + # shape: batch_size + log_probs: Optional[torch.Tensor] = None + # encoder scores in tgt_out_idx order + encoder_scores: Optional[torch.Tensor] = None + + +@dataclass +class RewardNetworkOutput(TensorDataClass): + predicted_reward: torch.Tensor diff --git a/reagent/validators/model_validator.py b/reagent/validators/model_validator.py index ab9b1622..47a1ceb1 100644 --- a/reagent/validators/model_validator.py +++ b/reagent/validators/model_validator.py @@ -5,8 +5,8 @@ import inspect import logging from reagent.core.registry_meta import RegistryMeta -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.reporting.result_registries import ValidationResult +from reagent.core.types import RLTrainingOutput +from reagent.workflow.result_registries import ValidationResult logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class ModelValidator(metaclass=RegistryMeta): """ result = self.do_validate(training_output) # Avoid circular dependency at import time - from reagent.core.union import ValidationResult__Union + from reagent.core.types import ValidationResult__Union # We need to use inspection because the result can be a future when running on # FBL diff --git a/reagent/validators/no_validation.py b/reagent/validators/no_validation.py index 18e2ba7f..a351a131 100644 --- a/reagent/validators/no_validation.py +++ b/reagent/validators/no_validation.py @@ -2,7 +2,7 @@ from reagent.core.dataclasses import dataclass from reagent.core.result_types import NoValidationResults -from reagent.core.rl_training_output import RLTrainingOutput +from reagent.core.types import RLTrainingOutput from reagent.validators.model_validator import ModelValidator diff --git a/reagent/data_fetchers/oss_data_fetcher.py b/reagent/workflow/data_fetcher.py similarity index 77% rename from reagent/data_fetchers/oss_data_fetcher.py rename to reagent/workflow/data_fetcher.py index 4d3ccd04..e9b1f03b 100644 --- a/reagent/data_fetchers/oss_data_fetcher.py +++ b/reagent/workflow/data_fetcher.py @@ -1,14 +1,8 @@ #!/usr/bin/env python3 - import logging -from typing import Dict, List, Optional, Tuple - -import reagent.core.types as rlt - -# pyre-fixme[21]: Could not find `petastorm`. -from petastorm import make_batch_reader -from petastorm.pytorch import DataLoader, decimal_friendly_collate +from typing import List, Optional, Tuple +# pyre-fixme[21]: Could not find `pyspark`. # pyre-fixme[21]: Could not find `pyspark`. from pyspark.sql.functions import col, crc32, explode, map_keys, udf @@ -23,20 +17,7 @@ from pyspark.sql.types import ( StructField, StructType, ) -from reagent.core.types import ( - Dataset, - OssDataset, - PreprocessingOptions, - ReaderOptions, - TableSpec, -) -from reagent.data_fetchers.data_fetcher import DataFetcher -from reagent.evaluation.evaluation_data_page import EvaluationDataPage -from reagent.parameters import NormalizationParameters -from reagent.preprocessing.batch_preprocessor import BatchPreprocessor -from reagent.torch_utils import dict_to_tensor -from reagent.training import RLTrainer, SACTrainer, TD3Trainer -from reagent.workflow.identify_types_flow import identify_normalization_parameters +from reagent.core.types import Dataset, OssDataset, TableSpec from reagent.workflow.spark_utils import get_spark_session, get_table_url @@ -396,9 +377,8 @@ def rand_string(length): import random """Generate a random string of fixed length """ - r = random.SystemRandom() letters = string.ascii_lowercase - return "".join(r.choice(letters) for _ in range(length)) + return "".join(random.choice(letters) for _ in range(length)) def upload_as_parquet(df) -> Dataset: @@ -471,108 +451,3 @@ def query_data( include_possible_actions=include_possible_actions, ) return upload_as_parquet(df) - - -def collate_and_preprocess(batch_preprocessor: BatchPreprocessor, use_gpu: bool): - """ Helper for Petastorm's DataLoader to preprocess. - TODO(kaiwenw): parallelize preprocessing by using transform of Petastorm reader - Should pin memory and preprocess in reader and convert to gpu in collate_fn. - """ - - def collate_fn(batch_list: List[Dict]): - batch = decimal_friendly_collate(batch_list) - preprocessed_batch = batch_preprocessor(batch) - if use_gpu: - preprocessed_batch = preprocessed_batch.cuda() - return preprocessed_batch - - return collate_fn - - -class OssDataFetcher(DataFetcher): - def query_data(self, **kwargs): - return query_data(**kwargs) - - def query_data_parametric(self, **kwargs): - return query_data(**kwargs) - - def identify_normalization_parameters( - self, - table_spec: TableSpec, - column_name: str, - preprocessing_options: PreprocessingOptions, - seed: Optional[int] = None, - ) -> Dict[int, NormalizationParameters]: - return identify_normalization_parameters( - table_spec, column_name, preprocessing_options, seed - ) - - def get_table_row_count(self, dataset: OssDataset): - spark = get_spark_session() - return spark.read.parquet(dataset.parquet_url).count() - - def gather_and_sort_eval_data( - self, - trainer: RLTrainer, - eval_dataset: Dataset, - batch_preprocessor: BatchPreprocessor, - use_gpu: bool, - reader_options: ReaderOptions, - ) -> EvaluationDataPage: - """ Sorts, computes logged values and validates the EvaluationDataPage """ - if isinstance(trainer, (SACTrainer, TD3Trainer)): - raise NotImplementedError("TODO: Implement CPE for continuous algos") - assert ( - trainer.calc_cpe_in_training - ), "this function should only be called when this is true." - - # first read the eval_dataset as EvaluationDataPages - device = "cuda" if use_gpu else "cpu" - eval_data = None - with make_batch_reader( - eval_dataset.parquet_url, - num_epochs=1, - reader_pool_type=reader_options.petastorm_reader_pool_type, - ) as reader: - for batch in reader: - assert rlt.isinstance_namedtuple(batch) - tensor_batch = dict_to_tensor(batch._asdict(), device=device) - tdp: rlt.PreprocessedTrainingBatch = batch_preprocessor(tensor_batch) - edp = EvaluationDataPage.create_from_training_batch(tdp, trainer) - if eval_data is None: - eval_data = edp - else: - eval_data = eval_data.append(edp) - - eval_data = eval_data.sort() - eval_data = eval_data.compute_values(trainer.gamma) - eval_data.validate() - return eval_data - - def get_dataloader( - self, - dataset: Dataset, - batch_size: int, - batch_preprocessor: Optional[BatchPreprocessor], - use_gpu: bool, - reader_options: ReaderOptions, - ): - """ get petastorm loader for dataset (with preprocessor) """ - data_reader = make_batch_reader( - dataset.parquet_url, - num_epochs=1, - reader_pool_type=reader_options.petastorm_reader_pool_type, - ) - # NOTE: must be wrapped by DataLoaderWrapper to call __exit__() on end of epoch - return DataLoader( - data_reader, - batch_size=batch_size, - collate_fn=collate_and_preprocess( - batch_preprocessor=batch_preprocessor, use_gpu=use_gpu - ), - ) - - def get_post_dataloader_preprocessor( - self, reader_options: ReaderOptions, use_gpu: bool - ): - return None diff --git a/reagent/workflow/env.py b/reagent/workflow/env.py new file mode 100644 index 00000000..693585ef --- /dev/null +++ b/reagent/workflow/env.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + + +def get_workflow_id() -> int: + # This is just stub. You will want to replace this file. + return 987654321 diff --git a/reagent/workflow/gym_batch_rl.py b/reagent/workflow/gym_batch_rl.py index dcd01ba3..214dbba1 100644 --- a/reagent/workflow/gym_batch_rl.py +++ b/reagent/workflow/gym_batch_rl.py @@ -16,10 +16,10 @@ from reagent.gym.envs.gym import Gym from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model from reagent.gym.runners.gymrunner import evaluate_for_n_episodes from reagent.gym.utils import fill_replay_buffer -from reagent.model_managers.union import ModelManager__Union from reagent.publishers.union import FileSystemPublisher, ModelPublisher__Union from reagent.replay_memory.circular_replay_buffer import ReplayBuffer from reagent.replay_memory.utils import replay_buffer_to_pre_timeline_df +from reagent.workflow.model_managers.union import ModelManager__Union from reagent.workflow.spark_utils import call_spark_class, get_spark_session diff --git a/reagent/workflow/identify_types_flow.py b/reagent/workflow/identify_types_flow.py index e77a4f31..66260865 100644 --- a/reagent/workflow/identify_types_flow.py +++ b/reagent/workflow/identify_types_flow.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional -import reagent.core.types as rlt +import reagent.types as rlt # pyre-fixme[21]: Could not find `pyspark`. # pyre-fixme[21]: Could not find `pyspark`. diff --git a/reagent/model_managers/actor_critic/__init__.py b/reagent/workflow/model_managers/actor_critic/__init__.py similarity index 100% rename from reagent/model_managers/actor_critic/__init__.py rename to reagent/workflow/model_managers/actor_critic/__init__.py diff --git a/reagent/model_managers/actor_critic/sac.py b/reagent/workflow/model_managers/actor_critic/sac.py similarity index 67% rename from reagent/model_managers/actor_critic/sac.py rename to reagent/workflow/model_managers/actor_critic/sac.py index 3f94e729..95bc4da3 100644 --- a/reagent/model_managers/actor_critic/sac.py +++ b/reagent/workflow/model_managers/actor_critic/sac.py @@ -3,19 +3,10 @@ import logging -from typing import Dict, Optional +from typing import Optional import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import ( - Dataset, - PreprocessingOptions, - ReaderOptions, - RewardOptions, - TableSpec, -) -from reagent.model_managers.actor_critic_base import ActorCriticBase from reagent.models.base import ModelBase from reagent.net_builder.continuous_actor.gaussian_fully_connected import ( GaussianFullyConnected, @@ -29,8 +20,9 @@ from reagent.net_builder.unions import ( from reagent.net_builder.value.fully_connected import ( FullyConnected as ValueFullyConnected, ) -from reagent.parameters import NormalizationData, NormalizationKey, param_hash +from reagent.parameters import param_hash from reagent.training import SACTrainer, SACTrainerParameters +from reagent.workflow.model_managers.actor_critic_base import ActorCriticBase logger = logging.getLogger(__name__) @@ -67,28 +59,26 @@ class SAC(ActorCriticBase): def __post_init_post_parse__(self): super().__post_init_post_parse__() + self._actor_network: Optional[ModelBase] = None + self.rl_parameters = self.trainer_param.rl - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> SACTrainer: + def build_trainer(self) -> SACTrainer: actor_net_builder = self.actor_net_builder.value - actor_network = actor_net_builder.build_actor( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. + # pyre-fixme[16]: `SAC` has no attribute `_actor_network`. + self._actor_network = actor_net_builder.build_actor( + self.state_normalization_data, self.action_normalization_data ) critic_net_builder = self.critic_net_builder.value - q1_network = critic_net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. + # pyre-fixme[16]: `SAC` has no attribute `_q1_network`. + self._q1_network = critic_net_builder.build_q_network( + self.state_normalization_data, self.action_normalization_data ) q2_network = ( critic_net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + self.state_normalization_data, self.action_normalization_data ) if self.use_2_q_functions else None @@ -100,36 +90,35 @@ class SAC(ActorCriticBase): # pyre-fixme[16]: `Optional` has no attribute `value`. value_net_builder = self.value_net_builder.value value_network = value_net_builder.build_value_network( - normalization_data_map[NormalizationKey.STATE] + self.state_normalization_data ) - if use_gpu: - q1_network.cuda() + if self.use_gpu: + self._q1_network.cuda() if q2_network: q2_network.cuda() if value_network: value_network.cuda() - actor_network.cuda() + self._actor_network.cuda() trainer = SACTrainer( - actor_network=actor_network, - q1_network=q1_network, + actor_network=self._actor_network, + q1_network=self._q1_network, value_network=value_network, q2_network=q2_network, - use_gpu=use_gpu, + use_gpu=self.use_gpu, # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SACTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer: SACTrainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value + assert self._actor_network is not None return net_builder.build_serving_module( - trainer.actor_network, - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + self._actor_network, + self.state_normalization_data, + self.action_normalization_data, serve_mean_policy=self.serve_mean_policy, ) diff --git a/reagent/model_managers/actor_critic/td3.py b/reagent/workflow/model_managers/actor_critic/td3.py similarity index 61% rename from reagent/model_managers/actor_critic/td3.py rename to reagent/workflow/model_managers/actor_critic/td3.py index 95641fbe..60b3bdaa 100644 --- a/reagent/model_managers/actor_critic/td3.py +++ b/reagent/workflow/model_managers/actor_critic/td3.py @@ -3,19 +3,10 @@ import logging -from typing import Dict, Optional +from typing import Optional import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import ( - Dataset, - PreprocessingOptions, - ReaderOptions, - RewardOptions, - TableSpec, -) -from reagent.model_managers.actor_critic_base import ActorCriticBase from reagent.models.base import ModelBase from reagent.net_builder.continuous_actor.fully_connected import ( FullyConnected as ContinuousFullyConnected, @@ -27,13 +18,9 @@ from reagent.net_builder.unions import ( ContinuousActorNetBuilder__Union, ParametricDQNNetBuilder__Union, ) -from reagent.parameters import ( - EvaluationParameters, - NormalizationData, - NormalizationKey, - param_hash, -) +from reagent.parameters import EvaluationParameters, param_hash from reagent.training import TD3Trainer, TD3TrainerParameters +from reagent.workflow.model_managers.actor_critic_base import ActorCriticBase logger = logging.getLogger(__name__) @@ -63,56 +50,53 @@ class TD3(ActorCriticBase): def __post_init_post_parse__(self): super().__post_init_post_parse__() + self._actor_network: Optional[ModelBase] = None + self.rl_parameters = self.trainer_param.rl - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> TD3Trainer: + def build_trainer(self) -> TD3Trainer: actor_net_builder = self.actor_net_builder.value - actor_network = actor_net_builder.build_actor( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + # pyre-fixme[16]: `TD3` has no attribute `_actor_network`. + # pyre-fixme[16]: `TD3` has no attribute `_actor_network`. + self._actor_network = actor_net_builder.build_actor( + self.state_normalization_data, self.action_normalization_data ) critic_net_builder = self.critic_net_builder.value - q1_network = critic_net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + # pyre-fixme[16]: `TD3` has no attribute `_q1_network`. + # pyre-fixme[16]: `TD3` has no attribute `_q1_network`. + self._q1_network = critic_net_builder.build_q_network( + self.state_normalization_data, self.action_normalization_data ) q2_network = ( critic_net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + self.state_normalization_data, self.action_normalization_data ) if self.use_2_q_functions else None ) - if use_gpu: - q1_network.cuda() + if self.use_gpu: + self._q1_network.cuda() if q2_network: q2_network.cuda() - actor_network.cuda() + self._actor_network.cuda() trainer = TD3Trainer( - actor_network=actor_network, - q1_network=q1_network, + actor_network=self._actor_network, + q1_network=self._q1_network, q2_network=q2_network, - use_gpu=use_gpu, + use_gpu=self.use_gpu, # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `TD3TrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer: TD3Trainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: net_builder = self.actor_net_builder.value + assert self._actor_network is not None return net_builder.build_serving_module( - trainer.actor_network, - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ACTION], + self._actor_network, + self.state_normalization_data, + self.action_normalization_data, ) diff --git a/reagent/model_managers/actor_critic_base.py b/reagent/workflow/model_managers/actor_critic_base.py similarity index 61% rename from reagent/model_managers/actor_critic_base.py rename to reagent/workflow/model_managers/actor_critic_base.py index cdd8d5ad..2fd347e3 100644 --- a/reagent/model_managers/actor_critic_base.py +++ b/reagent/workflow/model_managers/actor_critic_base.py @@ -5,7 +5,7 @@ import logging from typing import Dict, List, Optional, Tuple import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.dataclasses import dataclass, field from reagent.core.types import ( @@ -13,13 +13,13 @@ from reagent.core.types import ( PreprocessingOptions, ReaderOptions, RewardOptions, + RLTrainingOutput, + RLTrainingReport, TableSpec, ) -from reagent.data_fetchers.data_fetcher import DataFetcher from reagent.evaluation.evaluator import Evaluator, get_metrics_to_score from reagent.gym.policies.policy import Policy from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model -from reagent.model_managers.model_manager import ModelManager from reagent.models.base import ModelBase from reagent.parameters import EvaluationParameters, NormalizationData, NormalizationKey from reagent.preprocessing.batch_preprocessor import ( @@ -29,7 +29,11 @@ from reagent.preprocessing.batch_preprocessor import ( ) from reagent.preprocessing.normalization import get_feature_config from reagent.preprocessing.types import InputColumn -from reagent.reporting.actor_critic_reporter import ActorCriticReporter +from reagent.workflow.data_fetcher import query_data +from reagent.workflow.identify_types_flow import identify_normalization_parameters +from reagent.workflow.model_managers.model_manager import ModelManager +from reagent.workflow.reporters.actor_critic_reporter import ActorCriticReporter +from reagent.workflow.utils import train_and_evaluate_generic logger = logging.getLogger(__name__) @@ -81,18 +85,40 @@ class ActorCriticBase(ModelManager): "Please set action whitelist features in action_float_features field of " "config instead" ) + self._state_preprocessing_options = self.state_preprocessing_options + self._action_preprocessing_options = self.action_preprocessing_options + + # To be filled by property metrics_to_score + self._metrics_to_score: Optional[List[str]] = None + + # To be filled by subclasses + self._actor_network: Optional[ModelBase] = None + self._q1_network: Optional[ModelBase] = None @property def should_generate_eval_dataset(self) -> bool: - return False # CPE not supported in A/C yet + return self.eval_parameters.calc_cpe_in_training - def create_policy(self, trainer) -> Policy: + def create_policy(self, serving: bool) -> Policy: """ Create online actor critic policy. """ - return ActorPolicyWrapper(trainer.actor_network) + + if serving: + return create_predictor_policy_from_model(self.build_serving_module()) + else: + return ActorPolicyWrapper(self._actor_network) @property - def metrics_to_score(self, reward_options: RewardOptions) -> List[str]: - return get_metrics_to_score(reward_options.metric_reward_values) + def metrics_to_score(self) -> List[str]: + assert self._reward_options is not None + if self._metrics_to_score is None: + # pyre-fixme[16]: `ActorCriticBase` has no attribute `_metrics_to_score`. + # pyre-fixme[16]: `ActorCriticBase` has no attribute `_metrics_to_score`. + self._metrics_to_score = get_metrics_to_score( + # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. + # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. + self._reward_options.metric_reward_values + ) + return self._metrics_to_score @property def state_feature_config(self) -> rlt.ModelFeatureConfig: @@ -104,11 +130,11 @@ class ActorCriticBase(ModelManager): return get_feature_config(self.action_float_features) def run_feature_identification( - self, data_fetcher: DataFetcher, input_table_spec: TableSpec + self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: # Run state feature identification state_preprocessing_options = ( - self.state_preprocessing_options or PreprocessingOptions() + self._state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos @@ -118,13 +144,13 @@ class ActorCriticBase(ModelManager): whitelist_features=state_features ) - state_normalization_parameters = data_fetcher.identify_normalization_parameters( + state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) # Run action feature identification action_preprocessing_options = ( - self.action_preprocessing_options or PreprocessingOptions() + self._action_preprocessing_options or PreprocessingOptions() ) action_features = [ ffi.feature_id for ffi in self.action_feature_config.float_feature_infos @@ -142,7 +168,7 @@ class ActorCriticBase(ModelManager): whitelist_features=action_features, feature_overrides={fid: action_feature_override for fid in action_features}, ) - action_normalization_parameters = data_fetcher.identify_normalization_parameters( + action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, action_preprocessing_options ) @@ -161,13 +187,12 @@ class ActorCriticBase(ModelManager): def query_data( self, - data_fetcher: DataFetcher, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, ) -> Dataset: logger.info("Starting query") - return data_fetcher.query_data( + return query_data( input_table_spec=input_table_spec, discrete_action=False, include_possible_actions=False, @@ -175,31 +200,59 @@ class ActorCriticBase(ModelManager): sample_range=sample_range, ) - def get_reporter(self): - return ActorCriticReporter() - - def build_batch_preprocessor( - self, - reader_options: ReaderOptions, - use_gpu: bool, - batch_size: int, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> BatchPreprocessor: + def build_batch_preprocessor(self) -> BatchPreprocessor: state_preprocessor = Preprocessor( - normalization_data_map[ - NormalizationKey.STATE - ].dense_normalization_parameters, - use_gpu=use_gpu, + self.state_normalization_data.dense_normalization_parameters, + use_gpu=self.use_gpu, ) action_preprocessor = Preprocessor( - normalization_data_map[ - NormalizationKey.ACTION - ].dense_normalization_parameters, - use_gpu=use_gpu, + self.action_normalization_data.dense_normalization_parameters, + use_gpu=self.use_gpu, ) return PolicyNetworkBatchPreprocessor( state_preprocessor=state_preprocessor, action_preprocessor=action_preprocessor, - use_gpu=use_gpu, + use_gpu=self.use_gpu, ) + + # TODO: deprecate, once we deprecate internal page handlers + def train( + self, + train_dataset: Dataset, + eval_dataset: Optional[Dataset], + num_epochs: int, + reader_options: ReaderOptions, + ) -> RLTrainingOutput: + + reporter = ActorCriticReporter() + # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. + self.trainer.add_observer(reporter) + + evaluator = Evaluator( + action_names=None, + gamma=self.rl_parameters.gamma, + model=self.trainer, + metrics_to_score=self.metrics_to_score, + ) + # pyre-fixme[16]: `Evaluator` has no attribute `add_observer`. + evaluator.add_observer(reporter) + + batch_preprocessor = self.build_batch_preprocessor() + train_and_evaluate_generic( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. + trainer=self.trainer, + num_epochs=num_epochs, + use_gpu=self.use_gpu, + batch_preprocessor=batch_preprocessor, + reporter=reporter, + evaluator=evaluator, + reader_options=self.reader_options, + ) + # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. + training_report = RLTrainingReport.make_union_instance( + reporter.generate_training_report() + ) + + return RLTrainingOutput(training_report=training_report) diff --git a/reagent/model_managers/discrete/__init__.py b/reagent/workflow/model_managers/discrete/__init__.py similarity index 100% rename from reagent/model_managers/discrete/__init__.py rename to reagent/workflow/model_managers/discrete/__init__.py diff --git a/reagent/model_managers/discrete/discrete_c51dqn.py b/reagent/workflow/model_managers/discrete/discrete_c51dqn.py similarity index 71% rename from reagent/model_managers/discrete/discrete_c51dqn.py rename to reagent/workflow/model_managers/discrete/discrete_c51dqn.py index e4d71059..7eac95e6 100644 --- a/reagent/model_managers/discrete/discrete_c51dqn.py +++ b/reagent/workflow/model_managers/discrete/discrete_c51dqn.py @@ -1,16 +1,15 @@ #!/usr/bin/env python3 import logging -from typing import Dict import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.model_managers.discrete_dqn_base import DiscreteDQNBase from reagent.net_builder.categorical_dqn.categorical import Categorical from reagent.net_builder.unions import CategoricalDQNNetBuilder__Union -from reagent.parameters import NormalizationData, NormalizationKey, param_hash +from reagent.parameters import param_hash from reagent.training import C51Trainer, C51TrainerParameters +from reagent.training.loss_reporter import NoOpLossReporter +from reagent.workflow.model_managers.discrete_dqn_base import DiscreteDQNBase logger = logging.getLogger(__name__) @@ -38,24 +37,18 @@ class DiscreteC51DQN(DiscreteDQNBase): def __post_init_post_parse__(self): super().__post_init_post_parse__() - - assert ( - len(self.trainer_param.actions) > 1 - ), "DiscreteC51DQN needs at least 2 actions" + self.rl_parameters = self.trainer_param.rl + self.action_names = self.trainer_param.actions + assert len(self.action_names) > 1, "DiscreteC51DQN needs at least 2 actions" assert ( self.trainer_param.minibatch_size % 8 == 0 ), "The minibatch size must be divisible by 8 for performance reasons." - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> C51Trainer: + def build_trainer(self) -> C51Trainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( - state_normalization_data=normalization_data_map[NormalizationKey.STATE], - output_dim=len(self.trainer_param.actions), + state_normalization_data=self.state_normalization_data, + output_dim=len(self.action_names), # pyre-fixme[16]: `C51TrainerParameters` has no attribute `num_atoms`. # pyre-fixme[16]: `C51TrainerParameters` has no attribute `num_atoms`. num_atoms=self.trainer_param.num_atoms, @@ -67,31 +60,35 @@ class DiscreteC51DQN(DiscreteDQNBase): qmax=self.trainer_param.qmax, ) - if use_gpu: + if self.use_gpu: q_network = q_network.cuda() q_network_target = q_network.get_target_network() + # pyre-fixme[16]: `DiscreteC51DQN` has no attribute `_q_network`. + # pyre-fixme[16]: `DiscreteC51DQN` has no attribute `_q_network`. + self._q_network = q_network + return C51Trainer( q_network=q_network, q_network_target=q_network_target, - metrics_to_score=self.metrics_to_score(reward_options), - use_gpu=use_gpu, + metrics_to_score=self.metrics_to_score, + loss_reporter=NoOpLossReporter(), + use_gpu=self.use_gpu, # pyre-fixme[16]: `C51TrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `C51TrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer: C51Trainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ + assert self._q_network is not None, "_q_network was not initialized" net_builder = self.net_builder.value return net_builder.build_serving_module( - trainer.q_network, - normalization_data_map[NormalizationKey.STATE], - action_names=self.trainer_param.actions, + self._q_network, + self.state_normalization_data, + action_names=self.action_names, state_feature_config=self.state_feature_config, ) diff --git a/reagent/model_managers/discrete/discrete_dqn.py b/reagent/workflow/model_managers/discrete/discrete_dqn.py similarity index 67% rename from reagent/model_managers/discrete/discrete_dqn.py rename to reagent/workflow/model_managers/discrete/discrete_dqn.py index e85c2a57..c17a3d79 100644 --- a/reagent/model_managers/discrete/discrete_dqn.py +++ b/reagent/workflow/model_managers/discrete/discrete_dqn.py @@ -1,18 +1,16 @@ #!/usr/bin/env python3 import logging -from typing import Dict import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.model_managers.discrete_dqn_base import DiscreteDQNBase from reagent.net_builder.discrete_dqn.dueling import Dueling from reagent.net_builder.discrete_dqn.fully_connected import FullyConnected from reagent.net_builder.unions import DiscreteDQNNetBuilder__Union -from reagent.parameters import NormalizationData, NormalizationKey, param_hash +from reagent.parameters import param_hash from reagent.training import DQNTrainer, DQNTrainerParameters -from reagent.training.trainer import Trainer +from reagent.training.loss_reporter import NoOpLossReporter +from reagent.workflow.model_managers.discrete_dqn_base import DiscreteDQNBase logger = logging.getLogger(__name__) @@ -41,32 +39,26 @@ class DiscreteDQN(DiscreteDQNBase): def __post_init_post_parse__(self): super().__post_init_post_parse__() - + self.rl_parameters = self.trainer_param.rl + self.action_names = self.trainer_param.actions assert ( - len(self.trainer_param.actions) > 1 - ), f"DiscreteDQNModel needs at least 2 actions. Got {self.trainer_param.actions}." + len(self.action_names) > 1 + ), f"DiscreteDQNModel needs at least 2 actions. Got {self.action_names}." if self.trainer_param.minibatch_size % 8 != 0: logger.warn( f"minibatch size ({self.trainer_param.minibatch_size}) " "should be divisible by 8 for performance reasons!" ) - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> DQNTrainer: - state_normalization_data = normalization_data_map["state"] + def build_trainer(self) -> DQNTrainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( self.state_feature_config, - state_normalization_data, - # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `actions`. - len(self.trainer_param.actions), + self.state_normalization_data, + len(self.action_names), ) - if use_gpu: + if self.use_gpu: q_network = q_network.cuda() q_network_target = q_network.get_target_network() @@ -74,55 +66,60 @@ class DiscreteDQN(DiscreteDQNBase): reward_network, q_network_cpe, q_network_cpe_target = None, None, None # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `evaluation`. # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `evaluation`. - if self.eval_parameters.calc_cpe_in_training: + if self.trainer_param.evaluation.calc_cpe_in_training: # Metrics + reward - num_output_nodes = (len(self.metrics_to_score(reward_options)) + 1) * len( + num_output_nodes = (len(self.metrics_to_score) + 1) * len( + # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `actions`. # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `actions`. self.trainer_param.actions ) cpe_net_builder = self.cpe_net_builder.value reward_network = cpe_net_builder.build_q_network( - self.state_feature_config, state_normalization_data, num_output_nodes + self.state_feature_config, + self.state_normalization_data, + num_output_nodes, ) q_network_cpe = cpe_net_builder.build_q_network( - self.state_feature_config, state_normalization_data, num_output_nodes + self.state_feature_config, + self.state_normalization_data, + num_output_nodes, ) - if use_gpu: + if self.use_gpu: reward_network.cuda() q_network_cpe.cuda() q_network_cpe_target = q_network_cpe.get_target_network() + # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`. + # pyre-fixme[16]: `DiscreteDQN` has no attribute `_q_network`. + self._q_network = q_network trainer = DQNTrainer( q_network=q_network, q_network_target=q_network_target, reward_network=reward_network, q_network_cpe=q_network_cpe, q_network_cpe_target=q_network_cpe_target, - metrics_to_score=self.metrics_to_score(reward_options), - use_gpu=use_gpu, - evaluation=self.eval_parameters, + metrics_to_score=self.metrics_to_score, + loss_reporter=NoOpLossReporter(), + use_gpu=self.use_gpu, # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer: DQNTrainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ - assert trainer.q_network is not None, "_q_network was not initialized" + assert self._q_network is not None, "_q_network was not initialized" net_builder = self.net_builder.value return net_builder.build_serving_module( - trainer.q_network, - normalization_data_map["state"], - # pyre-fixme[16]: `DQNTrainerParameters` has no attribute `actions`. - action_names=self.trainer_param.actions, + self._q_network, + self.state_normalization_data, + action_names=self.action_names, state_feature_config=self.state_feature_config, ) diff --git a/reagent/model_managers/discrete/discrete_qrdqn.py b/reagent/workflow/model_managers/discrete/discrete_qrdqn.py similarity index 71% rename from reagent/model_managers/discrete/discrete_qrdqn.py rename to reagent/workflow/model_managers/discrete/discrete_qrdqn.py index b02c7ace..e8747656 100644 --- a/reagent/model_managers/discrete/discrete_qrdqn.py +++ b/reagent/workflow/model_managers/discrete/discrete_qrdqn.py @@ -1,21 +1,19 @@ #!/usr/bin/env python3 import logging -from typing import Dict import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.gym.policies.policy import Policy -from reagent.model_managers.discrete_dqn_base import DiscreteDQNBase from reagent.net_builder.discrete_dqn.fully_connected import FullyConnected from reagent.net_builder.quantile_dqn.dueling_quantile import DuelingQuantile from reagent.net_builder.unions import ( DiscreteDQNNetBuilder__Union, QRDQNNetBuilder__Union, ) -from reagent.parameters import NormalizationData, NormalizationKey, param_hash +from reagent.parameters import param_hash from reagent.training import QRDQNTrainer, QRDQNTrainerParameters +from reagent.training.loss_reporter import NoOpLossReporter +from reagent.workflow.model_managers.discrete_dqn_base import DiscreteDQNBase logger = logging.getLogger(__name__) @@ -43,30 +41,24 @@ class DiscreteQRDQN(DiscreteDQNBase): def __post_init_post_parse__(self): super().__post_init_post_parse__() - - assert ( - len(self.trainer_param.actions) > 1 - ), "DiscreteQRDQNModel needs at least 2 actions" + self.rl_parameters = self.trainer_param.rl + self.action_names = self.trainer_param.actions + assert len(self.action_names) > 1, "DiscreteQRDQNModel needs at least 2 actions" assert ( self.trainer_param.minibatch_size % 8 == 0 ), "The minibatch size must be divisible by 8 for performance reasons." - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> QRDQNTrainer: + def build_trainer(self) -> QRDQNTrainer: net_builder = self.net_builder.value q_network = net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - len(self.trainer_param.actions), + self.state_normalization_data, + len(self.action_names), # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `num_atoms`. # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `num_atoms`. num_atoms=self.trainer_param.num_atoms, ) - if use_gpu: + if self.use_gpu: q_network = q_network.cuda() q_network_target = q_network.get_target_network() @@ -74,9 +66,9 @@ class DiscreteQRDQN(DiscreteDQNBase): reward_network, q_network_cpe, q_network_cpe_target = None, None, None # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `evaluation`. # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `evaluation`. - if self.eval_parameters.calc_cpe_in_training: + if self.trainer_param.evaluation.calc_cpe_in_training: # Metrics + reward - num_output_nodes = (len(self.metrics_to_score(reward_options)) + 1) * len( + num_output_nodes = (len(self.metrics_to_score) + 1) * len( # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `actions`. # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `actions`. self.trainer_param.actions @@ -85,48 +77,47 @@ class DiscreteQRDQN(DiscreteDQNBase): cpe_net_builder = self.cpe_net_builder.value reward_network = cpe_net_builder.build_q_network( self.state_feature_config, - normalization_data_map[NormalizationKey.STATE], + self.state_normalization_data, num_output_nodes, ) q_network_cpe = cpe_net_builder.build_q_network( self.state_feature_config, - normalization_data_map[NormalizationKey.STATE], + self.state_normalization_data, num_output_nodes, ) - if use_gpu: + if self.use_gpu: reward_network.cuda() q_network_cpe.cuda() q_network_cpe_target = q_network_cpe.get_target_network() + # pyre-fixme[16]: `DiscreteQRDQN` has no attribute `_q_network`. + self._q_network = q_network trainer = QRDQNTrainer( q_network=q_network, q_network_target=q_network_target, reward_network=reward_network, - evaluation=self.eval_parameters, q_network_cpe=q_network_cpe, q_network_cpe_target=q_network_cpe_target, - metrics_to_score=self.metrics_to_score(reward_options), - use_gpu=use_gpu, + metrics_to_score=self.metrics_to_score, + loss_reporter=NoOpLossReporter(), + use_gpu=self.use_gpu, # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `QRDQNTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) return trainer - def build_serving_module( - self, - normalization_data_map: Dict[str, NormalizationData], - trainer: QRDQNTrainer, - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ + assert self._q_network is not None, "_q_network was not initialized" net_builder = self.net_builder.value return net_builder.build_serving_module( - trainer.q_network, - normalization_data_map[NormalizationKey.STATE], - action_names=self.trainer_param.actions, + self._q_network, + self.state_normalization_data, + action_names=self.action_names, state_feature_config=self.state_feature_config, ) diff --git a/reagent/workflow/model_managers/discrete_dqn_base.py b/reagent/workflow/model_managers/discrete_dqn_base.py new file mode 100644 index 00000000..b540f00e --- /dev/null +++ b/reagent/workflow/model_managers/discrete_dqn_base.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 + +import logging +from typing import Dict, List, Optional, Tuple + +from reagent import types as rlt +from reagent.core.dataclasses import dataclass, field +from reagent.core.types import ( + Dataset, + ModelFeatureConfigProvider__Union, + PreprocessingOptions, + ReaderOptions, + RewardOptions, + RLTrainingOutput, + RLTrainingReport, + TableSpec, +) +from reagent.evaluation.evaluator import Evaluator, get_metrics_to_score +from reagent.gym.policies.policy import Policy +from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model +from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler +from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer +from reagent.models.base import ModelBase +from reagent.models.model_feature_config_provider import RawModelFeatureConfigProvider +from reagent.parameters import EvaluationParameters, NormalizationData, NormalizationKey +from reagent.preprocessing.batch_preprocessor import ( + BatchPreprocessor, + DiscreteDqnBatchPreprocessor, +) +from reagent.preprocessing.preprocessor import Preprocessor +from reagent.preprocessing.types import InputColumn +from reagent.workflow.data_fetcher import query_data +from reagent.workflow.identify_types_flow import identify_normalization_parameters +from reagent.workflow.model_managers.model_manager import ModelManager +from reagent.workflow.reporters.discrete_dqn_reporter import DiscreteDQNReporter +from reagent.workflow.utils import train_and_evaluate_generic + + +logger = logging.getLogger(__name__) + + +@dataclass +class DiscreteDQNBase(ModelManager): + target_action_distribution: Optional[List[float]] = None + state_feature_config_provider: ModelFeatureConfigProvider__Union = field( + # pyre-fixme[28]: Unexpected keyword argument `raw`. + # pyre-fixme[28]: Unexpected keyword argument `raw`. + default_factory=lambda: ModelFeatureConfigProvider__Union( + raw=RawModelFeatureConfigProvider(float_feature_infos=[]) + ) + ) + eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters) + preprocessing_options: Optional[PreprocessingOptions] = None + reader_options: Optional[ReaderOptions] = None + + def __post_init_post_parse__(self): + super().__init__() + self._metrics_to_score = None + self._q_network: Optional[ModelBase] = None + + def create_policy(self, serving: bool) -> Policy: + """ Create an online DiscreteDQN Policy from env. """ + if serving: + return create_predictor_policy_from_model(self.build_serving_module()) + else: + sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature) + # pyre-fixme[16]: `RLTrainer` has no attribute `q_network`. + scorer = discrete_dqn_scorer(self.trainer.q_network) + return Policy(scorer=scorer, sampler=sampler) + + @property + def state_feature_config(self) -> rlt.ModelFeatureConfig: + return self.state_feature_config_provider.value.get_model_feature_config() + + @property + def metrics_to_score(self) -> List[str]: + assert self._reward_options is not None + if self._metrics_to_score is None: + # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_metrics_to_score`. + # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_metrics_to_score`. + self._metrics_to_score = get_metrics_to_score( + # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. + # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. + self._reward_options.metric_reward_values + ) + return self._metrics_to_score + + @property + def should_generate_eval_dataset(self) -> bool: + return self.eval_parameters.calc_cpe_in_training + + @property + def required_normalization_keys(self) -> List[str]: + return [NormalizationKey.STATE] + + def run_feature_identification( + self, input_table_spec: TableSpec + ) -> Dict[str, NormalizationData]: + preprocessing_options = self.preprocessing_options or PreprocessingOptions() + logger.info("Overriding whitelist_features") + state_features = [ + ffi.feature_id for ffi in self.state_feature_config.float_feature_infos + ] + preprocessing_options = preprocessing_options._replace( + whitelist_features=state_features + ) + return { + NormalizationKey.STATE: NormalizationData( + dense_normalization_parameters=identify_normalization_parameters( + input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options + ) + ) + } + + def query_data( + self, + input_table_spec: TableSpec, + sample_range: Optional[Tuple[float, float]], + reward_options: RewardOptions, + ) -> Dataset: + return query_data( + input_table_spec=input_table_spec, + discrete_action=True, + actions=self.action_names, + include_possible_actions=True, + sample_range=sample_range, + custom_reward_expression=reward_options.custom_reward_expression, + multi_steps=self.multi_steps, + gamma=self.rl_parameters.gamma, + ) + + @property + def multi_steps(self) -> Optional[int]: + return self.rl_parameters.multi_steps + + def build_batch_preprocessor(self) -> BatchPreprocessor: + state_preprocessor = Preprocessor( + self.state_normalization_data.dense_normalization_parameters, + use_gpu=self.use_gpu, + ) + return DiscreteDqnBatchPreprocessor( + num_actions=len(self.action_names), + state_preprocessor=state_preprocessor, + use_gpu=self.use_gpu, + ) + + def train( + self, + train_dataset: Dataset, + eval_dataset: Optional[Dataset], + num_epochs: int, + reader_options: ReaderOptions, + ) -> RLTrainingOutput: + """ + Train the model + + Returns partially filled RLTrainingOutput. + The field that should not be filled are: + - output_path + """ + reporter = DiscreteDQNReporter( + self.trainer_param.actions, + target_action_distribution=self.target_action_distribution, + ) + # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. + self.trainer.add_observer(reporter) + + evaluator = Evaluator( + self.action_names, + self.rl_parameters.gamma, + self.trainer, + metrics_to_score=self.metrics_to_score, + ) + # pyre-fixme[16]: `Evaluator` has no attribute `add_observer`. + evaluator.add_observer(reporter) + + batch_preprocessor = self.build_batch_preprocessor() + train_and_evaluate_generic( + train_dataset, + eval_dataset, + # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. + # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. + self.trainer, + num_epochs, + self.use_gpu, + batch_preprocessor, + reporter, + evaluator, + reader_options=self.reader_options, + ) + # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. + training_report = RLTrainingReport.make_union_instance( + reporter.generate_training_report() + ) + return RLTrainingOutput(training_report=training_report) diff --git a/reagent/model_managers/model_based/__init__.py b/reagent/workflow/model_managers/model_based/__init__.py similarity index 100% rename from reagent/model_managers/model_based/__init__.py rename to reagent/workflow/model_managers/model_based/__init__.py diff --git a/reagent/model_managers/model_based/cross_entropy_method.py b/reagent/workflow/model_managers/model_based/cross_entropy_method.py similarity index 76% rename from reagent/model_managers/model_based/cross_entropy_method.py rename to reagent/workflow/model_managers/model_based/cross_entropy_method.py index dd9f1669..3efee16c 100644 --- a/reagent/model_managers/model_based/cross_entropy_method.py +++ b/reagent/workflow/model_managers/model_based/cross_entropy_method.py @@ -1,26 +1,20 @@ #!/usr/bin/env python3 import logging -from typing import Dict, Optional +from typing import Optional import numpy as np -import reagent.core.types as rlt +import reagent.types as rlt import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions from reagent.gym.policies.policy import Policy -from reagent.model_managers.model_based.world_model import WorldModel -from reagent.model_managers.world_model_base import WorldModelBase from reagent.models.cem_planner import CEMPlannerNetwork -from reagent.parameters import ( - CEMTrainerParameters, - NormalizationData, - NormalizationKey, - param_hash, -) +from reagent.parameters import CEMTrainerParameters, param_hash from reagent.preprocessing.identify_types import CONTINUOUS_ACTION from reagent.preprocessing.normalization import get_num_output_features from reagent.training.cem_trainer import CEMTrainer +from reagent.workflow.model_managers.model_based.world_model import WorldModel +from reagent.workflow.model_managers.world_model_base import WorldModelBase logger = logging.getLogger(__name__) @@ -60,27 +54,31 @@ class CrossEntropyMethod(WorldModelBase): def create_policy(self, serving: bool = False) -> Policy: return CEMPolicy(self.cem_planner_network, self.discrete_action) - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> CEMTrainer: + def build_trainer(self) -> CEMTrainer: world_model_manager: WorldModel = WorldModel( trainer_param=self.trainer_param.mdnrnn ) + world_model_manager.initialize_trainer( + self.use_gpu, + self.reward_options, + # pyre-fixme[6]: Expected `Dict[str, + # reagent.parameters.NormalizationData]` for 3rd param but got + # `Optional[typing.Dict[str, reagent.parameters.NormalizationData]]`. + # pyre-fixme[6]: Expected `Dict[str, + # reagent.parameters.NormalizationData]` for 3rd param but got + # `Optional[typing.Dict[str, reagent.parameters.NormalizationData]]`. + self._normalization_data_map, + ) world_model_trainers = [ - world_model_manager.build_trainer( - use_gpu, normalization_data_map, reward_options - ) + world_model_manager.build_trainer() for _ in range(self.trainer_param.num_world_models) ] world_model_nets = [trainer.memory_network for trainer in world_model_trainers] terminal_effective = self.trainer_param.mdnrnn.not_terminal_loss_weight > 0 - action_normalization_parameters = normalization_data_map[ - NormalizationKey.ACTION - ].dense_normalization_parameters + action_normalization_parameters = ( + self.action_normalization_data.dense_normalization_parameters + ) sorted_action_norm_vals = list(action_normalization_parameters.values()) discrete_action = sorted_action_norm_vals[0].feature_type != CONTINUOUS_ACTION action_upper_bounds, action_lower_bounds = None, None @@ -100,14 +98,10 @@ class CrossEntropyMethod(WorldModelBase): num_elites=self.trainer_param.num_elites, plan_horizon_length=self.trainer_param.plan_horizon_length, state_dim=get_num_output_features( - normalization_data_map[ - NormalizationKey.STATE - ].dense_normalization_parameters + self.state_normalization_data.dense_normalization_parameters ), action_dim=get_num_output_features( - normalization_data_map[ - NormalizationKey.ACTION - ].dense_normalization_parameters + self.action_normalization_data.dense_normalization_parameters ), discrete_action=discrete_action, terminal_effective=terminal_effective, @@ -131,12 +125,10 @@ class CrossEntropyMethod(WorldModelBase): cem_planner_network=cem_planner_network, world_model_trainers=world_model_trainers, parameters=self.trainer_param, - use_gpu=use_gpu, + use_gpu=self.use_gpu, ) - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ diff --git a/reagent/model_managers/model_based/seq2reward_model.py b/reagent/workflow/model_managers/model_based/seq2reward_model.py similarity index 70% rename from reagent/model_managers/model_based/seq2reward_model.py rename to reagent/workflow/model_managers/model_based/seq2reward_model.py index 7eebbe32..b48e8a96 100644 --- a/reagent/model_managers/model_based/seq2reward_model.py +++ b/reagent/workflow/model_managers/model_based/seq2reward_model.py @@ -1,22 +1,15 @@ #!/usr/bin/env python3 import logging -from typing import Dict import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.model_managers.world_model_base import WorldModelBase from reagent.net_builder.unions import ValueNetBuilder__Union from reagent.net_builder.value.fully_connected import FullyConnected from reagent.net_builder.value.seq2reward_rnn import Seq2RewardNetBuilder -from reagent.parameters import ( - NormalizationData, - NormalizationKey, - Seq2RewardTrainerParameters, - param_hash, -) +from reagent.parameters import Seq2RewardTrainerParameters, param_hash from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer +from reagent.workflow.model_managers.world_model_base import WorldModelBase logger = logging.getLogger(__name__) @@ -43,26 +36,19 @@ class Seq2RewardModel(WorldModelBase): default_factory=Seq2RewardTrainerParameters ) - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> Seq2RewardTrainer: + def build_trainer(self) -> Seq2RewardTrainer: seq2reward_network = self.net_builder.value.build_value_network( - normalization_data_map[NormalizationKey.STATE] + self.state_normalization_data ) - if use_gpu: + if self.use_gpu: seq2reward_network = seq2reward_network.cuda() return Seq2RewardTrainer( seq2reward_network=seq2reward_network, params=self.trainer_param ) - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ diff --git a/reagent/model_managers/model_based/world_model.py b/reagent/workflow/model_managers/model_based/world_model.py similarity index 62% rename from reagent/model_managers/model_based/world_model.py rename to reagent/workflow/model_managers/model_based/world_model.py index e644ea5e..56b47256 100644 --- a/reagent/model_managers/model_based/world_model.py +++ b/reagent/workflow/model_managers/model_based/world_model.py @@ -1,21 +1,14 @@ #!/usr/bin/env python3 import logging -from typing import Dict import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.model_managers.world_model_base import WorldModelBase from reagent.models.world_model import MemoryNetwork -from reagent.parameters import ( - MDNRNNTrainerParameters, - NormalizationData, - NormalizationKey, - param_hash, -) +from reagent.parameters import MDNRNNTrainerParameters, param_hash from reagent.preprocessing.normalization import get_num_output_features from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer +from reagent.workflow.model_managers.world_model_base import WorldModelBase logger = logging.getLogger(__name__) @@ -32,31 +25,22 @@ class WorldModel(WorldModelBase): def __post_init_post_parse__(self): super().__post_init_post_parse__() - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> MDNRNNTrainer: + def build_trainer(self) -> MDNRNNTrainer: memory_network = MemoryNetwork( state_dim=get_num_output_features( - normalization_data_map[ - NormalizationKey.STATE - ].dense_normalization_parameters + self.state_normalization_data.dense_normalization_parameters ), action_dim=self.trainer_param.action_dim, num_hiddens=self.trainer_param.hidden_size, num_hidden_layers=self.trainer_param.num_hidden_layers, num_gaussians=self.trainer_param.num_gaussians, ) - if use_gpu: + if self.use_gpu: memory_network = memory_network.cuda() return MDNRNNTrainer(memory_network=memory_network, params=self.trainer_param) - def build_serving_module( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: """ Returns a TorchScript predictor module """ diff --git a/reagent/workflow/model_managers/model_manager.py b/reagent/workflow/model_managers/model_manager.py new file mode 100644 index 00000000..a697ea07 --- /dev/null +++ b/reagent/workflow/model_managers/model_manager.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +import abc +import dataclasses +import logging +import time +from typing import Dict, List, Optional, Tuple + +import torch +from reagent.core.registry_meta import RegistryMeta +from reagent.core.types import ( + Dataset, + OssReaderOptions, + ReaderOptions, + ResourceOptions, + RewardOptions, + RLTrainingOutput, + TableSpec, +) +from reagent.parameters import NormalizationData +from reagent.tensorboardX import summary_writer_context +from reagent.training.trainer import Trainer +from torch.utils.tensorboard import SummaryWriter + + +logger = logging.getLogger(__name__) + + +class ModelManager(metaclass=RegistryMeta): + """ + ModelManager manages how to train models. + + Each type of models can have their own config type, implemented as + `config_type()` class method. `__init__()` of the concrete class must take + this type. + + ModelManager abstracts over common phases of training, i.e.,: + 1. `run_feature_identification()` defines how to derive feature preprocessing + parameters from given data. + 2. `query_data()` massages the input table into the format expected by the trainer + 3. `initialize_trainer()` creates the trainer + 4. `train()` + 5. `build_serving_module()` builds the module for prediction + 6. `save_tainer()` saves the trainer for warmstarting + """ + + def __init__(self): + super().__init__() + # initialization is delayed to `initialize_trainer()` + self._normalization_data_map: Optional[Dict[str, NormalizationData]] = None + self._reward_options: Optional[RewardOptions] = None + self._trainer: Optional[Trainer] = None + self._use_gpu: Optional[bool] = None + + @property + def use_gpu(self) -> bool: + assert ( + self._use_gpu is not None + ), "Call initialize_trainer() to set the value first" + # pyre-fixme[7]: Expected `bool` but got `Optional[bool]`. + # pyre-fixme[7]: Expected `bool` but got `Optional[bool]`. + return self._use_gpu + + @property + def reward_options(self) -> RewardOptions: + assert self._reward_options is not None + # pyre-fixme[7]: Expected `RewardOptions` but got `Optional[RewardOptions]`. + # pyre-fixme[7]: Expected `RewardOptions` but got `Optional[RewardOptions]`. + return self._reward_options + + @reward_options.setter + def reward_options(self, reward_options: RewardOptions): + assert self._reward_options is None + self._reward_options = reward_options + + @abc.abstractmethod + def run_feature_identification( + self, input_table_spec: TableSpec + ) -> Dict[str, NormalizationData]: + """ + Derive preprocessing parameters from data. The keys of the dict should + match the keys from `required_normalization_keys()` + """ + pass + + @property + @abc.abstractmethod + def required_normalization_keys(self) -> List[str]: + """ Get the normalization keys required for current instance """ + pass + + def __getattr__(self, attr): + """ Get X_normalization_data by attribute """ + normalization_data_suffix = "_normalization_data" + if attr.endswith(normalization_data_suffix): + assert self._normalization_data_map is not None, ( + f"Trying to access {attr} but normalization_data_map " + "has not been set via `initialize_trainer`." + ) + normalization_key = attr[: -len(normalization_data_suffix)] + normalization_data = self._normalization_data_map.get( + normalization_key, None + ) + if normalization_data is None: + raise AttributeError( + f"normalization key `{normalization_key}` is unavailable. " + f"Available keys are: {self._normalization_data_map.keys()}." + ) + return normalization_data + + raise AttributeError( + f"attr {attr} not available {type(self)} (subclass of ModelManager)." + ) + + @property + @abc.abstractmethod + def should_generate_eval_dataset(self) -> bool: + pass + + @abc.abstractmethod + def query_data( + self, + input_table_spec: TableSpec, + sample_range: Optional[Tuple[float, float]], + reward_options: RewardOptions, + ) -> Dataset: + """ + Massage input table into the format expected by the trainer + """ + pass + + @property + def trainer(self) -> Trainer: + assert self._trainer is not None, "Call initialize_trainer() first" + # pyre-fixme[7]: Expected `Trainer` but got `Optional[Trainer]`. + # pyre-fixme[7]: Expected `Trainer` but got `Optional[Trainer]`. + return self._trainer + + def initialize_trainer( + self, + use_gpu: bool, + reward_options: RewardOptions, + normalization_data_map: Dict[str, NormalizationData], + warmstart_path: Optional[str] = None, + ) -> Trainer: + """ + Initialize the trainer. Subclass should not override this. Instead, + subclass should implement `required_normalization_keys()` and + `build_trainer()`. + """ + assert self._trainer is None, "Trainer was intialized" + self._use_gpu = use_gpu + self.reward_options = reward_options + # validate that we have all the required keys + for normalization_key in self.required_normalization_keys: + normalization_data = normalization_data_map.get(normalization_key, None) + assert normalization_data is not None, ( + f"NormalizationData for {normalization_key} " + "is required but not provided." + ) + # NOTE: Don't need this check in the future, for non-dense parameters + assert normalization_data.dense_normalization_parameters is not None, ( + f"Dense normalization parameters for " + f"{normalization_key} is not provided." + ) + assert ( + self._normalization_data_map is None + ), "Cannot reset self._normalization_data_map" + self._normalization_data_map = normalization_data_map + self._trainer = self.build_trainer() + if warmstart_path is not None: + trainer_state = torch.load(warmstart_path) + # pyre-fixme[16]: `Optional` has no attribute `load_state_dict`. + # pyre-fixme[16]: `Optional` has no attribute `load_state_dict`. + self._trainer.load_state_dict(trainer_state) + # pyre-fixme[7]: Expected `Trainer` but got `Optional[Trainer]`. + # pyre-fixme[7]: Expected `Trainer` but got `Optional[Trainer]`. + return self._trainer + + @abc.abstractmethod + def build_trainer(self) -> Trainer: + """ + Implement this to build the trainer, given the config + """ + pass + + def train_workflow( + self, + train_dataset: Dataset, + eval_dataset: Optional[Dataset], + normalization_data_map: Dict[str, NormalizationData], + num_epochs: int, + use_gpu: bool, + parent_workflow_id: int, + child_workflow_id: int, + reward_options: Optional[RewardOptions] = None, + reader_options: Optional[ReaderOptions] = None, + resource_options: Optional[ResourceOptions] = None, + warmstart_path: Optional[str] = None, + ) -> RLTrainingOutput: + writer = SummaryWriter() + logger.info("TensorBoard logging location is: {}".format(writer.log_dir)) + + warmstart_input_path = warmstart_path or None + self.initialize_trainer( + use_gpu=use_gpu, + # pyre-fixme[6]: Expected `RewardOptions` for 2nd param but got + # `Optional[RewardOptions]`. + # pyre-fixme[6]: Expected `RewardOptions` for 2nd param but got + # `Optional[RewardOptions]`. + reward_options=reward_options, + normalization_data_map=normalization_data_map, + warmstart_path=warmstart_input_path, + ) + + if not reader_options: + reader_options = OssReaderOptions() + + with summary_writer_context(writer): + train_output = self.train( + train_dataset, eval_dataset, num_epochs, reader_options + ) + + # TODO: make this a parameter + torchscript_output_path = f"model_{round(time.time())}.torchscript" + serving_module = self.build_serving_module() + torch.jit.save(serving_module, torchscript_output_path) + logger.info(f"Saved torchscript model to {torchscript_output_path}") + return dataclasses.replace(train_output, output_path=torchscript_output_path) + + @abc.abstractmethod + def train( + self, + train_dataset: Dataset, + eval_dataset: Optional[Dataset], + num_epochs: int, + reader_options: ReaderOptions, + ) -> RLTrainingOutput: + """ + Train the model + """ + pass + + @abc.abstractmethod + def build_serving_module(self) -> torch.nn.Module: + """ + Returns TorchScript module to be used in predictor + """ + pass + + def save_trainer(self, output_path: str) -> None: + """ + Save the trainer for warmstarting/checkpointing. + """ + trainer_state = self.trainer.state_dict() + torch.save(trainer_state, output_path) diff --git a/reagent/model_managers/parametric/__init__.py b/reagent/workflow/model_managers/parametric/__init__.py similarity index 100% rename from reagent/model_managers/parametric/__init__.py rename to reagent/workflow/model_managers/parametric/__init__.py diff --git a/reagent/workflow/model_managers/parametric/parametric_dqn.py b/reagent/workflow/model_managers/parametric/parametric_dqn.py new file mode 100644 index 00000000..59eefcc3 --- /dev/null +++ b/reagent/workflow/model_managers/parametric/parametric_dqn.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +import logging + +import torch +from reagent.core.dataclasses import dataclass, field +from reagent.net_builder.parametric_dqn.fully_connected import FullyConnected +from reagent.net_builder.unions import ParametricDQNNetBuilder__Union +from reagent.parameters import param_hash +from reagent.training import ParametricDQNTrainer, ParametricDQNTrainerParameters +from reagent.workflow.model_managers.parametric_dqn_base import ParametricDQNBase + + +logger = logging.getLogger(__name__) + + +@dataclass +class ParametricDQN(ParametricDQNBase): + __hash__ = param_hash + + trainer_param: ParametricDQNTrainerParameters = field( + default_factory=ParametricDQNTrainerParameters + ) + net_builder: ParametricDQNNetBuilder__Union = field( + # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. + # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. + default_factory=lambda: ParametricDQNNetBuilder__Union( + FullyConnected=FullyConnected() + ) + ) + + def __post_init_post_parse__(self): + super().__post_init_post_parse__() + self.rl_parameters = self.trainer_param.rl + + def build_trainer(self) -> ParametricDQNTrainer: + net_builder = self.net_builder.value + # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`. + # pyre-fixme[16]: `ParametricDQN` has no attribute `_q_network`. + self._q_network = net_builder.build_q_network( + self.state_normalization_data, self.action_normalization_data + ) + # Metrics + reward + reward_output_dim = len(self.metrics_to_score) + 1 + reward_network = net_builder.build_q_network( + self.state_normalization_data, + self.action_normalization_data, + output_dim=reward_output_dim, + ) + + if self.use_gpu: + self._q_network = self._q_network.cuda() + reward_network = reward_network.cuda() + + q_network_target = self._q_network.get_target_network() + return ParametricDQNTrainer( + q_network=self._q_network, + q_network_target=q_network_target, + reward_network=reward_network, + use_gpu=self.use_gpu, + # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute + # `asdict`. + # pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute + # `asdict`. + **self.trainer_param.asdict(), + ) + + def build_serving_module(self) -> torch.nn.Module: + net_builder = self.net_builder.value + assert self._q_network is not None + return net_builder.build_serving_module( + self._q_network, + self.state_normalization_data, + self.action_normalization_data, + ) diff --git a/reagent/model_managers/parametric_dqn_base.py b/reagent/workflow/model_managers/parametric_dqn_base.py similarity index 66% rename from reagent/model_managers/parametric_dqn_base.py rename to reagent/workflow/model_managers/parametric_dqn_base.py index fc309f8d..cd13ff24 100644 --- a/reagent/model_managers/parametric_dqn_base.py +++ b/reagent/workflow/model_managers/parametric_dqn_base.py @@ -3,23 +3,21 @@ import logging from typing import Dict, List, Optional, Tuple -import reagent.core.types as rlt +import reagent.types as rlt from reagent.core.dataclasses import dataclass, field -from reagent.core.rl_training_output import RLTrainingOutput from reagent.core.types import ( Dataset, PreprocessingOptions, ReaderOptions, RewardOptions, + RLTrainingOutput, TableSpec, ) -from reagent.data_fetchers.data_fetcher import DataFetcher from reagent.evaluation.evaluator import get_metrics_to_score from reagent.gym.policies.policy import Policy from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler from reagent.gym.policies.scorers.discrete_scorer import parametric_dqn_scorer -from reagent.model_managers.model_manager import ModelManager from reagent.models.base import ModelBase from reagent.parameters import EvaluationParameters, NormalizationData, NormalizationKey from reagent.preprocessing.batch_preprocessor import BatchPreprocessor @@ -28,8 +26,8 @@ from reagent.preprocessing.normalization import ( get_num_output_features, ) from reagent.preprocessing.types import InputColumn -from reagent.reporting.parametric_dqn_reporter import ParametricDQNReporter -from reagent.training.parametric_dqn_trainer import ParametricDQNTrainer +from reagent.workflow.identify_types_flow import identify_normalization_parameters +from reagent.workflow.model_managers.model_manager import ModelManager logger = logging.getLogger(__name__) @@ -60,32 +58,32 @@ class ParametricDQNBase(ModelManager): "Please set action whitelist features in action_float_features field of " "config instead" ) + self._state_preprocessing_options = self.state_preprocessing_options + self._action_preprocessing_options = self.action_preprocessing_options + self._q_network: Optional[ModelBase] = None + self._metrics_to_score: Optional[List[str]] = None + + def create_policy(self, serving: bool) -> Policy: + """ Create an online DiscreteDQN Policy from env. """ - def create_policy(self, trainer: ParametricDQNTrainer) -> Policy: # FIXME: this only works for one-hot encoded actions - action_dim = trainer.num_gym_actions - sampler = SoftmaxActionSampler(temperature=self.trainer_param.rl.temperature) - scorer = parametric_dqn_scorer( - max_num_actions=action_dim, q_network=trainer.q_network + action_dim = get_num_output_features( + self.action_normalization_data.dense_normalization_parameters ) - return Policy(scorer=scorer, sampler=sampler) - - def create_serving_policy( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> Policy: - # FIXME: this only works for one-hot encoded actions - action_dim = trainer.num_gym_actions - return create_predictor_policy_from_model( - self.build_serving_module(normalization_data_map, trainer), - max_num_actions=action_dim, - ) - - def get_reporter(self): - return ParametricDQNReporter() + if serving: + return create_predictor_policy_from_model( + self.build_serving_module(), max_num_actions=action_dim + ) + else: + sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature) + scorer = parametric_dqn_scorer( + max_num_actions=action_dim, q_network=self._q_network + ) + return Policy(scorer=scorer, sampler=sampler) @property def should_generate_eval_dataset(self) -> bool: - return False # Parametric DQN CPE not supported yet + return self.eval_parameters.calc_cpe_in_training @property def state_feature_config(self) -> rlt.ModelFeatureConfig: @@ -96,11 +94,11 @@ class ParametricDQNBase(ModelManager): return get_feature_config(self.action_float_features) def run_feature_identification( - self, data_fetcher: DataFetcher, input_table_spec: TableSpec + self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: # Run state feature identification state_preprocessing_options = ( - self.state_preprocessing_options or PreprocessingOptions() + self._state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos @@ -110,13 +108,13 @@ class ParametricDQNBase(ModelManager): whitelist_features=state_features ) - state_normalization_parameters = data_fetcher.identify_normalization_parameters( + state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) # Run action feature identification action_preprocessing_options = ( - self.action_preprocessing_options or PreprocessingOptions() + self._action_preprocessing_options or PreprocessingOptions() ) action_features = [ ffi.feature_id for ffi in self.action_feature_config.float_feature_infos @@ -125,7 +123,7 @@ class ParametricDQNBase(ModelManager): action_preprocessing_options = action_preprocessing_options._replace( whitelist_features=action_features ) - action_normalization_parameters = data_fetcher.identify_normalization_parameters( + action_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.ACTION, action_preprocessing_options ) return { @@ -143,24 +141,26 @@ class ParametricDQNBase(ModelManager): def query_data( self, - data_fetcher: DataFetcher, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, ) -> Dataset: raise NotImplementedError() - def metrics_to_score(self, reward_options: RewardOptions) -> List[str]: - return get_metrics_to_score(reward_options.metric_reward_values) + @property + def metrics_to_score(self) -> List[str]: + assert self.reward_options is not None + if self._metrics_to_score is None: + # pyre-fixme[16]: `ParametricDQNBase` has no attribute `_metrics_to_score`. + # pyre-fixme[16]: `ParametricDQNBase` has no attribute `_metrics_to_score`. + self._metrics_to_score = get_metrics_to_score( + # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. + # pyre-fixme[16]: `Optional` has no attribute `metric_reward_values`. + self._reward_options.metric_reward_values + ) + return self._metrics_to_score - def build_batch_preprocessor( - self, - reader_options: ReaderOptions, - use_gpu: bool, - batch_size: int, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> BatchPreprocessor: + def build_batch_preprocessor(self) -> BatchPreprocessor: raise NotImplementedError() def train( diff --git a/reagent/model_managers/ranking/__init__.py b/reagent/workflow/model_managers/ranking/__init__.py similarity index 100% rename from reagent/model_managers/ranking/__init__.py rename to reagent/workflow/model_managers/ranking/__init__.py diff --git a/reagent/model_managers/ranking/slate_q.py b/reagent/workflow/model_managers/ranking/slate_q.py similarity index 61% rename from reagent/model_managers/ranking/slate_q.py rename to reagent/workflow/model_managers/ranking/slate_q.py index cfa203b3..72372d35 100644 --- a/reagent/model_managers/ranking/slate_q.py +++ b/reagent/workflow/model_managers/ranking/slate_q.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import logging -from typing import Dict, Optional +from typing import Optional import torch from reagent.core.dataclasses import dataclass, field -from reagent.core.types import RewardOptions -from reagent.model_managers.slate_q_base import SlateQBase from reagent.models.base import ModelBase from reagent.net_builder.parametric_dqn.fully_connected import FullyConnected from reagent.net_builder.unions import ParametricDQNNetBuilder__Union -from reagent.parameters import NormalizationData, NormalizationKey, param_hash +from reagent.parameters import param_hash from reagent.training import SlateQTrainer, SlateQTrainerParameters +from reagent.workflow.model_managers.slate_q_base import SlateQBase logger = logging.getLogger(__name__) @@ -21,6 +20,11 @@ logger = logging.getLogger(__name__) class SlateQ(SlateQBase): __hash__ = param_hash + slate_size: int = -1 + num_candidates: int = -1 + trainer_param: SlateQTrainerParameters = field( + default_factory=SlateQTrainerParameters + ) net_builder: ParametricDQNNetBuilder__Union = field( # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. @@ -37,41 +41,32 @@ class SlateQ(SlateQBase): assert ( self.num_candidates > 0 ), f"Please set valid num_candidates (currently {self.num_candidates})" + self._q_network: Optional[ModelBase] = None + self.eval_parameters = self.trainer_param.evaluation - def build_trainer( - self, - use_gpu: bool, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> SlateQTrainer: + def build_trainer(self) -> SlateQTrainer: net_builder = self.net_builder.value # pyre-fixme[16]: `SlateQ` has no attribute `_q_network`. # pyre-fixme[16]: `SlateQ` has no attribute `_q_network`. - q_network = net_builder.build_q_network( - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ITEM], + self._q_network = net_builder.build_q_network( + self.state_normalization_data, self.item_normalization_data ) - if use_gpu: - q_network = q_network.cuda() + if self.use_gpu: + self._q_network = self._q_network.cuda() - q_network_target = q_network.get_target_network() + q_network_target = self._q_network.get_target_network() return SlateQTrainer( - q_network=q_network, + q_network=self._q_network, q_network_target=q_network_target, - use_gpu=use_gpu, + use_gpu=self.use_gpu, # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`. # pyre-fixme[16]: `SlateQTrainerParameters` has no attribute `asdict`. **self.trainer_param.asdict(), ) - def build_serving_module( - self, - normalization_data_map: Dict[str, NormalizationData], - trainer: SlateQTrainer, - ) -> torch.nn.Module: + def build_serving_module(self) -> torch.nn.Module: net_builder = self.net_builder.value + assert self._q_network is not None return net_builder.build_serving_module( - trainer.q_network, - normalization_data_map[NormalizationKey.STATE], - normalization_data_map[NormalizationKey.ITEM], + self._q_network, self.state_normalization_data, self.item_normalization_data ) diff --git a/reagent/model_managers/slate_q_base.py b/reagent/workflow/model_managers/slate_q_base.py similarity index 67% rename from reagent/model_managers/slate_q_base.py rename to reagent/workflow/model_managers/slate_q_base.py index df5a3ae1..e12b84c7 100644 --- a/reagent/model_managers/slate_q_base.py +++ b/reagent/workflow/model_managers/slate_q_base.py @@ -3,28 +3,26 @@ import logging from typing import Dict, List, Optional, Tuple -import reagent.core.types as rlt -from reagent.core.dataclasses import dataclass, field -from reagent.core.rl_training_output import RLTrainingOutput +import reagent.types as rlt +from reagent.core.dataclasses import dataclass from reagent.core.types import ( Dataset, PreprocessingOptions, ReaderOptions, RewardOptions, + RLTrainingOutput, TableSpec, ) -from reagent.data_fetchers.data_fetcher import DataFetcher from reagent.gym.policies.policy import Policy from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model from reagent.gym.policies.samplers.top_k_sampler import TopKSampler from reagent.gym.policies.scorers.slate_q_scorer import slate_q_scorer -from reagent.model_managers.model_manager import ModelManager +from reagent.models.base import ModelBase from reagent.parameters import NormalizationData, NormalizationKey -from reagent.preprocessing.batch_preprocessor import BatchPreprocessor from reagent.preprocessing.normalization import get_feature_config from reagent.preprocessing.types import InputColumn -from reagent.reporting.ranking_model_reporter import RankingModelReporter -from reagent.training import SlateQTrainerParameters +from reagent.workflow.identify_types_flow import identify_normalization_parameters +from reagent.workflow.model_managers.model_manager import ModelManager logger = logging.getLogger(__name__) @@ -32,17 +30,12 @@ logger = logging.getLogger(__name__) @dataclass class SlateQBase(ModelManager): - slate_feature_id: int = -1 - slate_score_id: Tuple[int, int] = (-1, -1) + slate_feature_id: int + slate_score_id: Tuple[int, int] item_preprocessing_options: Optional[PreprocessingOptions] = None state_preprocessing_options: Optional[PreprocessingOptions] = None state_float_features: Optional[List[Tuple[int, str]]] = None item_float_features: Optional[List[Tuple[int, str]]] = None - slate_size: int = -1 - num_candidates: int = -1 - trainer_param: SlateQTrainerParameters = field( - default_factory=SlateQTrainerParameters - ) def __post_init_post_parse__(self): super().__init__() @@ -64,23 +57,24 @@ class SlateQBase(ModelManager): self.item_preprocessing_options is None or self.item_preprocessing_options.sequence_feature_id is None ), "Please set slate_feature_id field of config instead" + self._state_preprocessing_options = self.state_preprocessing_options + self._item_preprocessing_options = self.item_preprocessing_options + self._q_network: Optional[ModelBase] = None self.eval_parameters = self.trainer_param.evaluation - def create_policy(self, trainer) -> Policy: - scorer = slate_q_scorer( - num_candidates=self.num_candidates, q_network=trainer.q_network - ) - sampler = TopKSampler(k=self.slate_size) - return Policy(scorer=scorer, sampler=sampler) - - def create_serving_policy( - self, normalization_data_map: Dict[str, NormalizationData], trainer - ) -> Policy: - return create_predictor_policy_from_model( - self.build_serving_module(normalization_data_map, trainer), - max_num_actions=self.num_candidates, - slate_size=self.slate_size, - ) + def create_policy(self, serving: bool) -> Policy: + if serving: + return create_predictor_policy_from_model( + self.build_serving_module(), + max_num_actions=self.num_candidates, + slate_size=self.slate_size, + ) + else: + scorer = slate_q_scorer( + num_candidates=self.num_candidates, q_network=self._q_network + ) + sampler = TopKSampler(k=self.slate_size) + return Policy(scorer=scorer, sampler=sampler) @property def should_generate_eval_dataset(self) -> bool: @@ -94,14 +88,11 @@ class SlateQBase(ModelManager): def item_feature_config(self) -> rlt.ModelFeatureConfig: return get_feature_config(self.item_float_features) - def get_reporter(self): - return RankingModelReporter() - def run_feature_identification( - self, data_fetcher: DataFetcher, input_table_spec: TableSpec + self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: state_preprocessing_options = ( - self.state_preprocessing_options or PreprocessingOptions() + self._state_preprocessing_options or PreprocessingOptions() ) state_features = [ ffi.feature_id for ffi in self.state_feature_config.float_feature_infos @@ -110,11 +101,11 @@ class SlateQBase(ModelManager): state_preprocessing_options = state_preprocessing_options._replace( whitelist_features=state_features ) - state_normalization_parameters = data_fetcher.identify_normalization_parameters( + state_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_FEATURES, state_preprocessing_options ) item_preprocessing_options = ( - self.item_preprocessing_options or PreprocessingOptions() + self._item_preprocessing_options or PreprocessingOptions() ) item_features = [ ffi.feature_id for ffi in self.item_feature_config.float_feature_infos @@ -123,7 +114,7 @@ class SlateQBase(ModelManager): item_preprocessing_options = item_preprocessing_options._replace( whitelist_features=item_features, sequence_feature_id=self.slate_feature_id ) - item_normalization_parameters = data_fetcher.identify_normalization_parameters( + item_normalization_parameters = identify_normalization_parameters( input_table_spec, InputColumn.STATE_SEQUENCE_FEATURES, item_preprocessing_options, @@ -141,19 +132,8 @@ class SlateQBase(ModelManager): def required_normalization_keys(self) -> List[str]: return [NormalizationKey.STATE, NormalizationKey.ITEM] - def build_batch_preprocessor( - self, - reader_options: ReaderOptions, - use_gpu: bool, - batch_size: int, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> BatchPreprocessor: - raise NotImplementedError("Write for OSS") - def query_data( self, - data_fetcher: DataFetcher, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, diff --git a/reagent/model_managers/union.py b/reagent/workflow/model_managers/union.py similarity index 86% rename from reagent/model_managers/union.py rename to reagent/workflow/model_managers/union.py index d944777a..5e002fd5 100644 --- a/reagent/model_managers/union.py +++ b/reagent/workflow/model_managers/union.py @@ -4,7 +4,7 @@ """ Register all ModelManagers. Must import them before filling union. """ from reagent.core.tagged_union import TaggedUnion -from reagent.model_managers.model_manager import ModelManager +from reagent.workflow.model_managers.model_manager import ModelManager from .actor_critic import * # noqa from .discrete import * # noqa diff --git a/reagent/model_managers/world_model_base.py b/reagent/workflow/model_managers/world_model_base.py similarity index 67% rename from reagent/model_managers/world_model_base.py rename to reagent/workflow/model_managers/world_model_base.py index 7d3228b9..a9b415f3 100644 --- a/reagent/model_managers/world_model_base.py +++ b/reagent/workflow/model_managers/world_model_base.py @@ -4,13 +4,17 @@ import logging from typing import Dict, List, Optional, Tuple from reagent.core.dataclasses import dataclass -from reagent.core.rl_training_output import RLTrainingOutput -from reagent.core.types import Dataset, ReaderOptions, RewardOptions, TableSpec -from reagent.data_fetchers.data_fetcher import DataFetcher -from reagent.model_managers.model_manager import ModelManager +from reagent.core.types import ( + Dataset, + ReaderOptions, + RewardOptions, + RLTrainingOutput, + TableSpec, +) +from reagent.gym.policies.policy import Policy from reagent.parameters import NormalizationData, NormalizationKey from reagent.preprocessing.batch_preprocessor import BatchPreprocessor -from reagent.reporting.world_model_reporter import WorldModelReporter +from reagent.workflow.model_managers.model_manager import ModelManager logger = logging.getLogger(__name__) @@ -25,6 +29,10 @@ class WorldModelBase(ModelManager): def normalization_key(cls) -> str: raise NotImplementedError() + def create_policy(self) -> Policy: + """ Create a WorldModel Policy from env. """ + raise NotImplementedError() + @property def should_generate_eval_dataset(self) -> bool: return False @@ -34,30 +42,19 @@ class WorldModelBase(ModelManager): return [NormalizationKey.STATE, NormalizationKey.ACTION] def run_feature_identification( - self, data_fetcher: DataFetcher, input_table_spec: TableSpec + self, input_table_spec: TableSpec ) -> Dict[str, NormalizationData]: raise NotImplementedError() - def get_reporter(self): - return WorldModelReporter() - def query_data( self, - data_fetcher: DataFetcher, input_table_spec: TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: RewardOptions, ) -> Dataset: raise NotImplementedError() - def build_batch_preprocessor( - self, - reader_options: ReaderOptions, - use_gpu: bool, - batch_size: int, - normalization_data_map: Dict[str, NormalizationData], - reward_options: RewardOptions, - ) -> BatchPreprocessor: + def build_batch_preprocessor(self) -> BatchPreprocessor: raise NotImplementedError() def train( diff --git a/reagent/workflow/reporters/actor_critic_reporter.py b/reagent/workflow/reporters/actor_critic_reporter.py new file mode 100644 index 00000000..dc7d2788 --- /dev/null +++ b/reagent/workflow/reporters/actor_critic_reporter.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +import itertools +import logging +from collections import OrderedDict + +from reagent.core import aggregators as agg +from reagent.core.observers import IntervalAggregatingObserver, ValueListObserver +from reagent.workflow.reporters.reporter_base import ReporterBase +from reagent.workflow.training_reports import ActorCriticTrainingReport + + +logger = logging.getLogger(__name__) + + +class ActorCriticReporter(ReporterBase): + def __init__(self, report_interval: int = 100): + self.value_list_observers = {"cpe_results": ValueListObserver("cpe_details")} + self.aggregating_observers = OrderedDict( + (name, IntervalAggregatingObserver(report_interval, aggregator)) + for name, aggregator in itertools.chain( + [ + ("td_loss", agg.MeanAggregator("td_loss")), + ("reward_loss", agg.MeanAggregator("reward_loss")), + ("recent_rewards", agg.RecentValuesAggregator("logged_rewards")), + ], + [ + ( + f"{key}_tb", + agg.TensorBoardHistogramAndMeanAggregator(key, log_key), + ) + for key, log_key in [ + ("td_loss", "td_loss"), + ("reward_loss", "reward_loss"), + ("logged_propensities", "propensities/logged"), + ("logged_rewards", "reward/logged"), + ] + ], + ) + ) + super().__init__(self.value_list_observers, self.aggregating_observers) + + # TODO: write this for OSS + def generate_training_report(self) -> ActorCriticTrainingReport: + return ActorCriticTrainingReport() diff --git a/reagent/workflow/reporters/discrete_dqn_reporter.py b/reagent/workflow/reporters/discrete_dqn_reporter.py new file mode 100644 index 00000000..908dae06 --- /dev/null +++ b/reagent/workflow/reporters/discrete_dqn_reporter.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +import itertools +import logging +from collections import OrderedDict +from typing import List, Optional + +import torch +from reagent.core import aggregators as agg +from reagent.core.observers import IntervalAggregatingObserver, ValueListObserver +from reagent.workflow.reporters.reporter_base import ReporterBase +from reagent.workflow.training_reports import DQNTrainingReport + + +logger = logging.getLogger(__name__) + + +class DiscreteDQNReporter(ReporterBase): + def __init__( + self, + actions: List[str], + report_interval: int = 100, + target_action_distribution: Optional[List[float]] = None, + recent_window_size: int = 100, + ): + self.value_list_observers = {"cpe_results": ValueListObserver("cpe_details")} + self.aggregating_observers = OrderedDict( + (name, IntervalAggregatingObserver(report_interval, aggregator)) + for name, aggregator in itertools.chain( + [ + ("td_loss", agg.MeanAggregator("td_loss")), + ("reward_loss", agg.MeanAggregator("reward_loss")), + ( + "model_values", + agg.FunctionsByActionAggregator( + "model_values", + actions, + {"mean": torch.mean, "std": torch.std}, + ), + ), + ( + "logged_action", + agg.ActionCountAggregator("logged_actions", actions), + ), + ( + "model_action", + agg.ActionCountAggregator("model_action_idxs", actions), + ), + ("recent_rewards", agg.RecentValuesAggregator("logged_rewards")), + ], + [ + ( + f"{key}_tb", + agg.TensorBoardActionCountAggregator(key, title, actions), + ) + for key, title in [ + ("logged_actions", "logged"), + ("model_action_idxs", "model"), + ] + ], + [ + ( + f"{key}_tb", + agg.TensorBoardHistogramAndMeanAggregator(key, log_key), + ) + for key, log_key in [ + ("td_loss", "td_loss"), + ("reward_loss", "reward_loss"), + ("logged_propensities", "propensities/logged"), + ("logged_rewards", "reward/logged"), + ] + ], + [ + ( + f"{key}_tb", + agg.TensorBoardActionHistogramAndMeanAggregator( + key, category, title, actions + ), + ) + for key, category, title in [ + ("model_propensities", "propensities", "model"), + ("model_rewards", "reward", "model"), + ("model_values", "value", "model"), + ] + ], + ) + ) + super().__init__(self.value_list_observers, self.aggregating_observers) + self.target_action_distribution = target_action_distribution + self.recent_window_size = recent_window_size + + # TODO: write this for OSS + def generate_training_report(self) -> DQNTrainingReport: + cpe_results = self.value_list_observers["cpe_results"].values # noqa + return DQNTrainingReport() diff --git a/reagent/workflow/reporters/parametric_dqn_reporter.py b/reagent/workflow/reporters/parametric_dqn_reporter.py new file mode 100644 index 00000000..bd0c9d82 --- /dev/null +++ b/reagent/workflow/reporters/parametric_dqn_reporter.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +import itertools +import logging +from collections import OrderedDict + +from reagent.core import aggregators as agg +from reagent.core.observers import IntervalAggregatingObserver, ValueListObserver +from reagent.workflow.reporters.reporter_base import ReporterBase +from reagent.workflow.training_reports import ParametricDQNTrainingReport + + +logger = logging.getLogger(__name__) + + +class ParametricDQNReporter(ReporterBase): + def __init__(self, report_interval: int = 100): + self.value_list_observers = {"cpe_results": ValueListObserver("cpe_details")} + self.aggregating_observers = OrderedDict( + (name, IntervalAggregatingObserver(report_interval, aggregator)) + for name, aggregator in itertools.chain( + [ + ("td_loss", agg.MeanAggregator("td_loss")), + ("reward_loss", agg.MeanAggregator("reward_loss")), + ("recent_rewards", agg.RecentValuesAggregator("logged_rewards")), + ], + [ + ( + f"{key}_tb", + agg.TensorBoardHistogramAndMeanAggregator(key, log_key), + ) + for key, log_key in [ + ("td_loss", "td_loss"), + ("reward_loss", "reward_loss"), + ("logged_propensities", "propensities/logged"), + ("logged_rewards", "reward/logged"), + ] + ], + ) + ) + super().__init__(self.value_list_observers, self.aggregating_observers) + + # TODO: write this for OSS + def generate_training_report(self) -> ParametricDQNTrainingReport: + return ParametricDQNTrainingReport() diff --git a/reagent/workflow/reporters/reporter_base.py b/reagent/workflow/reporters/reporter_base.py new file mode 100644 index 00000000..b5f54d92 --- /dev/null +++ b/reagent/workflow/reporters/reporter_base.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +import abc +import logging +from typing import Dict + +from reagent.core.observers import ( + CompositeObserver, + EpochEndObserver, + IntervalAggregatingObserver, + ValueListObserver, +) +from reagent.workflow.result_registries import TrainingReport + + +logger = logging.getLogger(__name__) + + +class ReporterBase(CompositeObserver): + def __init__( + self, + value_list_observers: Dict[str, ValueListObserver], + aggregating_observers: Dict[str, IntervalAggregatingObserver], + ): + epoch_end_observer = EpochEndObserver(self._epoch_end_callback) + self.last_epoch_end_num_batches: int = 0 + self.num_data_points_per_epoch = None + super().__init__( + list(value_list_observers.values()) + # pyre-fixme[6]: Expected `List[ValueListObserver]` for 1st param but + # got `List[IntervalAggregatingObserver]`. + + list(aggregating_observers.values()) + # pyre-fixme[6]: Expected `List[ValueListObserver]` for 1st param but + # got `List[EpochEndObserver]`. + + [epoch_end_observer] + ) + + def _epoch_end_callback(self, epoch: int): + logger.info(f"Epoch {epoch} ended") + + for observer in self.aggregating_observers.values(): + observer.flush() + + num_batches = len(self.td_loss.values) - self.last_epoch_end_num_batches + self.last_epoch_end_num_batches = len(self.td_loss.values) + if self.num_data_points_per_epoch is None: + self.num_data_points_per_epoch = num_batches + else: + assert self.num_data_points_per_epoch == num_batches + logger.info(f"Epoch {epoch} contains {num_batches} aggregated data points") + + def __getattr__(self, key: str): + if key in self.value_list_observers: + return self.value_list_observers[key] + return self.aggregating_observers[key].aggregator + + # TODO: write this for OSS + @abc.abstractmethod + def generate_training_report(self) -> TrainingReport: + pass diff --git a/reagent/reporting/result_registries.py b/reagent/workflow/result_registries.py similarity index 86% rename from reagent/reporting/result_registries.py rename to reagent/workflow/result_registries.py index 6b1f3343..ba72b56a 100644 --- a/reagent/reporting/result_registries.py +++ b/reagent/workflow/result_registries.py @@ -5,6 +5,10 @@ from reagent.core.dataclasses import dataclass from reagent.core.registry_meta import RegistryMeta +class TrainingReport(metaclass=RegistryMeta): + pass + + @dataclass class PublishingResult(metaclass=RegistryMeta): success: bool diff --git a/reagent/workflow/spark_utils.py b/reagent/workflow/spark_utils.py index 9afa037f..2c5a63ba 100644 --- a/reagent/workflow/spark_utils.py +++ b/reagent/workflow/spark_utils.py @@ -3,9 +3,8 @@ import logging import os import pprint -import tempfile from os.path import abspath, dirname, join -from typing import Dict +from typing import Dict, Optional import reagent @@ -34,29 +33,6 @@ ReAgent/ SPARK_JAR = join(dirname(reagent.__file__), os.pardir, SPARK_JAR_FROM_ROOT_DIR) -def create_and_return(path: str): - try: - os.mkdir(path) - except FileExistsError: - pass - return path - - -def create_and_return(path: str): - try: - os.mkdir(path) - except FileExistsError: - pass - return path - - -SPARK_DIRECTORY = "file://" + abspath( - tempfile.mkdtemp( - suffix=None, - prefix=None, - dir=create_and_return(join(tempfile.gettempdir(), "reagent_spark_warehouse")), - ) -) DEFAULT_SPARK_CONFIG = { "spark.app.name": "ReAgent", "spark.sql.session.timeZone": "UTC", @@ -65,7 +41,7 @@ DEFAULT_SPARK_CONFIG = { # use as many worker threads as possible on machine "spark.master": "local[*]", # default local warehouse for Hive - "spark.sql.warehouse.dir": SPARK_DIRECTORY, + "spark.sql.warehouse.dir": abspath("spark-warehouse"), # Set shuffle partitions to a low number, e.g. <= cores * 2 to speed # things up, otherwise the tests will use the default 200 partitions # and it will take a lot more time to complete @@ -78,16 +54,12 @@ DEFAULT_SPARK_CONFIG = { } -TEST_SPARK_SESSION = None - - -def get_spark_session(config: Dict[str, str] = DEFAULT_SPARK_CONFIG): - if TEST_SPARK_SESSION is not None: - return TEST_SPARK_SESSION +def get_spark_session(config: Optional[Dict[str, str]] = DEFAULT_SPARK_CONFIG): logger.info(f"Building with config: \n{pprint.pformat(config)}") spark = SparkSession.builder.enableHiveSupport() - for k, v in config.items(): - spark = spark.config(k, v) + if config is not None: + for k, v in config.items(): + spark = spark.config(k, v) spark = spark.getOrCreate() spark.sparkContext.setLogLevel("ERROR") return spark diff --git a/reagent/workflow/training.py b/reagent/workflow/training.py index 1a1cc1b4..c414b0c0 100644 --- a/reagent/workflow/training.py +++ b/reagent/workflow/training.py @@ -4,21 +4,21 @@ import dataclasses import logging from typing import Dict, NamedTuple, Optional, Tuple -import reagent.register # noqa import torch -from reagent.core.rl_training_output import RLTrainingOutput from reagent.core.types import ( OssReaderOptions, + ReaderOptions, RecurringPeriod, ResourceOptions, RewardOptions, + RLTrainingOutput, TableSpec, ) -from reagent.model_managers.union import ModelManager__Union from reagent.parameters import NormalizationData from reagent.publishers.union import ModelPublisher__Union -from reagent.runners.oss_batch_runner import OssBatchRunner from reagent.validators.union import ModelValidator__Union +from reagent.workflow.env import get_workflow_id +from reagent.workflow.model_managers.union import ModelManager__Union logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ def identify_and_train_network( num_epochs: int, use_gpu: Optional[bool] = None, reward_options: Optional[RewardOptions] = None, - reader_options: Optional[OssReaderOptions] = None, + reader_options: Optional[ReaderOptions] = None, resource_options: Optional[ResourceOptions] = None, warmstart_path: Optional[str] = None, validator: Optional[ModelValidator__Union] = None, @@ -40,8 +40,7 @@ def identify_and_train_network( use_gpu: bool = torch.cuda.is_available() manager = model.value - batch_runner = OssBatchRunner(use_gpu, manager, reward_options, {}, warmstart_path) - normalization_data_map = batch_runner.run_feature_identification(input_table_spec) + normalization_data_map = manager.run_feature_identification(input_table_spec) return query_and_train( input_table_spec, @@ -91,10 +90,7 @@ def get_sample_range( ) assert table_sample is not None, error_msg assert eval_table_sample is not None, error_msg - assert table_sample > 0, error_msg - assert eval_table_sample > 0, error_msg assert (eval_table_sample + table_sample) <= (100.0 + 1e-3), error_msg - assert (eval_table_sample + table_sample) >= (100.0 - 1e-3), error_msg return TrainEvalSampleRanges( train_sample_range=(0.0, table_sample), @@ -109,7 +105,7 @@ def query_and_train( num_epochs: int, use_gpu: bool, reward_options: Optional[RewardOptions] = None, - reader_options: Optional[OssReaderOptions] = None, + reader_options: Optional[ReaderOptions] = None, resource_options: Optional[ResourceOptions] = None, warmstart_path: Optional[str] = None, validator: Optional[ModelValidator__Union] = None, @@ -117,40 +113,50 @@ def query_and_train( parent_workflow_id: Optional[int] = None, recurring_period: Optional[RecurringPeriod] = None, ) -> RLTrainingOutput: + child_workflow_id = get_workflow_id() + if parent_workflow_id is None: + parent_workflow_id = child_workflow_id + logger.info("Starting query") reward_options = reward_options or RewardOptions() reader_options = reader_options or OssReaderOptions() resource_options = resource_options or ResourceOptions() manager = model.value - batch_runner = OssBatchRunner( - use_gpu, manager, reward_options, normalization_data_map, warmstart_path - ) - child_workflow_id = batch_runner.get_workflow_id() - if parent_workflow_id is None: - parent_workflow_id = child_workflow_id calc_cpe_in_training = manager.should_generate_eval_dataset sample_range_output = get_sample_range(input_table_spec, calc_cpe_in_training) - train_dataset, eval_dataset = batch_runner.query( + train_dataset = manager.query_data( input_table_spec=input_table_spec, - reader_options=reader_options, - resource_options=resource_options, + sample_range=sample_range_output.train_sample_range, + reward_options=reward_options, ) + eval_dataset = None + if calc_cpe_in_training: + eval_dataset = manager.query_data( + input_table_spec=input_table_spec, + sample_range=sample_range_output.eval_sample_range, + reward_options=reward_options, + ) logger.info("Starting training") - results = batch_runner.train( + results = manager.train_workflow( train_dataset, eval_dataset, normalization_data_map, num_epochs, - reader_options=reader_options, + use_gpu, parent_workflow_id=parent_workflow_id, + child_workflow_id=child_workflow_id, + reward_options=reward_options, + reader_options=reader_options, resource_options=resource_options, warmstart_path=warmstart_path, - validator=validator, ) + if validator is not None: + results = run_validator(validator, results) + if publisher is not None: results = run_publisher( publisher, diff --git a/reagent/workflow/training_reports.py b/reagent/workflow/training_reports.py new file mode 100644 index 00000000..3f605b9a --- /dev/null +++ b/reagent/workflow/training_reports.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from typing import Optional + +from reagent.core.dataclasses import dataclass +from reagent.evaluation.cpe import CpeEstimate +from reagent.workflow.result_registries import TrainingReport + + +@dataclass +class DQNTrainingReport(TrainingReport): + __registry_name__ = "dqn_report" + + td_loss: Optional[float] = None + mc_loss: Optional[float] = None + reward_ips: Optional[CpeEstimate] = None + reward_dm: Optional[CpeEstimate] = None + reward_dr: Optional[CpeEstimate] = None + value_sequential_dr: Optional[CpeEstimate] = None + value_weighted_dr: Optional[CpeEstimate] = None + value_magic_dr: Optional[CpeEstimate] = None + + +@dataclass +class ActorCriticTrainingReport(TrainingReport): + __registry_name__ = "actor_critic_report" + + +@dataclass +class ParametricDQNTrainingReport(TrainingReport): + __registry_name__ = "parametric_dqn_report" diff --git a/reagent/workflow/utils.py b/reagent/workflow/utils.py new file mode 100644 index 00000000..7dac7a53 --- /dev/null +++ b/reagent/workflow/utils.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import logging +from typing import Dict, List, Optional + +import reagent.types as rlt + +# pyre-fixme[21]: Could not find `petastorm`. +from petastorm import make_batch_reader + +# pyre-fixme[21]: Could not find module `petastorm.pytorch`. +# pyre-fixme[21]: Could not find module `petastorm.pytorch`. +from petastorm.pytorch import DataLoader, decimal_friendly_collate +from reagent.core.tracker import Observer +from reagent.core.types import Dataset, OssReaderOptions, ReaderOptions +from reagent.evaluation.evaluation_data_page import EvaluationDataPage +from reagent.evaluation.evaluator import Evaluator +from reagent.preprocessing.batch_preprocessor import BatchPreprocessor +from reagent.torch_utils import dict_to_tensor +from reagent.training import RLTrainer, SACTrainer, TD3Trainer +from reagent.workflow.spark_utils import get_spark_session +from reagent.workflow_utils.iterators import DataLoaderWrapper, EpochIterator + + +logger = logging.getLogger(__name__) + + +def get_table_row_count(parquet_url: str): + spark = get_spark_session() + return spark.read.parquet(parquet_url).count() + + +def collate_and_preprocess(batch_preprocessor: BatchPreprocessor, use_gpu: bool): + """ Helper for Petastorm's DataLoader to preprocess. + TODO(kaiwenw): parallelize preprocessing by using transform of Petastorm reader + Should pin memory and preprocess in reader and convert to gpu in collate_fn. + """ + + def collate_fn(batch_list: List[Dict]): + batch = decimal_friendly_collate(batch_list) + preprocessed_batch = batch_preprocessor(batch) + if use_gpu: + preprocessed_batch = preprocessed_batch.cuda() + return preprocessed_batch + + return collate_fn + + +def get_petastorm_dataloader( + dataset: Dataset, + batch_size: int, + batch_preprocessor: BatchPreprocessor, + use_gpu: bool, + reader_options: ReaderOptions, +): + """ get petastorm loader for dataset (with preprocessor) """ + data_reader = make_batch_reader( + dataset.parquet_url, + num_epochs=1, + reader_pool_type=reader_options.petastorm_reader_pool_type, + ) + # NOTE: must be wrapped by DataLoaderWrapper to call __exit__() on end of epoch + return DataLoader( + data_reader, + batch_size=batch_size, + collate_fn=collate_and_preprocess( + batch_preprocessor=batch_preprocessor, use_gpu=use_gpu + ), + ) + + +def gather_eval_data( + trainer: RLTrainer, + eval_dataset: Dataset, + batch_preprocessor: BatchPreprocessor, + use_gpu: bool, + reader_options: ReaderOptions, +) -> EvaluationDataPage: + """ Sorts, computes logged values and validates the EvaluationDataPage """ + if isinstance(trainer, (SACTrainer, TD3Trainer)): + raise NotImplementedError("TODO: Implement CPE for continuous algos") + assert ( + trainer.calc_cpe_in_training + ), "this function should only be called when this is true." + + # first read the eval_dataset as EvaluationDataPages + device = "cuda" if use_gpu else "cpu" + eval_data = None + with make_batch_reader( + eval_dataset.parquet_url, + num_epochs=1, + reader_pool_type=reader_options.petastorm_reader_pool_type, + ) as reader: + for batch in reader: + assert rlt.isinstance_namedtuple(batch) + tensor_batch = dict_to_tensor(batch._asdict(), device=device) + tdp: rlt.PreprocessedTrainingBatch = batch_preprocessor(tensor_batch) + edp = EvaluationDataPage.create_from_training_batch(tdp, trainer) + if eval_data is None: + eval_data = edp + else: + eval_data = eval_data.append(edp) + + eval_data = eval_data.sort() + eval_data = eval_data.compute_values(trainer.gamma) + eval_data.validate() + return eval_data + + +def train_and_evaluate_generic( + train_dataset: Dataset, + eval_dataset: Optional[Dataset], + trainer: RLTrainer, + num_epochs: int, + use_gpu: bool, + batch_preprocessor: BatchPreprocessor, + reporter: Observer, + evaluator: Evaluator, + reader_options: Optional[ReaderOptions] = None, +) -> None: + reader_options = reader_options or OssReaderOptions() + epoch_iterator = EpochIterator(num_epochs=num_epochs) + train_dataset_size = get_table_row_count(train_dataset.parquet_url) + # pyre-fixme[16]: `EpochIterator` has no attribute `add_observer`. + for epoch in epoch_iterator.add_observer(reporter): + logger.info(f"Starting training epoch {epoch}.") + dataloader = get_petastorm_dataloader( + dataset=train_dataset, + # pyre-fixme[6]: Expected `int` for 2nd param but got `Optional[int]`. + batch_size=trainer.minibatch_size, + batch_preprocessor=batch_preprocessor, + use_gpu=use_gpu, + reader_options=reader_options, + ) + dataloader_wrapper = DataLoaderWrapper( + dataloader=dataloader, dataloader_size=train_dataset_size + ) + for batch in dataloader_wrapper: + trainer.train(batch) + + if eval_dataset is not None: + eval_data = gather_eval_data( + trainer=trainer, + eval_dataset=eval_dataset, + batch_preprocessor=batch_preprocessor, + use_gpu=use_gpu, + reader_options=reader_options, + ) + # evaluator passes cpe_details to reporter via notify_observers + evaluator.evaluate_post_training(eval_data) diff --git a/reagent/workflow_utils/iterators.py b/reagent/workflow_utils/iterators.py index 4d6fcf53..41b424b0 100644 --- a/reagent/workflow_utils/iterators.py +++ b/reagent/workflow_utils/iterators.py @@ -4,6 +4,7 @@ import logging from collections import OrderedDict +from reagent.core.tracker import observable from reagent.tensorboardX import SummaryWriterContext from torch.utils.data import IterableDataset from tqdm import tqdm @@ -13,6 +14,21 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +@observable(epoch_start=int, epoch_end=int) +class EpochIterator: + def __init__(self, num_epochs: int): + assert num_epochs > 0 + self.num_epochs = num_epochs + + def __iter__(self): + SummaryWriterContext._reset_globals() + for epoch in range(self.num_epochs): + self.notify_observers(epoch_start=epoch) + yield epoch + self.notify_observers(epoch_end=epoch) + # TODO: flush at end of epoch? + + def get_batch_size(batch): try: return batch.batch_size() @@ -27,12 +43,7 @@ def get_batch_size(batch): class DataLoaderWrapper(IterableDataset): - def __init__( - self, - dataloader: IterableDataset, - dataloader_size: int, - post_dataloader_preprocessor=None, - ): + def __init__(self, dataloader: IterableDataset, dataloader_size: int): """ Wraps around an Iterable Dataloader to report progress bars and increase global step of SummaryWriter. At last iteration, will call dataloader.__exit__ if needed (e.g. Petastorm DataLoader). @@ -45,13 +56,10 @@ class DataLoaderWrapper(IterableDataset): self.dataloader = dataloader self.dataloader_iter = iter(dataloader) self.dataloader_size = dataloader_size - self.post_dataloader_preprocessor = post_dataloader_preprocessor def __iter__(self): t = tqdm(total=self.dataloader_size, desc="iterating dataloader") for batch in self.dataloader: - if self.post_dataloader_preprocessor is not None: - batch = self.post_dataloader_preprocessor(batch) batch_size = get_batch_size(batch) yield batch t.update(batch_size) diff --git a/reagent/workflow_utils/page_handler.py b/reagent/workflow_utils/page_handler.py new file mode 100644 index 00000000..91b27f25 --- /dev/null +++ b/reagent/workflow_utils/page_handler.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +import time +from collections import OrderedDict +from typing import Dict, List, Optional + +import numpy as np +import torch +from reagent.core.tracker import observable +from reagent.evaluation.cpe import CpeDetails +from reagent.evaluation.evaluation_data_page import EvaluationDataPage +from reagent.tensorboardX import SummaryWriterContext +from reagent.training.sac_trainer import SACTrainer +from reagent.training.td3_trainer import TD3Trainer +from reagent.types import MemoryNetworkInput, PreprocessedTrainingBatch + + +logger = logging.getLogger(__name__) + + +class PageHandler: + def __init__(self, trainer_or_evaluator): + self.trainer_or_evaluator = trainer_or_evaluator + self.results: List[Dict] = [] + self.epoch = 0 + + def refresh_results(self) -> None: + self.results: List[Dict] = [] + + def get_loss(self, loss_name="loss"): + """ See usage in get_mean_loss """ + return [float(result[loss_name]) for result in self.results] + + def get_mean_loss(self, loss_name="loss", axis=None): + """ + Get the average of a certain type of loss + + :param loss_name: possible loss names: + For world model: + 'loss' (referring to total loss), + 'bce' (loss for predicting not_terminal), + 'gmm' (loss for next state prediction), + 'mse' (loss for predicting reward) + For ranking model: + 'pg' (policy gradient loss) + 'baseline' (the baseline model's loss, usually for fitting V(s)) + 'kendall_tau' (kendall_tau coefficient between advantage and log_probs, + used in evaluation page handlers) + 'kendaull_tau_p_value' (the p-value for kendall_tau test, used in + evaluation page handlers) + :param axis: axis to perform mean function. + """ + return np.mean([result[loss_name] for result in self.results], axis=axis) + + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + raise NotImplementedError() + + def finish(self) -> None: + pass + + def set_epoch(self, epoch) -> None: + self.epoch = epoch + + +# TODO: remove. +# Use new DataLoaderWrapper & EpochIterator (see OSS train_and_evaluate_generic) +@observable(epoch_end=int) +class TrainingPageHandler(PageHandler): + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + SummaryWriterContext.increase_global_step() + self.trainer_or_evaluator.train(tdp) + + def finish(self) -> None: + # pyre-fixme[16]: `TrainingPageHandler` has no attribute `notify_observers`. + self.notify_observers(epoch_end=self.epoch) + self.trainer_or_evaluator.loss_reporter.flush() + self.epoch += 1 + + +# TODO: remove. +# Use new DataLoaderWrapper & EpochIterator (see OSS train_and_evaluate_generic) +class EvaluationPageHandler(PageHandler): + def __init__(self, trainer, evaluator, reporter): + self.trainer = trainer + self.evaluator = evaluator + self.evaluation_data: Optional[EvaluationDataPage] = None + self.reporter = reporter + self.results: List[CpeDetails] = [] + + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + if not self.trainer.calc_cpe_in_training: + return + # TODO: Perhaps we can make an RLTrainer param to check if continuous? + if isinstance(self.trainer, (SACTrainer, TD3Trainer)): + # TODO: Implement CPE for continuous algos + edp = None + else: + edp = EvaluationDataPage.create_from_training_batch(tdp, self.trainer) + if self.evaluation_data is None: + self.evaluation_data = edp + else: + # pyre-fixme[16]: `Optional` has no attribute `append`. + self.evaluation_data = self.evaluation_data.append(edp) + + def finish(self) -> None: + if self.evaluation_data is None: + return + # Making sure the data is sorted for CPE + # pyre-fixme[16]: `Optional` has no attribute `sort`. + self.evaluation_data = self.evaluation_data.sort() + # pyre-fixme[16]: `Optional` has no attribute `compute_values`. + self.evaluation_data = self.evaluation_data.compute_values(self.trainer.gamma) + # pyre-fixme[16]: `Optional` has no attribute `validate`. + self.evaluation_data.validate() + start_time = time.time() + evaluation_details = self.evaluator.evaluate_post_training(self.evaluation_data) + self.reporter.report(evaluation_details) + self.results.append(evaluation_details) + logger.info("CPE evaluation took {} seconds.".format(time.time() - start_time)) + self.evaluation_data = None + + def get_last_cpe_results(self): + if len(self.results) == 0: + return CpeDetails() + return self.results[-1] + + +class WorldModelTrainingPageHandler(PageHandler): + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + losses = self.trainer_or_evaluator.train(tdp) + self.results.append(losses) + + +class WorldModelRandomTrainingPageHandler(PageHandler): + """ Train a baseline model based on randomly shuffled data """ + + # pyre-fixme[14]: `handle` overrides method defined in `PageHandler` inconsistently. + def handle(self, training_input: MemoryNetworkInput) -> None: + _, batch_size, _ = training_input.next_state.float_features.size() + + tdp = MemoryNetworkInput( + state=training_input.state, + action=training_input.action, + time_diff=torch.ones_like(training_input.reward), + # shuffle the data + next_state=training_input.next_state._replace( + float_features=training_input.next_state.float_features[ + :, torch.randperm(batch_size), : + ] + ), + reward=training_input.reward[:, torch.randperm(batch_size)], + not_terminal=training_input.not_terminal[ # type: ignore + :, torch.randperm(batch_size) + ], + step=None, + ) + losses = self.trainer_or_evaluator.train(tdp) + self.results.append(losses) + + +class WorldModelEvaluationPageHandler(PageHandler): + # pyre-fixme[14]: `handle` overrides method defined in `PageHandler` inconsistently. + def handle(self, tdp: MemoryNetworkInput) -> None: + losses = self.trainer_or_evaluator.evaluate(tdp) + self.results.append(losses) + + +@observable(epoch_end=int) +class RankingTrainingPageHandler(PageHandler): + def __init__(self, trainer) -> None: + super().__init__(trainer) + self.policy_gradient_loss: List[float] = [] + self.baseline_loss: List[float] = [] + self.per_seq_probs: List[float] = [] + + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + res_dict = self.trainer_or_evaluator.train(tdp) + self.results.append(res_dict) + + def finish(self): + self.notify_observers(epoch_end=self.epoch) + result_template = self.results[0] + if result_template and "ips_rl_loss" in result_template: + self.policy_gradient_loss.append( + float(self.get_mean_loss(loss_name="ips_rl_loss")) + ) + if result_template and "baseline_loss" in result_template: + self.baseline_loss.append( + float(self.get_mean_loss(loss_name="baseline_loss")) + ) + if result_template and "per_seq_probs" in result_template: + self.per_seq_probs.append( + float(self.get_mean_loss(loss_name="per_seq_probs")) + ) + self.refresh_results() + + +@observable(epoch_end=int) +class RankingEvaluationPageHandler(PageHandler): + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + self.trainer_or_evaluator.evaluate(tdp) + + def finish(self): + eval_res = self.trainer_or_evaluator.evaluate_post_training() + self.notify_observers(epoch_end=self.epoch) # type: ignore + self.results.append(eval_res) + + +class RewardNetTrainingPageHandler(PageHandler): + def __init__(self, trainer): + super().__init__(trainer) + self.mse_loss = [] + + def handle(self, tdp: PreprocessedTrainingBatch) -> None: + mse_loss = self.trainer_or_evaluator.train(tdp) + self.results.append({"mse": mse_loss.cpu().numpy()}) + + def finish(self): + self.mse_loss.append(float(self.get_mean_loss(loss_name="mse"))) + self.refresh_results() + + +# TODO: remove. +# Use new DataLoaderWrapper & EpochIterator (see OSS train_and_evaluate_generic) +def get_actual_minibatch_size(batch, minibatch_size_preset): + try: + return batch.batch_size() + except AttributeError: + pass + if isinstance(batch, OrderedDict): + first_key = next(iter(batch.keys())) + batch_size = len(batch[first_key]) + else: + raise NotImplementedError() + return batch_size + + +# TODO: remove. +# Use new DataLoaderWrapper & EpochIterator (see OSS train_and_evaluate_generic) +def feed_pages( + data_loader, + dataset_num_rows, + epoch, + minibatch_size, + use_gpu, + page_handler, + # used before batch is handled by page_handler + post_data_loader_preprocessor=None, +): + num_rows_processed = 0 + num_rows_to_process_for_progress_tick = max(1, dataset_num_rows // 100) + last_percent_reported = -1 + + for batch in data_loader: + if post_data_loader_preprocessor: + batch = post_data_loader_preprocessor(batch) + + if use_gpu: + batch = batch.cuda() + + batch_size = get_actual_minibatch_size(batch, minibatch_size) + num_rows_processed += batch_size + + if ( + num_rows_processed // num_rows_to_process_for_progress_tick + ) != last_percent_reported: + last_percent_reported = ( + num_rows_processed // num_rows_to_process_for_progress_tick + ) + logger.info( + "Feeding page. Epoch: {}, Epoch Progress: {} of {} ({}%)".format( + epoch, + num_rows_processed, + dataset_num_rows, + (100 * num_rows_processed) // dataset_num_rows, + ) + ) + + page_handler.handle(batch) + + page_handler.finish()