Back out recent refactor

Summary:
Need more tests before landing the refactor diffs: D22702504 (https://github.com/facebookresearch/ReAgent/commit/1b470c489d19c33beab88b8ea2e79843d4d31f28), D23123762 (https://github.com/facebookresearch/ReAgent/commit/76829287265bc39f879f3bc1d946a1374c5e1141), D23124179 (https://github.com/facebookresearch/ReAgent/commit/b28f84aa013be00194508f52498160592cb37e9d), D23219012 (https://github.com/facebookresearch/ReAgent/commit/e404c5772ea4118105c2eb136ca96ad5ca8e01db)

Back out to a version based on D23155753.

Check our team diff history: https://fburl.com/diffs/ppsgazgj

Reviewed By: kittipatv

Differential Revision: D23270626

fbshipit-source-id: 14653066bb3924a987a54650a51241895b321c8e
This commit is contained in:
Zhengxing Chen
2020-08-21 15:58:02 -07:00
committed by Facebook GitHub Bot
parent e404c5772e
commit 0d294b11e5
171 changed files with 3445 additions and 3863 deletions
+1 -1
View File
@@ -64,7 +64,7 @@ ml.rl.training.imitator\_training module
ml.rl.training.loss\_reporter module
------------------------------------
.. automodule:: ml.rl.training.rl_reporter
.. automodule:: ml.rl.training.loss_reporter
:members:
:undoc-members:
:show-inheritance:
+2
View File
@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+19 -125
View File
@@ -3,100 +3,21 @@
import logging
from collections import deque
from typing import Any, Callable, Deque, Dict, List, Optional
from typing import Callable, Deque, Dict, List, Optional
import numpy as np
import torch
from reagent.core.tracker import Aggregator
from reagent.tensorboardX import SummaryWriterContext
logger = logging.getLogger(__name__)
class Aggregator:
def __init__(self, key: str, interval: Optional[int] = None):
super().__init__()
self.key = key
self.iteration = 0
self.interval = interval
self.aggregate_epoch = interval is None
self.intermediate_values: List[Any] = []
def update(self, key: str, value):
self.intermediate_values.append(value)
self.iteration += 1
# pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`.
if self.interval and self.iteration % self.interval == 0:
logger.info(
f"Interval Agg. Update: {self.key}; iteration {self.iteration}; "
f"aggregator: {self.__class__.__name__}"
)
self(self.key, self.intermediate_values)
self.intermediate_values = []
def finish_epoch(self):
# We need to reset iteration here to avoid aggregating on the same data multiple
# times
logger.info(
f"Epoch finished. Flushing: {self.key}; "
f"aggregator: {self.__class__.__name__}; points: {len(self.intermediate_values)}"
)
self.iteration = 0
if self.aggregate_epoch:
self(self.key, self.intermediate_values)
self.intermediate_values = []
def __call__(self, key: str, values):
assert key == self.key, f"Got {key}; expected {self.key}"
self.aggregate(values)
def aggregate(self, intermediate_values):
pass
def get_recent(self, count):
raise NotImplementedError()
def get_all(self):
raise NotImplementedError()
class AppendAggregator(Aggregator):
def __init__(self, key: str, interval: Optional[int] = None):
super().__init__(key, interval)
self.values = []
def __call__(self, key: str, values):
assert key == self.key, f"Got {key}; expected {self.key}"
self.aggregate(values)
def aggregate(self, intermediate_values):
self.values.extend(intermediate_values)
def get_recent(self, count):
if len(self.values) == 0:
return []
return self.values[-count:]
def get_all(self):
return self.values
class TensorAggregator(Aggregator):
def __call__(self, key: str, values, interval: Optional[int] = None):
if len(values) == 0:
return super().__call__(key, torch.tensor([0.0]))
def __call__(self, key: str, values):
# Ensure that tensor is on cpu before aggregation.
reshaped_values = []
for value in values:
if isinstance(value, list):
reshaped_values.append(torch.tensor(value))
elif not hasattr(value, "size"):
reshaped_values.append(torch.tensor(value).unsqueeze(0))
elif len(value.size()) == 0:
reshaped_values.append(value.unsqueeze(0))
else:
reshaped_values.append(value)
values = torch.cat(reshaped_values, dim=0).cpu()
values = torch.cat(values, dim=0).cpu()
return super().__call__(key, values)
@@ -114,8 +35,8 @@ def _log_histogram_and_mean(log_key, val):
class TensorBoardHistogramAndMeanAggregator(TensorAggregator):
def __init__(self, key: str, log_key: str, interval: Optional[int] = None):
super().__init__(key, interval)
def __init__(self, key: str, log_key: str):
super().__init__(key)
self.log_key = log_key
def aggregate(self, values):
@@ -133,9 +54,8 @@ class TensorBoardActionHistogramAndMeanAggregator(TensorAggregator):
title: str,
actions: List[str],
log_key_prefix: Optional[str] = None,
interval: Optional[int] = None,
):
super().__init__(key, interval)
super().__init__(key)
self.log_key_prefix = log_key_prefix or f"{category}/{title}"
self.actions = actions
SummaryWriterContext.add_custom_scalars_multilinechart(
@@ -157,10 +77,8 @@ class TensorBoardActionHistogramAndMeanAggregator(TensorAggregator):
class TensorBoardActionCountAggregator(TensorAggregator):
def __init__(
self, key: str, title: str, actions: List[str], interval: Optional[int] = None
):
super().__init__(key, interval)
def __init__(self, key: str, title: str, actions: List[str]):
super().__init__(key)
self.log_key = f"actions/{title}"
self.actions = actions
SummaryWriterContext.add_custom_scalars_multilinechart(
@@ -177,8 +95,8 @@ class TensorBoardActionCountAggregator(TensorAggregator):
class MeanAggregator(TensorAggregator):
def __init__(self, key: str, interval: Optional[int] = None):
super().__init__(key, interval)
def __init__(self, key: str):
super().__init__(key)
self.values: List[float] = []
def aggregate(self, values):
@@ -186,14 +104,6 @@ class MeanAggregator(TensorAggregator):
logger.info(f"{self.key}: {mean}")
self.values.append(mean)
def get_recent(self, count):
if len(self.values) == 0:
return []
return self.values[-count:]
def get_all(self):
return self.values
class FunctionsByActionAggregator(TensorAggregator):
"""
@@ -234,14 +144,8 @@ class FunctionsByActionAggregator(TensorAggregator):
}
"""
def __init__(
self,
key: str,
actions: List[str],
fns: Dict[str, Callable],
interval: Optional[int] = None,
):
super().__init__(key, interval)
def __init__(self, key: str, actions: List[str], fns: Dict[str, Callable]):
super().__init__(key)
self.actions = actions
self.values: Dict[str, Dict[str, List[float]]] = {
fn: {action: [] for action in self.actions} for fn in fns
@@ -268,8 +172,8 @@ class ActionCountAggregator(TensorAggregator):
`len(actions) - 1`. The input is assumed to contain action index.
"""
def __init__(self, key: str, actions: List[str], interval: Optional[int] = None):
super().__init__(key, interval)
def __init__(self, key: str, actions: List[str]):
super().__init__(key)
self.actions = actions
self.values: Dict[str, List[int]] = {action: [] for action in actions}
@@ -286,7 +190,7 @@ class ActionCountAggregator(TensorAggregator):
"""
totals = np.array([sum(counts) for counts in zip(*self.values.values())])
return {
action: (np.array(counts) / np.clip(totals, 1, None)).tolist()
action: (np.array(counts) / totals).tolist()
for action, counts in self.values.items()
}
@@ -294,7 +198,7 @@ class ActionCountAggregator(TensorAggregator):
"""
Returns the cumulative distributions in each aggregating step
"""
totals = max(1, sum(sum(counts) for counts in zip(*self.values.values())))
totals = sum(sum(counts) for counts in zip(*self.values.values()))
return {action: sum(counts) / totals for action, counts in self.values.items()}
@@ -302,20 +206,10 @@ _RECENT_DEFAULT_SIZE = int(1e6)
class RecentValuesAggregator(TensorAggregator):
def __init__(
self, key: str, size: int = _RECENT_DEFAULT_SIZE, interval: Optional[int] = None
):
super().__init__(key, interval)
def __init__(self, key: str, size: int = _RECENT_DEFAULT_SIZE):
super().__init__(key)
self.values: Deque[float] = deque(maxlen=size)
def aggregate(self, values):
flattened = torch.flatten(values).tolist()
self.values.extend(flattened)
def get_recent(self, count):
if len(self.values) == 0:
return []
return self.values[-count:]
def get_all(self):
return self.values
-29
View File
@@ -1,29 +0,0 @@
#!/usr/bin/env python3
import functools
import importlib
if importlib.util.find_spec("fblearner") is not None:
import fblearner.flow.api as flow
class AsyncWrapper:
def __init__(self, **kwargs):
self.async_wrapper = flow.flow_async(**kwargs)
self.type_wrapper = flow.typed()
def __call__(self, func):
return self.async_wrapper(self.type_wrapper(func))
else:
def AsyncWrapper(**outer_kwargs):
def async_wrapper_internal(func):
@functools.wraps(func)
def async_wrapper_repeat(*args, **kwargs):
return func(*args, **kwargs)
return async_wrapper_repeat
return async_wrapper_internal
+103
View File
@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Any, Dict, Iterable, List, Optional
from reagent.core.tracker import Aggregator, Observer
logger = logging.getLogger(__name__)
class CompositeObserver(Observer):
"""
A composite observer which takes care of dispatching values to child observers
"""
def __init__(self, observers: Iterable[Observer]):
self.observers: Dict[str, List[Observer]] = {}
for observer in observers:
observing_keys = observer.get_observing_keys()
for key in observing_keys:
self.observers.setdefault(key, []).append(observer)
super().__init__(list(self.observers))
def update(self, key: str, value):
for observer in self.observers[key]:
observer.update(key, value)
class EpochEndObserver(Observer):
"""
Call the callback function with epoch # when the epoch ends
"""
def __init__(self, callback, key: str = "epoch_end"):
super().__init__(observing_keys=[key])
self.callback = callback
def update(self, key: str, value):
self.callback(value)
class ValueListObserver(Observer):
"""
Simple observer that collect values into a list
"""
def __init__(self, observing_key: str):
super().__init__(observing_keys=[observing_key])
self.observing_key = observing_key
self.values: List[Any] = []
def update(self, key: str, value):
self.values.append(value)
def reset(self):
self.values = []
class IntervalAggregatingObserver(Observer):
def __init__(
self,
interval: Optional[int],
aggregator: Aggregator,
observe_epoch_end: bool = True,
):
self.key = aggregator.key
obs_keys = ["epoch_end"] if observe_epoch_end else []
obs_keys.append(self.key)
super().__init__(observing_keys=obs_keys)
self.iteration = 0
self.interval = interval
self.intermediate_values: List[Any] = []
self.aggregator = aggregator
def update(self, key: str, value):
if key == "epoch_end":
self.flush()
return
self.intermediate_values.append(value)
self.iteration += 1
# pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`.
if self.interval and self.iteration % self.interval == 0:
logger.info(
f"Interval Agg. Update: {self.key}; iteration {self.iteration}; "
f"aggregator: {self.aggregator.__class__.__name__}"
)
self.aggregator(self.key, self.intermediate_values)
self.intermediate_values = []
def flush(self):
# We need to reset iteration here to avoid aggregating on the same data multiple
# times
logger.info(
f"Interval Agg. Flushing: {self.key}; iteration: {self.iteration}; "
f"aggregator: {self.aggregator.__class__.__name__}; points: {len(self.intermediate_values)}"
)
self.iteration = 0
if self.intermediate_values:
self.aggregator(self.key, self.intermediate_values)
self.intermediate_values = []
+4 -11
View File
@@ -16,7 +16,7 @@ class RegistryMeta(abc.ABCMeta):
def __init__(cls, name, bases, attrs):
if not hasattr(cls, "REGISTRY"):
# Put REGISTRY on cls. This only happens once on the base class
logger.debug("Adding REGISTRY to type {}".format(name))
logger.info("Adding REGISTRY to type {}".format(name))
cls.REGISTRY: Dict[str, Type] = {}
cls.REGISTRY_NAME = name
cls.REGISTRY_FROZEN = False
@@ -28,19 +28,12 @@ class RegistryMeta(abc.ABCMeta):
if not cls.__abstractmethods__ and name != cls.REGISTRY_NAME:
# Only register fully-defined classes
logger.info(f"Registering {name} to {cls.REGISTRY_NAME}")
if hasattr(cls, "__registry_name__"):
registry_name = cls.__registry_name__
logger.info(
f"Registering {name} with alias {registry_name} to {cls.REGISTRY_NAME}"
)
logger.info(f"Using {registry_name} instead of {name}")
name = registry_name
else:
logger.info(f"Registering {name} to {cls.REGISTRY_NAME}")
# assert name not in cls.REGISTRY
# TODO: Combine FB and OSS model managers and then bring back this assert.
# For now this works because FB model managers inherit from their OSS counterparts
if name in cls.REGISTRY:
logger.warning(f"Overwriting open source {name} with internal version")
assert name not in cls.REGISTRY
cls.REGISTRY[name] = cls
else:
logger.info(
+1 -1
View File
@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from reagent.core.dataclasses import dataclass
from reagent.reporting.result_registries import PublishingResult, ValidationResult
from reagent.workflow.result_registries import PublishingResult, ValidationResult
@dataclass
-19
View File
@@ -1,19 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from dataclasses import dataclass
from typing import Optional
from reagent.core.union import (
PublishingResult__Union,
TrainingReport__Union,
ValidationResult__Union,
)
@dataclass
class RLTrainingOutput:
validation_result: Optional[ValidationResult__Union] = None
publishing_result: Optional[PublishingResult__Union] = None
training_report: Optional[TrainingReport__Union] = None
local_output_path: Optional[str] = None
+117
View File
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import functools
import logging
from typing import List
import torch
logger = logging.getLogger(__name__)
class Observer:
"""
Base class for observers
"""
def __init__(self, observing_keys: List[str]):
super().__init__()
assert isinstance(observing_keys, list)
self.observing_keys = observing_keys
def get_observing_keys(self) -> List[str]:
return self.observing_keys
def update(self, key: str, value):
pass
class Aggregator:
def __init__(self, key: str):
super().__init__()
self.key = key
def __call__(self, key: str, values):
assert key == self.key, f"Got {key}; expected {self.key}"
self.aggregate(values)
def aggregate(self, values):
pass
def observable(cls=None, **kwargs): # noqa: C901
"""
Decorator to mark a class as producing observable values. The names of the
observable values are the names of keyword arguments. The values of keyword
arguments are the types of the value. The type is currently not used for
anything.
"""
assert kwargs
observable_value_types = kwargs
def wrap(cls):
assert not hasattr(cls, "add_observer")
assert not hasattr(cls, "notify_observers")
original_init = cls.__init__
@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
assert not hasattr(self, "_observable_value_types")
assert not hasattr(self, "_observers")
self._observable_value_types = observable_value_types
self._observers = {v: [] for v in observable_value_types}
cls.__init__ = new_init
def add_observer(self, observer: Observer) -> None:
observing_keys = observer.get_observing_keys()
unknown_keys = [
k for k in observing_keys if k not in self._observable_value_types
]
if unknown_keys:
logger.warning(f"{unknown_keys} cannot be observed in {type(self)}")
for k in observing_keys:
if k in self._observers and observer not in self._observers[k]:
self._observers[k].append(observer)
return self
cls.add_observer = add_observer
def add_observers(self, observers: List[Observer]) -> None:
for observer in observers:
self.add_observer(observer)
return self
cls.add_observers = add_observers
def notify_observers(self, **kwargs):
for key, value in kwargs.items():
if value is None:
# Allow optional reporting
continue
assert key in self._observers, f"Unknown key: {key}"
# TODO: Create a generic framework for type conversion
if self._observable_value_types[key] == torch.Tensor:
if not isinstance(value, torch.Tensor):
value = torch.tensor(value)
if len(value.shape) == 0:
value = value.reshape(1)
value = value.detach()
for observer in self._observers[key]:
observer.update(key, value)
cls.notify_observers = notify_observers
return cls
if cls is None:
return wrap
return wrap(cls)
+65 -708
View File
@@ -1,26 +1,32 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import dataclasses
import logging
# The dataclasses in this file should be vanilla dataclass to have minimal overhead
from dataclasses import dataclass, field
from datetime import datetime as RecurringPeriod # noqa
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from reagent.base_dataclass import BaseDataClass
from reagent.core.configuration import param_hash
from reagent.core.dataclasses import dataclass as pydantic_dataclass
from reagent.preprocessing.normalization_constants import (
# Triggering registration to registries
import reagent.core.result_types # noqa
import reagent.workflow.training_reports # noqa
from reagent.core.dataclasses import dataclass
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
from reagent.core.tagged_union import TaggedUnion # noqa F401
from reagent.models.model_feature_config_provider import ModelFeatureConfigProvider
from reagent.preprocessing.normalization import (
DEFAULT_MAX_QUANTILE_SIZE,
DEFAULT_MAX_UNIQUE_ENUM,
DEFAULT_NUM_SAMPLES,
DEFAULT_QUANTILE_K2_THRESHOLD,
)
from reagent.preprocessing.types import InputColumn
from reagent.types import BaseDataClass
from reagent.workflow.result_registries import PublishingResult, ValidationResult
from reagent.workflow.training_reports import TrainingReport
if IS_FB_ENVIRONMENT:
from reagent.fb.models.model_feature_config_builder import ( # noqa
ConfigeratorModelFeatureConfigProvider,
)
import reagent.core.fb.fb_types # noqa
@dataclass
@@ -34,7 +40,7 @@ class OssDataset(Dataset):
@dataclass
class TableSpec(BaseDataClass):
class TableSpec:
table: str
table_sample: Optional[float] = None
eval_table_sample: Optional[float] = None
@@ -44,11 +50,27 @@ class TableSpec(BaseDataClass):
class RewardOptions:
custom_reward_expression: Optional[str] = None
metric_reward_values: Optional[Dict[str, float]] = None
additional_reward_expression: Optional[str] = None
# for ranking
# key: feature id in slate_reward column, value: linear coefficient
slate_reward_values: Optional[Dict[str, float]] = None
# key: feature id in item_reward column, value: linear coefficient
item_reward_values: Optional[Dict[str, float]] = None
@dataclass
class ReaderOptions:
pass
num_threads: int = 32
skip_smaller_batches: bool = True
num_workers: int = 0
koski_logging_level: int = 2
# distributed reader
distributed_reader: bool = False
distributed_master_mem: str = "20G"
distributed_worker_mem: str = "20G"
distributed_num_workers: int = 2
gang_name: str = ""
@dataclass
@@ -58,7 +80,10 @@ class OssReaderOptions(ReaderOptions):
@dataclass
class ResourceOptions:
pass
cpu: Optional[int] = None
# "-1" or "xxG" where "xx" is a positive integer
memory: Optional[str] = "40g"
gpu: int = 1
@dataclass
@@ -84,713 +109,45 @@ class PreprocessingOptions(BaseDataClass):
set_missing_value_to_zero: Optional[bool] = False
whitelist_features: Optional[List[int]] = None
assert_whitelist_feature_coverage: bool = True
variance_threshold: VarianceThreshold = VarianceThreshold()
sequence_feature_id: Optional[int] = None
ignore_sanity_check_failure: bool = IGNORE_SANITY_CHECK_FAILURE
ignore_sanity_check_task: bool = False
variance_threshold: VarianceThreshold = VarianceThreshold()
load_from_operator_id: Optional[int] = None
skip_sanity_check: bool = False
# IdMappings are stored in manifold folder:
# "tree/{namespace}/{tablename}/{ds}/{base_mapping_name}/{embedding_table_name}"
base_mapping_name: str = "DefaultMappingName"
sequence_feature_id: Optional[int] = None
### below here for preprocessing sparse features ###
# If the number of occurrences of any raw features ids is lower than this, we
# ignore those feature ids when constructing the IdMapping
sparse_threshold: int = 0
# IdMappings are stored in manifold folder:
# "tree/{namespace}/{tablename}/{ds}/{base_mapping_name}/{embedding_table_name}"
base_mapping_name: str = "DefaultMappingName"
class NoDuplicatedWarningLogger:
def __init__(self, logger):
self.logger = logger
self.msg = set()
def warning(self, msg):
if msg not in self.msg:
self.logger.warning(msg)
self.msg.add(msg)
@ModelFeatureConfigProvider.fill_union()
class ModelFeatureConfigProvider__Union(TaggedUnion):
pass
logger = logging.getLogger(__name__)
no_dup_logger = NoDuplicatedWarningLogger(logger)
@PublishingResult.fill_union()
class PublishingResult__Union(TaggedUnion):
pass
def isinstance_namedtuple(x):
return isinstance(x, tuple) and hasattr(x, "_fields")
@ValidationResult.fill_union()
class ValidationResult__Union(TaggedUnion):
pass
@TrainingReport.fill_union()
class RLTrainingReport(TaggedUnion):
pass
@dataclass
class TensorDataClass(BaseDataClass):
def __getattr__(self, attr):
if attr.startswith("__") and attr.endswith("__"):
raise AttributeError
tensor_attr = getattr(torch.Tensor, attr, None)
if tensor_attr is None or not callable(tensor_attr):
logger.error(
f"Attemping to call torch.Tensor.{attr} on "
f"{type(self)} (instance of TensorDataClass)."
)
if tensor_attr is None:
raise AttributeError(f"torch.Tensor doesn't have {attr} attribute.")
else:
raise RuntimeError(f"Tensor.{attr} is not callable.")
def continuation(*args, **kwargs):
def f(v):
# if possible, returns v.attr(*args, **kwargs).
# otws, return v
if isinstance(v, (torch.Tensor, TensorDataClass)):
return getattr(v, attr)(*args, **kwargs)
elif isinstance(v, dict):
return {kk: f(vv) for kk, vv in v.items()}
elif isinstance(v, tuple):
return tuple(f(vv) for vv in v)
return v
return type(self)(**f(self.__dict__))
return continuation
def cuda(self, *args, **kwargs):
cuda_tensor = {}
for k, v in self.__dict__.items(): # noqa F402
if isinstance(v, torch.Tensor):
kwargs["non_blocking"] = kwargs.get("non_blocking", True)
cuda_tensor[k] = v.cuda(*args, **kwargs)
elif isinstance(v, TensorDataClass):
cuda_tensor[k] = v.cuda(*args, **kwargs)
else:
cuda_tensor[k] = v
return type(self)(**cuda_tensor)
# (offset, value)
IdListFeatureValue = Tuple[torch.Tensor, torch.Tensor]
# (offset, key, value)
IdScoreListFeatureValue = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
# name -> value
IdListFeature = Dict[str, IdListFeatureValue]
IdScoreListFeature = Dict[str, IdScoreListFeatureValue]
# id -> value
ServingIdListFeature = Dict[int, IdListFeatureValue]
ServingIdScoreListFeature = Dict[int, IdScoreListFeatureValue]
#####
# FIXME: These config types are misplaced but we need to write FBL config adapter
# if we moved them.
######
@pydantic_dataclass
class IdListFeatureConfig(BaseDataClass):
name: str
# integer feature ID
feature_id: int
# name of the embedding table to use
id_mapping_name: str
@pydantic_dataclass
class IdScoreListFeatureConfig(BaseDataClass):
name: str
# integer feature ID
feature_id: int
# name of the embedding table to use
id_mapping_name: str
@pydantic_dataclass
class FloatFeatureInfo(BaseDataClass):
name: str
feature_id: int
@pydantic_dataclass
class IdMapping(object):
__hash__ = param_hash
ids: List[int] = field(default_factory=list)
def __post_init_post_parse__(self):
"""
used in preprocessing
ids list represents mapping from idx -> value
we want the reverse: from feature to embedding table indices
"""
self._id2index: Dict[int, int] = {}
@property
def id2index(self) -> Dict[int, int]:
# pyre-fixme[16]: `IdMapping` has no attribute `_id2index`.
if not self._id2index:
self._id2index = {id: i for i, id in enumerate(self.ids)}
return self._id2index
@property
def table_size(self):
return len(self.ids)
@pydantic_dataclass
class ModelFeatureConfig(BaseDataClass):
float_feature_infos: List[FloatFeatureInfo] = field(default_factory=list)
# table name -> id mapping
id_mapping_config: Dict[str, IdMapping] = field(default_factory=dict)
# id_list_feature_configs is feature_id -> list of values
id_list_feature_configs: List[IdListFeatureConfig] = field(default_factory=list)
# id_score_list_feature_configs is feature_id -> (keys -> values)
id_score_list_feature_configs: List[IdScoreListFeatureConfig] = field(
default_factory=list
)
def __post_init_post_parse__(self):
both_lists = self.id_list_feature_configs + self.id_score_list_feature_configs
if not self.only_dense:
# sanity check for keys in mapping config
ids = [config.feature_id for config in both_lists]
names = [config.name for config in both_lists]
assert len(ids) == len(set(ids)), f"duplicates in ids: {ids}"
assert len(names) == len(set(names)), f"duplicates in names: {names}"
assert len(ids) == len(names), f"{len(ids)} != {len(names)}"
self._id2name = {config.feature_id: config.name for config in both_lists}
self._name2id = {config.name: config.feature_id for config in both_lists}
self._id2config = {config.feature_id: config for config in both_lists}
self._name2config = {config.name: config for config in both_lists}
@property
def only_dense(self):
return not (self.id_list_feature_configs or self.id_score_list_feature_configs)
@property
def id2name(self):
return self._id2name
@property
def name2id(self):
return self._name2id
@property
def id2config(self):
return self._id2config
@property
def name2config(self):
return self._name2config
######
# dataclasses for internal API
######
@dataclass
class ValuePresence(TensorDataClass):
value: torch.Tensor
presence: Optional[torch.Tensor]
@dataclass
class ActorOutput(TensorDataClass):
action: torch.Tensor
log_prob: Optional[torch.Tensor] = None
squashed_mean: Optional[torch.Tensor] = None
@dataclass
class DocList(TensorDataClass):
# the shape is (batch_size, num_candidates, num_document_features)
float_features: torch.Tensor
# the shapes are (batch_size, num_candidates)
mask: torch.Tensor
value: torch.Tensor
def __post_init__(self):
assert (
len(self.float_features.shape) == 3
), f"Unexpected shape: {self.float_features.shape}"
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def select_slate(self, action: torch.Tensor):
row_idx = torch.repeat_interleave(
torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1
)
mask = self.mask[row_idx, action]
# Make sure the indices are in the right range
assert mask.to(torch.bool).all()
float_features = self.float_features[row_idx, action]
value = self.value[row_idx, action]
return DocList(float_features, mask, value)
def as_feature_data(self):
_batch_size, _slate_size, feature_dim = self.float_features.shape
return FeatureData(self.float_features.view(-1, feature_dim))
@dataclass
class FeatureData(TensorDataClass):
# For dense features, shape is (batch_size, feature_dim)
float_features: torch.Tensor
id_list_features: IdListFeature = dataclasses.field(default_factory=dict)
id_score_list_features: IdScoreListFeature = dataclasses.field(default_factory=dict)
# For sequence, shape is (stack_size, batch_size, feature_dim)
stacked_float_features: Optional[torch.Tensor] = None
# For ranking algos,
candidate_docs: Optional[DocList] = None
# Experimental: sticking this here instead of putting it in float_features
# because a lot of places derive the shape of float_features from
# normalization parameters.
time_since_first: Optional[torch.Tensor] = None
def __post_init__(self):
def usage():
return (
"For sequence features, use `stacked_float_features`."
"For document features, use `candidate_doc_float_features`."
)
if self.float_features.ndim == 3:
no_dup_logger.warning(f"`float_features` should be 2D.\n{usage()}")
elif self.float_features.ndim != 2:
raise ValueError(
f"float_features should be 2D; got {self.float_features.shape}.\n{usage()}"
)
@property
def has_float_features_only(self) -> bool:
return (
not self.id_list_features
and self.time_since_first is None
and self.candidate_docs is None
)
def get_tiled_batch(self, num_tiles: int):
assert (
self.has_float_features_only
), f"only works for float features now: {self}"
"""
tiled_feature should be (batch_size * num_tiles, feature_dim)
forall i in [batch_size],
tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
"""
feat = self.float_features
assert (
len(feat.shape) == 2
), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
batch_size, _ = feat.shape
# pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
return FeatureData(float_features=tiled_feat)
class TensorFeatureData(torch.nn.Module):
"""
Primarily for using in nn.Sequential
"""
def forward(self, input: torch.Tensor) -> FeatureData:
assert isinstance(input, torch.Tensor)
return FeatureData(input)
class ServingFeatureData(NamedTuple):
float_features_with_presence: Tuple[torch.Tensor, torch.Tensor]
id_list_features: ServingIdListFeature
id_score_list_features: ServingIdScoreListFeature
@dataclass
class PreprocessedRankingInput(TensorDataClass):
state: FeatureData
src_seq: FeatureData
src_src_mask: torch.Tensor
tgt_in_seq: Optional[FeatureData] = None
tgt_out_seq: Optional[FeatureData] = None
tgt_tgt_mask: Optional[torch.Tensor] = None
slate_reward: Optional[torch.Tensor] = None
position_reward: Optional[torch.Tensor] = None
# all indices will be +2 to account for padding
# symbol (0) and decoder_start_symbol (1)
src_in_idx: Optional[torch.Tensor] = None
tgt_in_idx: Optional[torch.Tensor] = None
tgt_out_idx: Optional[torch.Tensor] = None
tgt_out_probs: Optional[torch.Tensor] = None
# store ground-truth target sequences
optim_tgt_in_idx: Optional[torch.Tensor] = None
optim_tgt_out_idx: Optional[torch.Tensor] = None
optim_tgt_in_seq: Optional[FeatureData] = None
optim_tgt_out_seq: Optional[FeatureData] = None
def batch_size(self) -> int:
return self.state.float_features.size()[0]
@classmethod
def from_tensors(
cls,
state: torch.Tensor,
src_seq: torch.Tensor,
src_src_mask: torch.Tensor,
tgt_in_seq: Optional[torch.Tensor] = None,
tgt_out_seq: Optional[torch.Tensor] = None,
tgt_tgt_mask: Optional[torch.Tensor] = None,
slate_reward: Optional[torch.Tensor] = None,
position_reward: Optional[torch.Tensor] = None,
src_in_idx: Optional[torch.Tensor] = None,
tgt_in_idx: Optional[torch.Tensor] = None,
tgt_out_idx: Optional[torch.Tensor] = None,
tgt_out_probs: Optional[torch.Tensor] = None,
optim_tgt_in_idx: Optional[torch.Tensor] = None,
optim_tgt_out_idx: Optional[torch.Tensor] = None,
optim_tgt_in_seq: Optional[torch.Tensor] = None,
optim_tgt_out_seq: Optional[torch.Tensor] = None,
**kwargs,
):
assert isinstance(state, torch.Tensor)
assert isinstance(src_seq, torch.Tensor)
assert isinstance(src_src_mask, torch.Tensor)
assert tgt_in_seq is None or isinstance(tgt_in_seq, torch.Tensor)
assert tgt_out_seq is None or isinstance(tgt_out_seq, torch.Tensor)
assert tgt_tgt_mask is None or isinstance(tgt_tgt_mask, torch.Tensor)
assert slate_reward is None or isinstance(slate_reward, torch.Tensor)
assert position_reward is None or isinstance(position_reward, torch.Tensor)
assert src_in_idx is None or isinstance(src_in_idx, torch.Tensor)
assert tgt_in_idx is None or isinstance(tgt_in_idx, torch.Tensor)
assert tgt_out_idx is None or isinstance(tgt_out_idx, torch.Tensor)
assert tgt_out_probs is None or isinstance(tgt_out_probs, torch.Tensor)
assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor)
assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor)
assert optim_tgt_in_seq is None or isinstance(optim_tgt_in_seq, torch.Tensor)
assert optim_tgt_out_seq is None or isinstance(optim_tgt_out_seq, torch.Tensor)
return cls(
state=FeatureData(float_features=state),
src_seq=FeatureData(float_features=src_seq),
src_src_mask=src_src_mask,
tgt_in_seq=FeatureData(float_features=tgt_in_seq)
if tgt_in_seq is not None
else None,
tgt_out_seq=FeatureData(float_features=tgt_out_seq)
if tgt_out_seq is not None
else None,
tgt_tgt_mask=tgt_tgt_mask,
slate_reward=slate_reward,
position_reward=position_reward,
src_in_idx=src_in_idx,
tgt_in_idx=tgt_in_idx,
tgt_out_idx=tgt_out_idx,
tgt_out_probs=tgt_out_probs,
optim_tgt_in_idx=optim_tgt_in_idx,
optim_tgt_out_idx=optim_tgt_out_idx,
optim_tgt_in_seq=FeatureData(float_features=optim_tgt_in_seq)
if optim_tgt_in_seq is not None
else None,
optim_tgt_out_seq=FeatureData(float_features=optim_tgt_out_seq)
if optim_tgt_out_seq is not None
else None,
)
def __post_init__(self):
if (
isinstance(self.state, torch.Tensor)
or isinstance(self.src_seq, torch.Tensor)
or isinstance(self.tgt_in_seq, torch.Tensor)
or isinstance(self.tgt_out_seq, torch.Tensor)
or isinstance(self.optim_tgt_in_seq, torch.Tensor)
or isinstance(self.optim_tgt_out_seq, torch.Tensor)
):
raise ValueError(
f"Use from_tensors() {type(self.state)} {type(self.src_seq)} "
f"{type(self.tgt_in_seq)} {type(self.tgt_out_seq)} "
f"{type(self.optim_tgt_in_seq)} {type(self.optim_tgt_out_seq)} "
)
@dataclass
class BaseInput(TensorDataClass):
"""
Base class for all inputs, both raw and preprocessed
"""
state: FeatureData
next_state: FeatureData
reward: torch.Tensor
time_diff: torch.Tensor
step: Optional[torch.Tensor]
not_terminal: torch.Tensor
def batch_size(self):
return self.state.float_features.size()[0]
@classmethod
def from_dict(cls, batch):
id_list_features = batch.get(InputColumn.STATE_ID_LIST_FEATURES, None) or {}
id_score_list_features = (
batch.get(InputColumn.STATE_ID_SCORE_LIST_FEATURES, None) or {}
)
next_id_list_features = (
batch.get(InputColumn.NEXT_STATE_ID_LIST_FEATURES, None) or {}
)
next_id_score_list_features = (
batch.get(InputColumn.NEXT_STATE_ID_SCORE_LIST_FEATURES, None) or {}
)
return BaseInput(
state=FeatureData(
float_features=batch[InputColumn.STATE_FEATURES],
id_list_features=id_list_features,
id_score_list_features=id_score_list_features,
),
next_state=FeatureData(
float_features=batch[InputColumn.NEXT_STATE_FEATURES],
id_list_features=next_id_list_features,
id_score_list_features=next_id_score_list_features,
),
reward=batch[InputColumn.REWARD],
time_diff=batch[InputColumn.TIME_DIFF],
step=batch[InputColumn.STEP],
not_terminal=batch[InputColumn.NOT_TERMINAL],
)
@dataclass
class ExtraData(TensorDataClass):
mdp_id: Optional[torch.Tensor] = None
sequence_number: Optional[torch.Tensor] = None
action_probability: Optional[torch.Tensor] = None
max_num_actions: Optional[int] = None
metrics: Optional[torch.Tensor] = None
@classmethod
def from_dict(cls, d):
return cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)})
@dataclass
class DiscreteDqnInput(BaseInput):
action: torch.Tensor
next_action: torch.Tensor
possible_actions_mask: torch.Tensor
possible_next_actions_mask: torch.Tensor
extras: ExtraData
@classmethod
def from_dict(cls, batch):
base = super().from_dict(batch)
return cls(
state=base.state,
next_state=base.next_state,
reward=base.reward,
time_diff=base.time_diff,
step=base.step,
not_terminal=base.not_terminal,
action=batch[InputColumn.ACTION],
next_action=batch[InputColumn.NEXT_ACTION],
possible_actions_mask=batch[InputColumn.POSSIBLE_ACTIONS_MASK],
possible_next_actions_mask=batch[InputColumn.POSSIBLE_NEXT_ACTIONS_MASK],
extras=batch[InputColumn.EXTRAS],
)
@dataclass
class SlateQInput(BaseInput):
"""
The shapes of `reward`, `reward_mask`, & `next_item_mask` are
`(batch_size, slate_size)`.
`reward_mask` indicated whether the reward could be observed, e.g.,
the item got into viewport or not.
"""
action: torch.Tensor
next_action: torch.Tensor
reward_mask: torch.Tensor
extras: Optional[ExtraData] = None
@classmethod
def from_dict(cls, d):
action = d["action"]
next_action = d["next_action"]
return cls(
state=FeatureData(
float_features=d["state_features"],
candidate_docs=DocList(
float_features=d["candidate_features"],
mask=d["item_mask"],
value=d["item_probability"],
),
),
next_state=FeatureData(
float_features=d["next_state_features"],
candidate_docs=DocList(
float_features=d["next_candidate_features"],
mask=d["next_item_mask"],
value=d["next_item_probability"],
),
),
action=action,
next_action=next_action,
reward=d["position_reward"],
reward_mask=d["reward_mask"],
time_diff=d["time_diff"],
not_terminal=d["not_terminal"],
step=None,
extras=ExtraData.from_dict(d),
)
@dataclass
class ParametricDqnInput(BaseInput):
action: FeatureData
next_action: FeatureData
possible_actions: FeatureData
possible_actions_mask: torch.Tensor
possible_next_actions: FeatureData
possible_next_actions_mask: torch.Tensor
extras: Optional[ExtraData] = None
@classmethod
def from_dict(cls, batch):
return cls(
state=FeatureData(float_features=batch["state_features"]),
action=FeatureData(float_features=batch["action"]),
next_state=FeatureData(float_features=batch["next_state_features"]),
next_action=FeatureData(float_features=batch["next_action"]),
possible_actions=FeatureData(float_features=batch["possible_actions"]),
possible_actions_mask=batch["possible_actions_mask"],
possible_next_actions=FeatureData(
float_features=batch["possible_next_actions"]
),
possible_next_actions_mask=batch["possible_next_actions_mask"],
reward=batch["reward"],
not_terminal=batch["not_terminal"],
time_diff=batch["time_diff"],
step=batch["step"],
extras=batch["extras"],
)
@dataclass
class PolicyNetworkInput(BaseInput):
action: FeatureData
next_action: FeatureData
extras: Optional[ExtraData] = None
@classmethod
def from_dict(cls, batch):
return cls(
state=FeatureData(float_features=batch["state_features"]),
action=FeatureData(float_features=batch["action"]),
next_state=FeatureData(float_features=batch["next_state_features"]),
next_action=FeatureData(float_features=batch["next_action"]),
reward=batch["reward"],
not_terminal=batch["not_terminal"],
time_diff=batch["time_diff"],
step=batch["step"],
extras=batch["extras"],
)
def batch_size(self) -> int:
return self.state.float_features.shape[0]
@dataclass
class PolicyGradientInput(BaseDataClass):
state: FeatureData
action: torch.Tensor
reward: torch.Tensor
log_prob: torch.Tensor
@classmethod
def input_prototype(cls):
num_classes = 5
batch_size = 10
state_dim = 3
return cls(
state=FeatureData(float_features=torch.randn(batch_size, state_dim)),
action=F.one_hot(torch.randint(high=num_classes, size=(batch_size,))),
reward=torch.rand(batch_size),
log_prob=torch.log(torch.rand(batch_size)),
)
@dataclass
class MemoryNetworkInput(BaseInput):
action: torch.Tensor
def batch_size(self):
if len(self.state.float_features.size()) == 2:
return self.state.float_features.size()[0]
elif len(self.state.float_features.size()) == 3:
return self.state.float_features.size()[1]
else:
raise NotImplementedError()
@dataclass
class PreprocessedTrainingBatch(TensorDataClass):
training_input: Union[PreprocessedRankingInput]
# TODO: deplicate this and move into individual ones.
extras: ExtraData = field(default_factory=ExtraData)
def batch_size(self):
return self.training_input.state.float_features.size()[0]
@dataclass
class MemoryNetworkOutput(TensorDataClass):
mus: torch.Tensor
sigmas: torch.Tensor
logpi: torch.Tensor
reward: torch.Tensor
not_terminal: torch.Tensor
last_step_lstm_hidden: torch.Tensor
last_step_lstm_cell: torch.Tensor
all_steps_lstm_hidden: torch.Tensor
@dataclass
class Seq2RewardOutput(TensorDataClass):
acc_reward: torch.Tensor
@dataclass
class DqnPolicyActionSet(TensorDataClass):
greedy: int
softmax: Optional[int] = None
greedy_act_name: Optional[str] = None
softmax_act_name: Optional[str] = None
softmax_act_prob: Optional[float] = None
@dataclass
class PlanningPolicyOutput(TensorDataClass):
# best action to take next
next_best_continuous_action: Optional[torch.Tensor] = None
next_best_discrete_action_one_hot: Optional[torch.Tensor] = None
next_best_discrete_action_idx: Optional[int] = None
@dataclass
class RankingOutput(TensorDataClass):
# a tensor of integer indices w.r.t. to possible candidates
# shape: batch_size, tgt_seq_len
ranked_tgt_out_idx: Optional[torch.Tensor] = None
# generative probability of ranked tgt sequences at each decoding step
# shape: batch_size, tgt_seq_len, candidate_size
ranked_tgt_out_probs: Optional[torch.Tensor] = None
# log probabilities of given tgt sequences are used in REINFORCE
# shape: batch_size
log_probs: Optional[torch.Tensor] = None
# encoder scores in tgt_out_idx order
encoder_scores: Optional[torch.Tensor] = None
@dataclass
class RewardNetworkOutput(TensorDataClass):
predicted_reward: torch.Tensor
class RLTrainingOutput:
validation_result: Optional[ValidationResult__Union] = None
publishing_result: Optional[PublishingResult__Union] = None
training_report: Optional[RLTrainingReport] = None
output_path: Optional[str] = None
-39
View File
@@ -1,39 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
from reagent.core.tagged_union import TaggedUnion
from reagent.models.model_feature_config_provider import ModelFeatureConfigProvider
from reagent.reporting.result_registries import PublishingResult, ValidationResult
from reagent.reporting.training_reports import TrainingReport
if True: # Register modules for unions
import reagent.reporting.oss_training_reports # noqa
import reagent.core.result_types # noqa
if IS_FB_ENVIRONMENT:
import reagent.reporting.fb.fb_training_reports # noqa
import reagent.fb.models.model_feature_config_builder # noqa
import reagent.core.fb.fb_result_types # noqa
import reagent.core.fb.fb_types # noqa
@ModelFeatureConfigProvider.fill_union()
class ModelFeatureConfigProvider__Union(TaggedUnion):
pass
@PublishingResult.fill_union()
class PublishingResult__Union(TaggedUnion):
pass
@ValidationResult.fill_union()
class ValidationResult__Union(TaggedUnion):
pass
@TrainingReport.fill_union()
class TrainingReport__Union(TaggedUnion):
pass
View File
-41
View File
@@ -1,41 +0,0 @@
#!/usr/bin/env python3
import logging
from typing import Dict, Optional
from reagent.core.types import Dataset, PreprocessingOptions, ReaderOptions, TableSpec
from reagent.parameters import NormalizationParameters
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
logger = logging.getLogger(__name__)
class DataFetcher:
# TODO: T71636145 Make a more specific API for DataFetcher
def query_data(self, **kwargs):
raise NotImplementedError()
# TODO: T71636145 Make a more specific API for DataFetcher
def query_data_parametric(self, **kwargs):
raise NotImplementedError()
def identify_normalization_parameters(
self,
table_spec: TableSpec,
column_name: str,
preprocessing_options: PreprocessingOptions,
seed: Optional[int] = None,
) -> Dict[int, NormalizationParameters]:
raise NotImplementedError()
def get_dataloader(
self,
dataset: Dataset,
batch_size: int,
batch_preprocessor: Optional[BatchPreprocessor],
use_gpu: bool,
reader_options: ReaderOptions,
):
raise NotImplementedError()
@@ -3,8 +3,8 @@
import logging
import torch
from reagent.core.types import MemoryNetworkInput
from reagent.training.world_model.compress_model_trainer import CompressModelTrainer
from reagent.types import MemoryNetworkInput
logger = logging.getLogger(__name__)
+2 -104
View File
@@ -8,17 +8,8 @@ from typing import NamedTuple, Optional, cast
import numpy as np
import torch
import torch.nn as nn
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet
from reagent.ope.estimators.sequential_estimators import (
Action,
ActionSpace,
RLEstimatorInput,
RLPolicy,
State,
Transition,
ValueFunction,
)
from reagent.torch_utils import masked_softmax
from reagent.training import ParametricDQNTrainer
from reagent.training.dqn_trainer import DQNTrainer
@@ -51,7 +42,6 @@ class EvaluationDataPage(NamedTuple):
model_metrics_values_for_logged_action: Optional[torch.Tensor] = None
possible_actions_state_concat: Optional[torch.Tensor] = None
contexts: Optional[torch.Tensor] = None
sequential_estimator_input: Optional[RLEstimatorInput] = None
@classmethod
def create_from_training_batch(
@@ -320,83 +310,6 @@ class EvaluationDataPage(NamedTuple):
eval_action_idxs=eval_action_idxs,
)
@staticmethod
def create_rl_estimator_input_from_tensors_dqn(
trainer: DQNTrainer,
mdp_ids: torch.Tensor,
states: rlt.FeatureData,
actions: rlt.FeatureData,
propensities: torch.Tensor,
rewards: torch.Tensor,
):
class DQNRLPolicy(RLPolicy):
def __init__(self, trainer: DQNTrainer):
super().__init__(ActionSpace(trainer.num_actions))
self._trainer = trainer
def action_dist(self, state: State):
feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1))
# Only 1 batch
q_values = self._trainer.get_detached_q_values(feat_data)[0][0]
return self._action_space.distribution(
torch.nn.Softmax(dim=0)(q_values)
)
class CPEValueFunction(ValueFunction):
def __init__(self, trainer: DQNTrainer):
self._trainer = trainer
def state_action_value(self, state: State, action: Action) -> float:
feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1))
model_values = self._trainer.q_network_cpe(feat_data)[
:, 0 : self._trainer.num_actions
][0]
return model_values[action.value].item()
def state_value(self, state: State) -> float:
feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1))
model_values = self._trainer.q_network_cpe(feat_data)[
:, 0 : self._trainer.num_actions
][0]
q_values = self._trainer.get_detached_q_values(feat_data)[0][0]
dist = torch.nn.Softmax(dim=0)(q_values)
assert dist.shape == model_values.shape
return torch.dot(dist, model_values).item()
def reset(self):
pass
states_tensor = states.float_features
logged_actions = torch.argmax(actions.float(), dim=1)
log = []
cur_mdp = []
i = 0
while i < mdp_ids.shape[0]:
if i + 1 < mdp_ids.shape[0] and mdp_ids[i, 0] == mdp_ids[i + 1, 0]:
cur_mdp.append(
Transition(
last_state=State(states_tensor[i]),
action=Action(logged_actions[i].item()),
action_prob=propensities[i, 0].item(),
state=State(states_tensor[i + 1]),
reward=rewards[i, 0].item(),
status=Transition.Status.NORMAL,
)
)
elif len(cur_mdp) > 0:
log.append(cur_mdp)
cur_mdp = []
i += 1
# Temporary value of gamma
return RLEstimatorInput(
gamma=1.0,
log=log,
target_policy=DQNRLPolicy(trainer),
value_function=CPEValueFunction(trainer),
discrete_states=False,
)
@classmethod
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@@ -541,9 +454,6 @@ class EvaluationDataPage(NamedTuple):
possible_actions_mask=possible_actions_mask,
optimal_q_values=optimal_q_values,
eval_action_idxs=eval_action_idxs,
sequential_estimator_input=EvaluationDataPage.create_rl_estimator_input_from_tensors_dqn(
trainer, mdp_ids, states, actions, propensities, rewards
),
)
def append(self, edp):
@@ -560,15 +470,6 @@ class EvaluationDataPage(NamedTuple):
new_edp[x] = torch.cat((t, other_t), dim=0)
elif isinstance(t, np.ndarray):
new_edp[x] = np.concatenate((t, other_t), axis=0)
elif isinstance(t, RLEstimatorInput):
t.log.extend(other_t.log)
new_edp[x] = RLEstimatorInput(
gamma=t.gamma,
log=t.log,
target_policy=t.target_policy,
value_function=t.value_function,
discrete_states=t.discrete_states,
)
else:
raise Exception("Invalid type in training data page")
else:
@@ -583,10 +484,7 @@ class EvaluationDataPage(NamedTuple):
new_edp = {}
for x in EvaluationDataPage._fields:
t = getattr(self, x)
if hasattr(t, "__getitem__"):
new_edp[x] = t[sorted_idxs] if t is not None else None
else:
new_edp[x] = t
new_edp[x] = t[sorted_idxs] if t is not None else None
return EvaluationDataPage(**new_edp)
+5 -13
View File
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from reagent.core import types as rlt
from reagent.core.tracker import observable
from reagent.evaluation.cpe import CpeDetails, CpeEstimateSet
from reagent.evaluation.doubly_robust_estimator import DoublyRobustEstimator
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
@@ -53,6 +53,7 @@ def get_metrics_to_score(metric_reward_values: Optional[Dict[str, float]]) -> Li
return sorted([*metric_reward_values.keys()])
@observable(cpe_details=CpeDetails)
class Evaluator:
NUM_J_STEPS_FOR_MAGIC_ESTIMATOR = 25
@@ -69,15 +70,7 @@ class Evaluator:
gamma
)
self.reporter = None
def evaluate(self, eval_input: rlt.TensorDataClass) -> None:
pass
def finish(self):
pass
def evaluate_one_shot(self, edp: EvaluationDataPage) -> CpeDetails:
def evaluate_post_training(self, edp: EvaluationDataPage) -> CpeDetails:
cpe_details = CpeDetails()
cpe_details.reward_estimates = self.score_cpe("Reward", edp)
@@ -123,9 +116,8 @@ class Evaluator:
cpe_details.mc_loss = float(
F.mse_loss(edp.logged_values, edp.model_values_for_logged_action)
)
assert self.reporter is not None, "Missing reporter"
self.reporter.report(cpe_results=cpe_details)
# pyre-fixme[16]: `Evaluator` has no attribute `notify_observers`.
self.notify_observers(cpe_details=cpe_details)
return cpe_details
def score_cpe(self, metric_name, edp: EvaluationDataPage):
+94 -9
View File
@@ -11,6 +11,9 @@ from reagent.evaluation.cpe import (
)
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
from reagent.evaluation.evaluator import Evaluator
from reagent.evaluation.weighted_sequential_doubly_robust_estimator import (
WeightedSequentialDoublyRobustEstimator,
)
from reagent.ope.estimators.contextual_bandits_estimators import (
BanditsEstimatorInput,
DMEstimator,
@@ -31,6 +34,10 @@ from reagent.ope.estimators.sequential_estimators import (
MAGICEstimator,
RLEstimator,
RLEstimatorInput,
RLPolicy,
State,
Transition,
ValueFunction,
)
from reagent.ope.estimators.types import ActionSpace
@@ -109,6 +116,92 @@ class SequentialOPEstimatorAdapter:
self.gamma = gamma
self._device = device
class EDPSeqPolicy(RLPolicy):
def __init__(
self, num_actions: int, model_propensities: torch.Tensor, device=None
):
super().__init__(ActionSpace(num_actions), device)
self.model_propensities = model_propensities
def action_dist(self, state: State) -> ActionDistribution:
# "state" is (trajectory, step)
return self.model_propensities[state.value]
class EDPValueFunc(ValueFunction):
def __init__(
self, model_values: torch.Tensor, target_propensities: torch.Tensor
):
self.model_values = model_values
self.target_propensities = target_propensities
def state_action_value(self, state: State, action: Action) -> float:
return self.model_values[state.value][action].item()
def state_value(self, state: State) -> float:
return torch.dot(
self.model_values[state.value], self.target_propensities[state.value]
).item()
def reset(self):
pass
@staticmethod
def edp_to_rl_input(
edp: EvaluationDataPage, gamma, device=None
) -> RLEstimatorInput:
assert edp.model_values is not None
eq_len = WeightedSequentialDoublyRobustEstimator.transform_to_equal_length_trajectories(
edp.mdp_id,
edp.action_mask.cpu().numpy(),
edp.logged_rewards.cpu().numpy().flatten(),
edp.logged_propensities.cpu().numpy().flatten(),
edp.model_propensities.cpu().numpy(),
edp.model_values.cpu().numpy(),
)
(
actions,
rewards,
logged_propensities,
target_propensities,
estimated_q_values,
) = (
torch.tensor(x, dtype=torch.double, device=device, requires_grad=True)
for x in eq_len
)
num_examples = logged_propensities.shape[0]
horizon = logged_propensities.shape[1]
log = []
for traj in range(num_examples):
log.append(
[
Transition(
last_state=State((traj, i)),
action=torch.argmax(actions[traj, i]).item(),
action_prob=logged_propensities[traj, i].item(),
state=State((traj, i + 1)),
reward=rewards[traj, i].item(),
)
for i in range(horizon - 1)
if actions[traj, i][torch.argmax(actions[traj, i]).item()] != 0.0
]
)
return RLEstimatorInput(
gamma=gamma,
log=log,
target_policy=SequentialOPEstimatorAdapter.EDPSeqPolicy(
actions.shape[2], target_propensities
),
value_function=SequentialOPEstimatorAdapter.EDPValueFunc(
estimated_q_values, target_propensities
),
ground_truth=None,
horizon=horizon,
)
@staticmethod
def estimator_results_to_cpe_estimate(
estimator_results: EstimatorResults,
@@ -144,16 +237,8 @@ class SequentialOPEstimatorAdapter:
)
def estimate(self, edp: EvaluationDataPage) -> CpeEstimate:
est_input = edp.sequential_estimator_input
assert est_input is not None, "EDP does not contain sequential estimator inputs"
estimator_results = self.seq_ope_estimator.evaluate(
RLEstimatorInput(
gamma=self.gamma,
log=est_input.log,
target_policy=est_input.target_policy,
value_function=est_input.value_function,
discrete_states=est_input.discrete_states,
)
SequentialOPEstimatorAdapter.edp_to_rl_input(edp, self.gamma, self._device)
)
assert isinstance(estimator_results, EstimatorResults)
return SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate(
@@ -7,8 +7,9 @@ from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from reagent.core.types import PreprocessedTrainingBatch
from reagent.core.tracker import observable
from reagent.models.seq2slate import Seq2SlateMode
from reagent.types import PreprocessedTrainingBatch
from sklearn.metrics import (
average_precision_score,
dcg_score,
@@ -28,6 +29,17 @@ class ListwiseRankingMetrics:
cross_entropy_loss: Optional[float] = 0.0
@observable(
cross_entropy_loss=torch.Tensor,
dcg=torch.Tensor,
ndcg=torch.Tensor,
mean_ap=torch.Tensor,
auc=torch.Tensor,
base_dcg=torch.Tensor,
base_ndcg=torch.Tensor,
base_map=torch.Tensor,
base_auc=torch.Tensor,
)
class RankingListwiseEvaluator:
""" Evaluate listwise ranking models on common ranking metrics """
@@ -43,7 +55,6 @@ class RankingListwiseEvaluator:
self.base_map = []
self.log_softmax = nn.LogSoftmax(dim=1)
self.kl_loss = nn.KLDivLoss(reduction="batchmean")
self.reporter = None
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@@ -72,7 +83,9 @@ class RankingListwiseEvaluator:
self.seq2slate_net.train(seq2slate_net_prev_mode)
if not self.calc_cpe:
self.reporter.report_evaluation_minibatch(cross_entropy_loss=ce_loss)
# pyre-fixme[16]: `RankingListwiseEvaluator` has no attribute
# `notify_observers`.
self.notify_observers(cross_entropy_loss=ce_loss)
return
# shape: batch_size, tgt_seq_len
@@ -119,7 +132,7 @@ class RankingListwiseEvaluator:
batch_base_dcg.append(dcg_score(truth_scores, base_scores))
batch_base_ndcg.append(ndcg_score(truth_scores, base_scores))
self.reporter.report_evaluation_minibatch(
self.notify_observers(
cross_entropy_loss=ce_loss,
dcg=torch.mean(torch.tensor(batch_dcg)).reshape(1),
ndcg=torch.mean(torch.tensor(batch_ndcg)).reshape(1),
@@ -132,5 +145,5 @@ class RankingListwiseEvaluator:
)
@torch.no_grad()
def evaluate_one_shot(self):
def evaluate_post_training(self):
pass
@@ -8,15 +8,24 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from reagent.core.types import PreprocessedTrainingBatch
from reagent.core.tracker import observable
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
from reagent.models.seq2slate import Seq2SlateMode
from reagent.training.ranking.seq2slate_trainer import Seq2SlateTrainer
from reagent.types import PreprocessedTrainingBatch
logger = logging.getLogger(__name__)
@observable(
eval_baseline_loss=torch.Tensor,
eval_advantages=torch.Tensor,
logged_slate_rank_probs=torch.Tensor,
ranked_slate_rank_probs=torch.Tensor,
eval_data_pages_g=EvaluationDataPage,
eval_data_pages_ng=EvaluationDataPage,
)
class RankingPolicyGradientEvaluator:
""" Evaluate ranking models that are learned through policy gradient """
@@ -30,12 +39,13 @@ class RankingPolicyGradientEvaluator:
self.trainer = trainer
self.calc_cpe = calc_cpe
self.reward_network = reward_network
self.reporter = None
# Evaluate greedy/non-greedy version of the ranking model
self.eval_data_pages_g: Optional[EvaluationDataPage] = None
self.eval_data_pages_ng: Optional[EvaluationDataPage] = None
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def evaluate(self, eval_tdp: PreprocessedTrainingBatch) -> None:
seq2slate_net = self.trainer.seq2slate_net
@@ -117,7 +127,9 @@ class RankingPolicyGradientEvaluator:
else:
self.eval_data_pages_ng = self.eval_data_pages_ng.append(edp_ng)
self.reporter.report_evaluation_minibatch(
# pyre-fixme[16]: `RankingPolicyGradientEvaluator` has no attribute
# `notify_observers`.
self.notify_observers(
eval_baseline_loss=eval_baseline_loss,
eval_advantages=eval_advantage,
logged_slate_rank_probs=logged_slate_rank_prob,
@@ -125,13 +137,11 @@ class RankingPolicyGradientEvaluator:
)
@torch.no_grad()
def finish(self):
self.reporter.report_evaluation_epoch(
def evaluate_post_training(self):
self.notify_observers(
# Use ValueListObserver as aggregating_observers requires input to be Tensor
eval_data_pages_g=self.eval_data_pages_g,
eval_data_pages_ng=self.eval_data_pages_ng,
)
self.eval_data_pages_g = None
self.eval_data_pages_ng = None
def evaluate_one_shot(self, edp: EvaluationDataPage):
pass
+5 -8
View File
@@ -6,10 +6,9 @@ import logging
import numpy as np
import torch
import torch.nn.functional as F
from reagent.core import types as rlt
from reagent.core.types import PreprocessedTrainingBatch
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
from reagent import types as rlt
from reagent.training.reward_network_trainer import RewardNetTrainer
from reagent.types import PreprocessedTrainingBatch
logger = logging.getLogger(__name__)
@@ -22,6 +21,7 @@ class RewardNetEvaluator:
self.trainer = trainer
self.mse_loss = []
self.rewards = []
self.best_model = None
self.best_model_loss = 1e9
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
@@ -47,7 +47,7 @@ class RewardNetEvaluator:
reward_net.train(reward_net_prev_mode)
@torch.no_grad()
def finish(self):
def evaluate_post_training(self):
mean_mse_loss = np.mean(self.mse_loss)
logger.info(f"Evaluation MSE={mean_mse_loss}")
eval_res = {"mse": mean_mse_loss, "rewards": torch.cat(self.rewards)}
@@ -56,9 +56,6 @@ class RewardNetEvaluator:
if mean_mse_loss < self.best_model_loss:
self.best_model_loss = mean_mse_loss
self.trainer.best_model = copy.deepcopy(self.trainer.reward_net)
self.best_model = copy.deepcopy(self.trainer.reward_net)
return eval_res
def evaluate_one_shot(self, edp: EvaluationDataPage):
pass
+4 -5
View File
@@ -3,8 +3,8 @@
import logging
import torch
from reagent.core.types import PreprocessedTrainingBatch
from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer
from reagent.types import PreprocessedTrainingBatch
logger = logging.getLogger(__name__)
@@ -15,13 +15,15 @@ class Seq2RewardEvaluator:
self.trainer = trainer
self.reward_net = self.trainer.seq2reward_network
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def evaluate(self, eval_tdp: PreprocessedTrainingBatch):
reward_net_prev_mode = self.reward_net.training
self.reward_net.eval()
# pyre-fixme[6]: Expected `MemoryNetworkInput` for 1st param but got
# `PreprocessedTrainingBatch`.
loss = self.trainer.compute_loss(eval_tdp)
loss = self.trainer.get_loss(eval_tdp)
detached_loss = loss.cpu().detach().item()
q_values = (
self.trainer.get_Q(
@@ -37,6 +39,3 @@ class Seq2RewardEvaluator:
)
self.reward_net.train(reward_net_prev_mode)
return (detached_loss, q_values)
def finish(self):
pass
+8 -26
View File
@@ -4,28 +4,23 @@ import logging
from typing import Dict, List
import torch
from reagent.core.types import FeatureData, MemoryNetworkInput
from reagent.reporting.world_model_reporter import (
DebugToolsReporter,
WorldModelReporter,
)
from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer
from reagent.types import FeatureData, MemoryNetworkInput
logger = logging.getLogger(__name__)
class WorldModelLossEvaluator(object):
class LossEvaluator(object):
""" Evaluate losses on data pages """
def __init__(self, trainer: MDNRNNTrainer, state_dim: int) -> None:
self.trainer = trainer
self.state_dim = state_dim
self.reporter = WorldModelReporter(1)
def evaluate(self, tdp: MemoryNetworkInput) -> None:
def evaluate(self, tdp: MemoryNetworkInput) -> Dict[str, float]:
self.trainer.memory_network.mdnrnn.eval()
losses = self.trainer.compute_loss(tdp, state_dim=self.state_dim)
losses = self.trainer.get_loss(tdp, state_dim=self.state_dim)
detached_losses = {
"loss": losses["loss"].cpu().detach().item(),
"gmm": losses["gmm"].cpu().detach().item(),
@@ -34,10 +29,7 @@ class WorldModelLossEvaluator(object):
}
del losses
self.trainer.memory_network.mdnrnn.train()
self.reporter.report(**detached_losses)
def finish(self):
pass
return detached_losses
class FeatureImportanceEvaluator(object):
@@ -65,7 +57,6 @@ class FeatureImportanceEvaluator(object):
self.action_feature_num = action_feature_num
self.sorted_action_feature_start_indices = sorted_action_feature_start_indices
self.sorted_state_feature_start_indices = sorted_state_feature_start_indices
self.reporter = DebugToolsReporter()
def evaluate(self, batch: MemoryNetworkInput):
""" Calculate feature importance: setting each state/action feature to
@@ -80,7 +71,7 @@ class FeatureImportanceEvaluator(object):
state_feature_num = self.state_feature_num
feature_importance = torch.zeros(action_feature_num + state_feature_num)
orig_losses = self.trainer.compute_loss(batch, state_dim=state_dim)
orig_losses = self.trainer.get_loss(batch, state_dim=state_dim)
orig_loss = orig_losses["loss"].cpu().detach().item()
del orig_losses
@@ -124,7 +115,7 @@ class FeatureImportanceEvaluator(object):
not_terminal=batch.not_terminal,
step=None,
)
losses = self.trainer.compute_loss(new_batch, state_dim=state_dim)
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
del losses
@@ -151,7 +142,7 @@ class FeatureImportanceEvaluator(object):
not_terminal=batch.not_terminal,
step=None,
)
losses = self.trainer.compute_loss(new_batch, state_dim=state_dim)
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
feature_importance[i + action_feature_num] = (
losses["loss"].cpu().detach().item() - orig_loss
)
@@ -161,7 +152,6 @@ class FeatureImportanceEvaluator(object):
logger.info(
"**** Debug tool feature importance ****: {}".format(feature_importance)
)
self.reporter.report(feature_importance=feature_importance.tolist())
return {"feature_loss_increase": feature_importance.numpy()}
def compute_median_feature_value(self, features):
@@ -180,9 +170,6 @@ class FeatureImportanceEvaluator(object):
median_feature = features.mean(dim=0)
return median_feature
def finish(self):
pass
class FeatureSensitivityEvaluator(object):
""" Evaluate state feature sensitivity caused by varying actions """
@@ -196,7 +183,6 @@ class FeatureSensitivityEvaluator(object):
self.trainer = trainer
self.state_feature_num = state_feature_num
self.sorted_state_feature_start_indices = sorted_state_feature_start_indices
self.reporter = DebugToolsReporter()
def evaluate(self, batch: MemoryNetworkInput):
""" Calculate state feature sensitivity due to actions:
@@ -254,8 +240,4 @@ class FeatureSensitivityEvaluator(object):
logger.info(
"**** Debug tool feature sensitivity ****: {}".format(feature_sensitivity)
)
self.reporter.report(feature_sensitivity=feature_sensitivity.tolist())
return {"feature_sensitivity": feature_sensitivity.numpy()}
def finish(self):
pass
+1 -1
View File
@@ -19,7 +19,7 @@ import random
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.core.dataclasses import dataclass
from reagent.gym.envs.env_wrapper import EnvWrapper
+1 -1
View File
@@ -7,7 +7,7 @@ from typing import Callable, Optional
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from gym import spaces
from reagent.core.dataclasses import dataclass
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import Optional, Tuple
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from gym import spaces
from gym_minigrid.wrappers import ReseedWrapper
+1 -1
View File
@@ -14,7 +14,7 @@ from typing import Optional
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from gym.spaces import Box
from reagent.gym.envs.env_wrapper import EnvWrapper
+1 -1
View File
@@ -5,7 +5,7 @@ import logging
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
from reagent.core.dataclasses import dataclass
from reagent.gym.envs.env_wrapper import EnvWrapper
from reagent.gym.envs.wrappers.recsim import ValueWrapper
+1 -1
View File
@@ -4,7 +4,7 @@
from typing import Any, Optional
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
from reagent.gym.types import Sampler, Scorer
+1 -1
View File
@@ -4,7 +4,7 @@
from typing import Any, Optional, Tuple, Union
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
from reagent.gym.policies import Policy
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import List, Optional
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
import torch.nn.functional as F
from reagent.gym.policies.policy import Policy
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.gym.types import GaussianSamplerScore, Sampler
@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import reagent.core.types as rlt
import reagent.types as rlt
import torch
import torch.nn.functional as F
from reagent.gym.types import Sampler
@@ -41,9 +41,6 @@ class SoftmaxActionSampler(Sampler):
assert raw_action.shape == (
batch_size,
), f"{raw_action.shape} != ({batch_size}, )"
assert (
int(raw_action.max().item()) < num_actions
), f"Invalid action: {int(raw_action.max().item())}"
action = F.one_hot(raw_action, num_actions)
assert action.ndim == 2
log_prob = m.log_prob(raw_action)
@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.gym.types import Sampler
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.gym.types import GaussianSamplerScore, Scorer
from reagent.models.base import ModelBase
@@ -4,7 +4,7 @@
from typing import Optional, Tuple
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.gym.preprocessors.trainer_preprocessor import get_possible_actions_for_gym
from reagent.gym.types import Scorer
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import reagent.core.types as rlt
import reagent.types as rlt
import torch
import torch.nn.functional as F
from reagent.gym.types import Scorer
@@ -7,7 +7,7 @@ import logging
from typing import List, Optional, Tuple
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
import torch.nn.functional as F
from gym import Env, spaces
@@ -9,7 +9,7 @@ from typing import Optional
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
import torch.nn.functional as F
from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE
@@ -75,5 +75,5 @@ train_every_ts: 1
train_after_ts: 20000
num_train_episodes: 10
num_eval_episodes: 10
passing_score_bar: 190
passing_score_bar: 200
use_gpu: false
+9 -16
View File
@@ -17,21 +17,13 @@ from reagent.gym.agents.post_step import train_with_replay_buffer_post_step
from reagent.gym.envs.union import Env__Union
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes, run_episode
from reagent.gym.utils import build_normalizer, fill_replay_buffer
from reagent.model_managers.model_manager import ModelManager
from reagent.model_managers.union import ModelManager__Union
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.tensorboardX import summary_writer_context
from reagent.test.base.horizon_test_base import HorizonTestBase
from reagent.workflow.model_managers.union import ModelManager__Union
from torch.utils.tensorboard import SummaryWriter
try:
# Use internal runner or OSS otherwise
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
except ImportError:
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
# for seeding the environment
SEED = 0
logger = logging.getLogger(__name__)
@@ -116,12 +108,13 @@ def run_test(
normalization = build_normalizer(env)
logger.info(f"Normalization is: \n{pprint.pformat(normalization)}")
manager: ModelManager = model.value
runner = BatchRunner(use_gpu, manager, RewardOptions(), normalization)
trainer = runner.initialize_trainer()
reporter = manager.get_reporter()
trainer.reporter = reporter
training_policy = manager.create_policy(trainer)
manager = model.value
trainer = manager.initialize_trainer(
use_gpu=use_gpu,
reward_options=RewardOptions(),
normalization_data_map=normalization,
)
training_policy = manager.create_policy(serving=False)
replay_buffer = ReplayBuffer(
replay_capacity=replay_memory_size, batch_size=trainer.minibatch_size
@@ -172,7 +165,7 @@ def run_test(
f"{len(train_rewards)} episodes is less than < {passing_score_bar}.\n"
)
serving_policy = manager.create_serving_policy(normalization, trainer)
serving_policy = manager.create_policy(serving=True)
agent = Agent.create_for_env_with_serving_policy(env, serving_policy)
eval_rewards = evaluate_for_n_episodes(
+4 -14
View File
@@ -17,22 +17,14 @@ from reagent.gym.envs.gym import Gym
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
from reagent.gym.utils import build_normalizer, fill_replay_buffer
from reagent.model_managers.union import ModelManager__Union
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.runners.oss_batch_runner import OssBatchRunner
from reagent.tensorboardX import summary_writer_context
from reagent.test.base.horizon_test_base import HorizonTestBase
from reagent.workflow.model_managers.union import ModelManager__Union
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
try:
# Use internal runner or OSS otherwise
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
except ImportError:
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
# for seeding the environment
SEED = 0
logger = logging.getLogger(__name__)
@@ -86,7 +78,7 @@ class TestGymOffline(HorizonTestBase):
def evaluate_cem(env, manager, num_eval_episodes: int):
# NOTE: for CEM, serving isn't implemented
policy = manager.create_policy()
policy = manager.create_policy(serving=False)
agent = Agent.create_for_env(env, policy)
return evaluate_for_n_episodes(
n=num_eval_episodes, env=env, agent=agent, max_steps=env.max_steps
@@ -110,13 +102,11 @@ def run_test_offline(
logger.info(f"Normalization is: \n{pprint.pformat(normalization)}")
manager = model.value
runner = OssBatchRunner(
use_gpu,
manager,
trainer = manager.initialize_trainer(
use_gpu=use_gpu,
reward_options=RewardOptions(),
normalization_data_map=normalization,
)
trainer = runner.initialize_trainer()
# first fill the replay buffer to burn_in
replay_buffer = ReplayBuffer(
+7 -17
View File
@@ -4,7 +4,7 @@
import logging
import os
import unittest
from typing import Optional, cast
from typing import Optional
import torch
from reagent.core.types import RewardOptions
@@ -12,18 +12,10 @@ from reagent.gym.envs.env_wrapper import EnvWrapper
from reagent.gym.envs.gym import Gym
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
from reagent.gym.utils import build_normalizer, fill_replay_buffer
from reagent.model_managers.union import ModelManager__Union
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.runners.oss_batch_runner import OssBatchRunner
from reagent.test.base.horizon_test_base import HorizonTestBase
from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer
try:
# Use internal runner or OSS otherwise
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
except ImportError:
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
from reagent.workflow.model_managers.union import ModelManager__Union
logging.basicConfig(level=logging.INFO)
@@ -79,8 +71,8 @@ def train_seq2reward(
)
preprocessed_test_batch = trainer_preprocessor(test_batch)
adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim)
# valid_losses = trainer.get_loss(preprocessed_test_batch)
# print_seq2reward_losses(epoch, "validation", valid_losses)
valid_losses = trainer.get_loss(preprocessed_test_batch)
print_seq2reward_losses(epoch, "validation", valid_losses)
trainer.seq2reward_network.train()
return trainer
@@ -117,13 +109,11 @@ def train_seq2reward_and_compute_reward_mse(
env.seed(SEED)
manager = model.value
runner = OssBatchRunner(
use_gpu,
manager,
trainer = manager.initialize_trainer(
use_gpu=use_gpu,
reward_options=RewardOptions(),
normalization_data_map=build_normalizer(env),
)
trainer = cast(Seq2RewardTrainer, runner.initialize_trainer())
device = "cuda" if use_gpu else "cpu"
# pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
@@ -159,7 +149,7 @@ def train_seq2reward_and_compute_reward_mse(
)
preprocessed_test_batch = trainer_preprocessor(test_batch)
adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim)
losses = trainer.compute_loss(preprocessed_test_batch)
losses = trainer.get_loss(preprocessed_test_batch)
detached_losses = losses.cpu().detach().item()
trainer.seq2reward_network.train()
return detached_losses
+10 -22
View File
@@ -3,11 +3,11 @@
import logging
import os
import unittest
from typing import Dict, List, Optional, cast
from typing import Dict, List, Optional
import gym
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.core.types import RewardOptions
from reagent.evaluation.world_model_evaluator import (
@@ -21,21 +21,14 @@ from reagent.gym.envs.pomdp.state_embed_env import StateEmbedEnvironment
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
from reagent.gym.utils import build_normalizer, fill_replay_buffer
from reagent.model_managers.union import ModelManager__Union
from reagent.models.world_model import MemoryNetwork
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.test.base.horizon_test_base import HorizonTestBase
from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer
from reagent.workflow.model_managers.union import ModelManager__Union
from tqdm import tqdm
try:
# Use internal runner or OSS otherwise
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
except ImportError:
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -156,7 +149,7 @@ def train_mdnrnn(
batch_size=batch_size
)
preprocessed_test_batch = trainer_preprocessor(test_batch)
valid_losses = trainer.compute_loss(preprocessed_test_batch)
valid_losses = trainer.get_loss(preprocessed_test_batch)
print_mdnrnn_losses(epoch, "validation", valid_losses)
trainer.memory_network.mdnrnn.train()
return trainer
@@ -178,13 +171,11 @@ def train_mdnrnn_and_compute_feature_stats(
env.seed(SEED)
manager = model.value
runner = BatchRunner(
use_gpu,
manager,
trainer = manager.initialize_trainer(
use_gpu=use_gpu,
reward_options=RewardOptions(),
normalization_data_map=build_normalizer(env),
)
trainer = cast(MDNRNNTrainer, runner.initialize_trainer())
device = "cuda" if use_gpu else "cpu"
# pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
@@ -297,13 +288,11 @@ def train_mdnrnn_and_train_on_embedded_env(
env.seed(SEED)
embedding_manager = embedding_model.value
embedding_runner = BatchRunner(
use_gpu,
embedding_manager,
embedding_trainer = embedding_manager.initialize_trainer(
use_gpu=use_gpu,
reward_options=RewardOptions(),
normalization_data_map=build_normalizer(env),
)
embedding_trainer = cast(MDNRNNTrainer, embedding_runner.initialize_trainer())
device = "cuda" if use_gpu else "cpu"
embedding_trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
@@ -347,14 +336,13 @@ def train_mdnrnn_and_train_on_embedded_env(
state_max_value=state_max,
)
agent_manager = train_model.value
agent_trainer = agent_manager.build_trainer(
agent_trainer = agent_manager.initialize_trainer(
use_gpu=use_gpu,
reward_options=RewardOptions(),
# pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got
# `StateEmbedEnvironment`.
normalization_data_map=build_normalizer(embed_env),
)
agent_trainer.reporter = agent_manager.get_reporter()
device = "cuda" if use_gpu else "cpu"
agent_trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
agent_trainer,
@@ -371,7 +359,7 @@ def train_mdnrnn_and_train_on_embedded_env(
# evaluate model
rewards = []
policy = agent_manager.create_policy(agent_trainer)
policy = agent_manager.create_policy(serving=False)
# pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got
# `StateEmbedEnvironment`.
agent = Agent.create_for_env(embed_env, policy=policy, device=device)
+1 -1
View File
@@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field, fields
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
+2 -1
View File
@@ -1,9 +1,10 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import collections
import json
import logging
from dataclasses import asdict, fields, is_dataclass
from dataclasses import asdict, dataclass, fields, is_dataclass
from typing import Any, NamedTuple, Type, Union
-148
View File
@@ -1,148 +0,0 @@
#!/usr/bin/env python3
import logging
from typing import Dict, List, Optional, Tuple
from reagent.core import types as rlt
from reagent.core.dataclasses import dataclass, field
from reagent.core.types import (
Dataset,
PreprocessingOptions,
ReaderOptions,
RewardOptions,
TableSpec,
)
from reagent.core.union import ModelFeatureConfigProvider__Union
from reagent.data_fetchers.data_fetcher import DataFetcher
from reagent.evaluation.evaluator import Evaluator, get_metrics_to_score
from reagent.gym.policies.policy import Policy
from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer
from reagent.model_managers.model_manager import ModelManager
from reagent.models.base import ModelBase
from reagent.models.model_feature_config_provider import RawModelFeatureConfigProvider
from reagent.parameters import EvaluationParameters, NormalizationData, NormalizationKey
from reagent.preprocessing.batch_preprocessor import (
BatchPreprocessor,
DiscreteDqnBatchPreprocessor,
)
from reagent.preprocessing.preprocessor import Preprocessor
from reagent.preprocessing.types import InputColumn
from reagent.reporting.discrete_dqn_reporter import DiscreteDQNReporter
logger = logging.getLogger(__name__)
@dataclass
class DiscreteDQNBase(ModelManager):
target_action_distribution: Optional[List[float]] = None
state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
# pyre-fixme[28]: Unexpected keyword argument `raw`.
# pyre-fixme[28]: Unexpected keyword argument `raw`.
default_factory=lambda: ModelFeatureConfigProvider__Union(
raw=RawModelFeatureConfigProvider(float_feature_infos=[])
)
)
eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)
preprocessing_options: Optional[PreprocessingOptions] = None
reader_options: Optional[ReaderOptions] = None
def __post_init_post_parse__(self):
super().__init__()
def create_policy(self, trainer) -> Policy:
""" Create an online DiscreteDQN Policy from env. """
sampler = SoftmaxActionSampler(temperature=self.trainer_param.rl.temperature)
scorer = discrete_dqn_scorer(trainer.q_network)
return Policy(scorer=scorer, sampler=sampler)
@property
def state_feature_config(self) -> rlt.ModelFeatureConfig:
return self.state_feature_config_provider.value.get_model_feature_config()
def metrics_to_score(self, reward_options: RewardOptions) -> List[str]:
return get_metrics_to_score(reward_options.metric_reward_values)
@property
def should_generate_eval_dataset(self) -> bool:
return self.eval_parameters.calc_cpe_in_training
@property
def required_normalization_keys(self) -> List[str]:
return [NormalizationKey.STATE]
def run_feature_identification(
self, data_fetcher: DataFetcher, input_table_spec: TableSpec
) -> Dict[str, NormalizationData]:
preprocessing_options = self.preprocessing_options or PreprocessingOptions()
logger.info("Overriding whitelist_features")
state_features = [
ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
]
preprocessing_options = preprocessing_options._replace(
whitelist_features=state_features
)
return {
NormalizationKey.STATE: NormalizationData(
dense_normalization_parameters=data_fetcher.identify_normalization_parameters(
input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options
)
)
}
def query_data(
self,
data_fetcher: DataFetcher,
input_table_spec: TableSpec,
sample_range: Optional[Tuple[float, float]],
reward_options: RewardOptions,
) -> Dataset:
return data_fetcher.query_data(
input_table_spec=input_table_spec,
discrete_action=True,
actions=self.trainer_param.actions,
include_possible_actions=True,
sample_range=sample_range,
custom_reward_expression=reward_options.custom_reward_expression,
multi_steps=self.multi_steps,
gamma=self.trainer_param.rl.gamma,
)
@property
def multi_steps(self) -> Optional[int]:
return self.trainer_param.rl.multi_steps
def build_batch_preprocessor(
self,
reader_options: ReaderOptions,
use_gpu: bool,
batch_size: int,
normalization_data_map: Dict[str, NormalizationData],
reward_options: RewardOptions,
) -> BatchPreprocessor:
state_preprocessor = Preprocessor(
normalization_data_map[
NormalizationKey.STATE
].dense_normalization_parameters,
use_gpu=use_gpu,
)
return DiscreteDqnBatchPreprocessor(
num_actions=len(self.trainer_param.actions),
state_preprocessor=state_preprocessor,
use_gpu=use_gpu,
)
def get_reporter(self):
return DiscreteDQNReporter(
self.trainer_param.actions,
target_action_distribution=self.target_action_distribution,
)
def get_evaluator(self, trainer, reward_options: RewardOptions):
return Evaluator(
self.trainer_param.actions,
self.trainer_param.rl.gamma,
trainer,
metrics_to_score=self.metrics_to_score(reward_options),
)
-130
View File
@@ -1,130 +0,0 @@
#!/usr/bin/env python3
import abc
import logging
from typing import Dict, List, Optional, Tuple
import torch
from reagent.core.registry_meta import RegistryMeta
from reagent.core.types import Dataset, ReaderOptions, RewardOptions, TableSpec
from reagent.data_fetchers.data_fetcher import DataFetcher
from reagent.gym.policies.policy import Policy
from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model
from reagent.parameters import NormalizationData
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
from reagent.training.trainer import Trainer
logger = logging.getLogger(__name__)
class ModelManager(metaclass=RegistryMeta):
"""
ModelManager manages how to train models.
Each type of models can have their own config type, implemented as
`config_type()` class method. `__init__()` of the concrete class must take
this type.
ModelManager abstracts over common phases of training, i.e.,:
1. `run_feature_identification()` defines how to derive feature preprocessing
parameters from given data.
2. `query_data()` massages the input table into the format expected by the trainer
3. `initialize_trainer()` creates the trainer
4. `train()`
5. `build_serving_module()` builds the module for prediction
6. `save_trainer()` saves the trainer for warmstarting
"""
@abc.abstractmethod
def run_feature_identification(
self, data_fetcher: DataFetcher, input_table_spec: TableSpec
) -> Dict[str, NormalizationData]:
"""
Derive preprocessing parameters from data. The keys of the dict should
match the keys from `required_normalization_keys()`
"""
pass
@property
@abc.abstractmethod
def required_normalization_keys(self) -> List[str]:
""" Get the normalization keys required for current instance """
pass
@property
@abc.abstractmethod
def should_generate_eval_dataset(self) -> bool:
raise NotImplementedError()
def get_evaluator(self, trainer, reward_options: RewardOptions):
return None
@abc.abstractmethod
def query_data(
self,
data_fetcher: DataFetcher,
input_table_spec: TableSpec,
sample_range: Optional[Tuple[float, float]],
reward_options: RewardOptions,
) -> Dataset:
"""
Massage input table into the format expected by the trainer
"""
pass
@abc.abstractmethod
def get_reporter(self):
"""
Get the reporter that displays statistics after training
"""
pass
@abc.abstractmethod
def build_batch_preprocessor(
self,
reader_options: ReaderOptions,
use_gpu: bool,
batch_size: int,
normalization_data_map: Dict[str, NormalizationData],
reward_options: RewardOptions,
) -> BatchPreprocessor:
"""
The Batch Preprocessor is a module that transforms data to a form that can be (1) read by the trainer
or (2) used in part of the serving module. For training, the batch preprocessor is typically run
on reader machines in parallel so the GPUs on the trainer machines can be fully utilized.
"""
pass
@abc.abstractmethod
def build_trainer(
self,
use_gpu: bool,
normalization_data_map: Dict[str, NormalizationData],
reward_options: RewardOptions,
) -> Trainer:
"""
Implement this to build the trainer, given the config
"""
pass
def create_policy(self, trainer) -> Policy:
""" Create a Policy from env. """
raise NotImplementedError()
def create_serving_policy(
self, normalization_data_map: Dict[str, NormalizationData], trainer
) -> Policy:
""" Create an online Policy from env. """
return create_predictor_policy_from_model(
self.build_serving_module(normalization_data_map, trainer)
)
@abc.abstractmethod
def build_serving_module(
self, normalization_data_map: Dict[str, NormalizationData], trainer
) -> torch.nn.Module:
"""
Returns TorchScript module to be used in predictor
"""
pass
@@ -1,96 +0,0 @@
#!/usr/bin/env python3
import logging
from typing import Dict
import torch
from reagent.core.dataclasses import dataclass, field
from reagent.core.types import RewardOptions
from reagent.model_managers.parametric_dqn_base import ParametricDQNBase
from reagent.net_builder.parametric_dqn.fully_connected import FullyConnected
from reagent.net_builder.unions import ParametricDQNNetBuilder__Union
from reagent.parameters import NormalizationData, NormalizationKey, param_hash
from reagent.preprocessing.normalization import (
get_feature_config,
get_num_output_features,
)
from reagent.training import ParametricDQNTrainer, ParametricDQNTrainerParameters
logger = logging.getLogger(__name__)
@dataclass
class ParametricDQN(ParametricDQNBase):
__hash__ = param_hash
trainer_param: ParametricDQNTrainerParameters = field(
default_factory=ParametricDQNTrainerParameters
)
net_builder: ParametricDQNNetBuilder__Union = field(
# pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
# pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
default_factory=lambda: ParametricDQNNetBuilder__Union(
FullyConnected=FullyConnected()
)
)
def __post_init_post_parse__(self):
super().__post_init_post_parse__()
def build_trainer(
self,
use_gpu: bool,
normalization_data_map: Dict[str, NormalizationData],
reward_options: RewardOptions,
) -> ParametricDQNTrainer:
net_builder = self.net_builder.value
q_network = net_builder.build_q_network(
normalization_data_map[NormalizationKey.STATE],
normalization_data_map[NormalizationKey.ACTION],
)
# Metrics + reward
reward_output_dim = len(self.metrics_to_score(reward_options)) + 1
reward_network = net_builder.build_q_network(
normalization_data_map[NormalizationKey.STATE],
normalization_data_map[NormalizationKey.ACTION],
output_dim=reward_output_dim,
)
if use_gpu:
q_network = q_network.cuda()
reward_network = reward_network.cuda()
q_network_target = q_network.get_target_network()
trainer = ParametricDQNTrainer(
q_network=q_network,
q_network_target=q_network_target,
reward_network=reward_network,
use_gpu=use_gpu,
# pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
# `asdict`.
# pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
# `asdict`.
**self.trainer_param.asdict(),
)
# HACK: injecting num_actions to build policies for gym
trainer.num_gym_actions = get_num_output_features(
normalization_data_map[
NormalizationKey.ACTION
].dense_normalization_parameters
)
return trainer
def build_serving_module(
self,
normalization_data_map: Dict[str, NormalizationData],
trainer: ParametricDQNTrainer,
) -> torch.nn.Module:
net_builder = self.net_builder.value
return net_builder.build_serving_module(
trainer.q_network,
normalization_data_map[NormalizationKey.STATE],
normalization_data_map[NormalizationKey.ACTION],
)
+1 -1
View File
@@ -5,7 +5,7 @@ import math
from typing import List, Optional
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.fully_connected_network import FullyConnectedNetwork
from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE
+1 -1
View File
@@ -5,7 +5,7 @@ from copy import deepcopy
from typing import Any, Optional
import torch.nn as nn
from reagent.core import types as rlt
from reagent import types as rlt
# add ABCMeta once https://github.com/sphinx-doc/sphinx/issues/5995 is fixed
+1 -1
View File
@@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
+1 -1
View File
@@ -17,7 +17,7 @@ import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.world_model import MemoryNetwork
from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE
+1 -1
View File
@@ -4,7 +4,7 @@
from typing import List
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.fully_connected_network import FullyConnectedNetwork
+1 -1
View File
@@ -4,7 +4,7 @@
from typing import Optional
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.fully_connected_network import FullyConnectedNetwork
+1 -1
View File
@@ -5,7 +5,7 @@ import logging
from typing import List, Optional, Tuple
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.critic import FullyConnectedCritic
from reagent.models.dqn import FullyConnectedDQN
+1 -1
View File
@@ -4,7 +4,7 @@
from typing import Dict, List
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
+1 -1
View File
@@ -8,7 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as f
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.torch_utils import stack
from torch.distributions.normal import Normal
@@ -2,7 +2,7 @@
import abc
import reagent.core.types as rlt
import reagent.types as rlt
from reagent.core.dataclasses import dataclass
from reagent.core.registry_meta import RegistryMeta
+1 -1
View File
@@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
+1 -1
View File
@@ -10,7 +10,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from torch.nn.parallel.distributed import DistributedDataParallel
+1 -1
View File
@@ -6,7 +6,7 @@ import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.seq2slate import (
DECODER_START_SYMBOL,
+1 -1
View File
@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.models.mdn_rnn import MDNRNN
@@ -3,7 +3,7 @@
import abc
from typing import List
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
from reagent.core.registry_meta import RegistryMeta
+1 -1
View File
@@ -2,7 +2,7 @@
from typing import List
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.core.dataclasses import dataclass, field
from reagent.models.base import ModelBase
from reagent.models.dueling_q_network import DuelingQNetwork
@@ -2,7 +2,7 @@
from typing import List
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.core.dataclasses import dataclass, field
from reagent.models.base import ModelBase
from reagent.models.dqn import FullyConnectedDQN
@@ -3,7 +3,7 @@
from typing import List
import reagent.models as models
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.core.dataclasses import dataclass, field
from reagent.net_builder.discrete_dqn_net_builder import DiscreteDQNNetBuilder
from reagent.parameters import NormalizationData, param_hash
@@ -3,7 +3,7 @@
import abc
from typing import List
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
from reagent.core.registry_meta import RegistryMeta
@@ -3,7 +3,7 @@
import abc
from typing import List
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
from reagent.core.registry_meta import RegistryMeta
@@ -687,27 +687,15 @@ class NeuralDualDICE(RLEstimator):
), "Expected all fields to be present"
tgt_dist = input.target_policy.action_dist(t.state)
tgt_action = tgt_dist.sample()[0]
samples["init_state"].append(
state.value.cpu().numpy()
if isinstance(state.value, torch.Tensor)
else state.value
)
samples["init_state"].append(state.value)
samples["init_action"].append(
torch.nn.functional.one_hot(
torch.tensor(tgt_init_action.value, dtype=torch.long),
self.action_dim,
).float()
)
samples["last_state"].append(
t.last_state.value.cpu().numpy()
if isinstance(t.last_state.value, torch.Tensor)
else t.last_state.value
)
samples["state"].append(
t.state.value.cpu().numpy()
if isinstance(t.state.value, torch.Tensor)
else t.state.value
)
samples["last_state"].append(t.last_state.value)
samples["state"].append(t.state.value)
samples["log_action"].append(
torch.nn.functional.one_hot(
torch.tensor(t.action.value, dtype=torch.long), self.action_dim
-1
View File
@@ -58,7 +58,6 @@ class MDNRNNTrainerParameters(BaseDataClass):
action_dim: int = 2
action_names: List[str] = field(default_factory=lambda: [])
multi_steps: int = 1
shuffle_training_data: bool = False
@dataclass(frozen=True)
+1 -1
View File
@@ -5,7 +5,7 @@ from enum import Enum
from typing import Dict, Optional
from reagent.core.dataclasses import dataclass
from reagent.core.types import BaseDataClass
from reagent.types import BaseDataClass
class LearningMethod(Enum):
+1 -1
View File
@@ -4,7 +4,7 @@
import logging
from typing import Dict, List, Optional, Tuple
import reagent.core.types as rlt
import reagent.types as rlt
import torch
from reagent.models.base import ModelBase
from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet
+1 -1
View File
@@ -6,7 +6,7 @@ from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.preprocessing.preprocessor import Preprocessor
+13 -13
View File
@@ -7,24 +7,12 @@ from dataclasses import asdict
from typing import Dict, List, Optional, Tuple
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import six
import torch
from reagent.parameters import NormalizationData, NormalizationParameters
from reagent.preprocessing import identify_types
from reagent.preprocessing.identify_types import DEFAULT_MAX_UNIQUE_ENUM, FEATURE_TYPES
from reagent.preprocessing.normalization_constants import (
BOX_COX_MARGIN,
BOX_COX_MAX_STDDEV,
DEFAULT_MAX_QUANTILE_SIZE,
DEFAULT_NUM_SAMPLES,
DEFAULT_QUANTILE_K2_THRESHOLD,
EPS,
MAX_FEATURE_VALUE,
MIN_FEATURE_VALUE,
MINIMUM_SAMPLES_TO_IDENTIFY,
MISSING_VALUE,
)
from scipy import stats
from scipy.stats.mstats import mquantiles
@@ -32,6 +20,18 @@ from scipy.stats.mstats import mquantiles
logger = logging.getLogger(__name__)
BOX_COX_MAX_STDDEV = 1e8
BOX_COX_MARGIN = 1e-4
MISSING_VALUE = -1337.1337
DEFAULT_QUANTILE_K2_THRESHOLD = 1000.0
MINIMUM_SAMPLES_TO_IDENTIFY = 20
DEFAULT_MAX_QUANTILE_SIZE = 20
DEFAULT_NUM_SAMPLES = 100000
MAX_FEATURE_VALUE = 6.0
MIN_FEATURE_VALUE = MAX_FEATURE_VALUE * -1
EPS = 1e-6
def no_op_feature():
return NormalizationParameters(
identify_types.CONTINUOUS, None, 0, 0, 1, None, None, None, None
@@ -1,19 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from reagent.preprocessing.identify_types import ( # noqa
DEFAULT_MAX_UNIQUE_ENUM,
FEATURE_TYPES,
)
BOX_COX_MAX_STDDEV = 1e8
BOX_COX_MARGIN = 1e-4
MISSING_VALUE = -1337.1337
DEFAULT_QUANTILE_K2_THRESHOLD = 1000.0
MINIMUM_SAMPLES_TO_IDENTIFY = 20
DEFAULT_MAX_QUANTILE_SIZE = 20
DEFAULT_NUM_SAMPLES = 100000
MAX_FEATURE_VALUE = 6.0
MIN_FEATURE_VALUE = MAX_FEATURE_VALUE * -1
EPS = 1e-6
+1 -1
View File
@@ -4,7 +4,7 @@
import logging
from typing import Dict, Tuple
import reagent.core.types as rlt
import reagent.types as rlt
import torch
+1 -1
View File
@@ -5,7 +5,7 @@ import logging
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import reagent.core.types as rlt
import reagent.types as rlt
import torch
import torch.nn.functional as F
from reagent.parameters import NormalizationData
+3 -4
View File
@@ -6,10 +6,9 @@ from typing import Optional
from reagent.core.dataclasses import dataclass
from reagent.core.result_types import NoPublishingResults
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.types import RecurringPeriod
from reagent.model_managers.model_manager import ModelManager
from reagent.core.types import RecurringPeriod, RLTrainingOutput
from reagent.publishers.model_publisher import ModelPublisher
from reagent.workflow.model_managers.model_manager import ModelManager
try:
@@ -73,7 +72,7 @@ if HAS_TINYDB:
child_workflow_id: int,
recurring_period: Optional[RecurringPeriod],
) -> NoPublishingResults:
path = training_output.local_output_path
path = training_output.output_path
assert path is not None, f"Given path is None."
assert os.path.exists(path), f"Given path {path} doesn't exist."
Model = Query()
+4 -5
View File
@@ -5,10 +5,9 @@ import inspect
from typing import Optional
from reagent.core.registry_meta import RegistryMeta
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.types import RecurringPeriod
from reagent.model_managers.model_manager import ModelManager
from reagent.reporting.result_registries import PublishingResult
from reagent.core.types import RecurringPeriod, RLTrainingOutput
from reagent.workflow.model_managers.model_manager import ModelManager
from reagent.workflow.result_registries import PublishingResult
class ModelPublisher(metaclass=RegistryMeta):
@@ -39,7 +38,7 @@ class ModelPublisher(metaclass=RegistryMeta):
recurring_period,
)
# Avoid circular dependency at import time
from reagent.core.union import PublishingResult__Union
from reagent.core.types import PublishingResult__Union
# We need to use inspection because the result can be a future when running on
# FBL
+2 -3
View File
@@ -4,10 +4,9 @@ from typing import Optional
from reagent.core.dataclasses import dataclass
from reagent.core.result_types import NoPublishingResults
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.types import RecurringPeriod
from reagent.model_managers.model_manager import ModelManager
from reagent.core.types import RecurringPeriod, RLTrainingOutput
from reagent.publishers.model_publisher import ModelPublisher
from reagent.workflow.model_managers.model_manager import ModelManager
@dataclass
-24
View File
@@ -1,24 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
if True: # To prevent auto sorting of inputs
# Triggering registration to registries
import reagent.core.result_types # noqa
import reagent.reporting.oss_training_reports # noqa
from reagent.model_managers.union import * # noqa
if IS_FB_ENVIRONMENT:
import reagent.core.fb.fb_result_types # noqa
# Register all unions
from reagent.core.union import * # noqa
from reagent.model_managers.union import * # noqa
from reagent.optimizer.union import * # noqa
from reagent.publishers.union import * # noqa
from reagent.validators.union import * # noqa
if IS_FB_ENVIRONMENT:
from reagent.model_managers.fb.union import * # noqa
View File
@@ -1,55 +0,0 @@
#!/usr/bin/env python3
import itertools
import logging
from reagent.core import aggregators as agg
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.union import TrainingReport__Union
from reagent.reporting.oss_training_reports import OssActorCriticTrainingReport
from reagent.reporting.reporter_base import ReporterBase
logger = logging.getLogger(__name__)
class ActorCriticReporter(ReporterBase):
def __init__(self, report_interval: int = 100):
aggregators = itertools.chain(
[
("cpe_results", agg.AppendAggregator("cpe_details")),
("td_loss", agg.MeanAggregator("td_loss", interval=report_interval)),
(
"reward_loss",
agg.MeanAggregator("reward_loss", interval=report_interval),
),
(
"recent_rewards",
agg.RecentValuesAggregator(
"logged_rewards", interval=report_interval
),
),
],
[
(
f"{key}_tb",
agg.TensorBoardHistogramAndMeanAggregator(
key, log_key, interval=report_interval
),
)
for key, log_key in [
("td_loss", "td_loss"),
("reward_loss", "reward_loss"),
("logged_propensities", "propensities/logged"),
("logged_rewards", "reward/logged"),
]
],
)
super().__init__(aggregators)
# TODO: T71636196 write this for OSS
def publish(self) -> RLTrainingOutput:
report = OssActorCriticTrainingReport()
return RLTrainingOutput(
training_report=TrainingReport__Union(oss_actor_critic_report=report)
)
-109
View File
@@ -1,109 +0,0 @@
#!/usr/bin/env python3
import itertools
import logging
from typing import List, Optional
import torch
from reagent.core import aggregators as agg
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.union import TrainingReport__Union
from reagent.reporting.oss_training_reports import OssDQNTrainingReport
from reagent.reporting.reporter_base import ReporterBase
logger = logging.getLogger(__name__)
class DiscreteDQNReporter(ReporterBase):
def __init__(
self,
actions: List[str],
report_interval: int = 100,
target_action_distribution: Optional[List[float]] = None,
recent_window_size: int = 100,
):
aggregators = itertools.chain(
[
("CPE Results", agg.AppendAggregator("cpe_details")),
("TD Loss", agg.MeanAggregator("td_loss", interval=report_interval)),
(
"Reward Loss",
agg.MeanAggregator("reward_loss", interval=report_interval),
),
(
"Model Action Values",
agg.FunctionsByActionAggregator(
"model_values",
actions,
{"mean": torch.mean, "std": torch.std},
interval=report_interval,
),
),
(
"Logged Actions",
agg.ActionCountAggregator(
"logged_actions", actions, interval=report_interval
),
),
(
"model_action",
agg.ActionCountAggregator(
"model_action_idxs", actions, interval=report_interval
),
),
(
"Recent Logged Rewards",
agg.RecentValuesAggregator(
"logged_rewards", interval=report_interval
),
),
],
[
(
f"{key}_tb",
agg.TensorBoardActionCountAggregator(
key, title, actions, interval=report_interval
),
)
for key, title in [
("logged_actions", "logged"),
("model_action_idxs", "model"),
]
],
[
(
f"{key}_tb",
agg.TensorBoardHistogramAndMeanAggregator(
key, log_key, interval=report_interval
),
)
for key, log_key in [
("td_loss", "td_loss"),
("reward_loss", "reward_loss"),
("logged_propensities", "propensities/logged"),
("logged_rewards", "reward/logged"),
]
],
[
(
f"{key}_tb",
agg.TensorBoardActionHistogramAndMeanAggregator(
key, category, title, actions, interval=report_interval
),
)
for key, category, title in [
("model_propensities", "propensities", "model"),
("model_rewards", "reward", "model"),
("model_values", "value", "model"),
]
],
)
super().__init__(aggregators)
self.target_action_distribution = target_action_distribution
self.recent_window_size = recent_window_size
def publish(self) -> RLTrainingOutput:
return RLTrainingOutput(
training_report=TrainingReport__Union(oss_dqn_report=OssDQNTrainingReport())
)
-62
View File
@@ -1,62 +0,0 @@
#!/usr/bin/env python3
from typing import List, Optional
from reagent.core.dataclasses import dataclass
from reagent.evaluation.cpe import CpeEstimate
from reagent.reporting.training_reports import TrainingReport
@dataclass
class OssDQNTrainingReport(TrainingReport):
__registry_name__ = "oss_dqn_report"
td_loss: Optional[float] = None
mc_loss: Optional[float] = None
reward_ips: Optional[CpeEstimate] = None
reward_dm: Optional[CpeEstimate] = None
reward_dr: Optional[CpeEstimate] = None
value_sequential_dr: Optional[CpeEstimate] = None
value_weighted_dr: Optional[CpeEstimate] = None
value_magic_dr: Optional[CpeEstimate] = None
@dataclass
class OssActorCriticTrainingReport(TrainingReport):
__registry_name__ = "oss_actor_critic_report"
@dataclass
class OssParametricDQNTrainingReport(TrainingReport):
__registry_name__ = "oss_parametric_dqn_report"
td_loss: Optional[float] = None
mc_loss: Optional[float] = None
reward_ips: Optional[CpeEstimate] = None
reward_dm: Optional[CpeEstimate] = None
reward_dr: Optional[CpeEstimate] = None
value_sequential_dr: Optional[CpeEstimate] = None
value_weighted_dr: Optional[CpeEstimate] = None
value_magic_dr: Optional[CpeEstimate] = None
@dataclass
class OssWorldModelTrainingReport(TrainingReport):
__registry_name__ = "oss_world_model_report"
loss: List[float]
gmm: List[float]
bce: List[float]
mse: List[float]
@dataclass
class DebugToolsReport(TrainingReport):
__registry_name__ = "oss_debug_tools_report"
feature_importance: Optional[List[float]] = None
feature_sensitivity: Optional[List[float]] = None
@dataclass
class OssRankingModelTrainingReport(TrainingReport):
__registry_name__ = "oss_ranking_model_training_report"
@@ -1,64 +0,0 @@
#!/usr/bin/env python3
import itertools
import logging
from typing import List, Optional
from reagent.core import aggregators as agg
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.union import TrainingReport__Union
from reagent.reporting.oss_training_reports import OssParametricDQNTrainingReport
from reagent.reporting.reporter_base import ReporterBase
logger = logging.getLogger(__name__)
class ParametricDQNReporter(ReporterBase):
def __init__(
self,
report_interval: int = 100,
target_action_distribution: Optional[List[float]] = None,
recent_window_size: int = 100,
):
aggregators = itertools.chain(
[
("cpe_results", agg.AppendAggregator("cpe_results")),
("td_loss", agg.MeanAggregator("td_loss", interval=report_interval)),
(
"reward_loss",
agg.MeanAggregator("reward_loss", interval=report_interval),
),
(
"logged_rewards",
agg.RecentValuesAggregator(
"logged_rewards", interval=report_interval
),
),
],
[
(
f"{key}_tb",
agg.TensorBoardHistogramAndMeanAggregator(
key, log_key, interval=report_interval
),
)
for key, log_key in [
("td_loss", "td_loss"),
("reward_loss", "reward_loss"),
("logged_propensities", "propensities/logged"),
("logged_rewards", "reward/logged"),
]
],
)
super().__init__(aggregators)
self.target_action_distribution = target_action_distribution
self.recent_window_size = recent_window_size
# TODO: T71636218 write this for OSS
def publish(self) -> RLTrainingOutput:
cpe_results = self.cpe_results.values
report = OssParametricDQNTrainingReport()
return RLTrainingOutput(
training_report=TrainingReport__Union(oss_parametric_dqn_report=report)
)
@@ -1,60 +0,0 @@
#!/usr/bin/env python3
import logging
from reagent.core import aggregators as agg
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.union import TrainingReport__Union
from reagent.reporting.oss_training_reports import OssRankingModelTrainingReport
from reagent.reporting.reporter_base import ReporterBase
logger = logging.getLogger(__name__)
class RankingModelReporter(ReporterBase):
def __init__(self, report_interval: int = 100):
"""
For Ranking model:
'pg' (policy gradient loss)
'baseline' (the baseline model's loss, usually for fitting V(s))
'kendall_tau' (kendall_tau coefficient between advantage and log_probs,
used in evaluation page handlers)
'kendaull_tau_p_value' (the p-value for kendall_tau test, used in
evaluation page handlers)
"""
aggregators = [
("pg", agg.MeanAggregator("pg", interval=report_interval)),
("baseline", agg.MeanAggregator("baseline", interval=report_interval)),
(
"kendall_tau",
agg.MeanAggregator("kendall_tau", interval=report_interval),
),
(
"kendaull_tau_p_value",
agg.MeanAggregator("kendaull_tau_p_value", interval=report_interval),
),
] + [
(
f"{key}_tb",
agg.TensorBoardHistogramAndMeanAggregator(
key, log_key, interval=report_interval
),
)
for key, log_key in [
("pg", "pg"),
("baseline", "baseline"),
("kendall_tau", "kendall_tau"),
("kendaull_tau_p_value", "kendaull_tau_p_value"),
]
]
super().__init__(aggregators)
# TODO: T71636236 write this for OSS
def publish(self) -> RLTrainingOutput:
report = OssRankingModelTrainingReport()
return RLTrainingOutput(
training_report=TrainingReport__Union(
oss_ranking_model_training_report=report
)
)
-59
View File
@@ -1,59 +0,0 @@
#!/usr/bin/env python3
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
import torch
from reagent.core import aggregators as agg
from reagent.core.rl_training_output import RLTrainingOutput
logger = logging.getLogger(__name__)
class ReporterBase:
def __init__(self, aggregators: List[Tuple[str, agg.Aggregator]]):
self.aggregators = OrderedDict(aggregators)
def report(self, **kwargs: Dict[str, Any]):
for name, value in kwargs.items():
for aggregator in self.aggregators.values():
if aggregator.key == name:
aggregator.update(name, value)
def finish_epoch(self):
for aggregator in self.aggregators.values():
aggregator.finish_epoch()
def publish(self) -> RLTrainingOutput:
pass
def get_recent(self, key: str, count: int, average: bool):
for _, aggregator in self.aggregators.items():
if aggregator.key == key:
recent = aggregator.aggregator.get_recent(count)
if len(recent) == 0:
return None
if average:
return float(torch.mean(torch.tensor(recent)))
return recent
return None
def get_all(self, key: str, average: bool):
for _, aggregator in self.aggregators.items():
if aggregator.key == key:
all_data = aggregator.aggregator.get_all()
if len(all_data) == 0:
return None
if average:
return float(torch.mean(torch.tensor(all_data)))
return all_data
return None
def __getattr__(self, key: str):
return self.aggregators[key]
def end_epoch(self):
for aggregator in self.aggregators.values():
aggregator.end_epoch()
-363
View File
@@ -1,363 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import logging
import math
from collections import deque
from typing import Deque, List, NamedTuple, Optional
import numpy as np
import torch
from reagent.tensorboardX import SummaryWriterContext
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
LOSS_REPORT_INTERVAL = 100
class BatchStats(NamedTuple):
td_loss: Optional[torch.Tensor] = None
reward_loss: Optional[torch.Tensor] = None
imitator_loss: Optional[torch.Tensor] = None
logged_actions: Optional[torch.Tensor] = None
logged_propensities: Optional[torch.Tensor] = None
logged_rewards: Optional[torch.Tensor] = None
logged_values: Optional[torch.Tensor] = None
model_propensities: Optional[torch.Tensor] = None
model_rewards: Optional[torch.Tensor] = None
model_values: Optional[torch.Tensor] = None
model_values_on_logged_actions: Optional[torch.Tensor] = None
model_action_idxs: Optional[torch.Tensor] = None
def write_summary(self, actions: List[str]):
if actions:
for field, log_key in [
("logged_actions", "actions/logged"),
("model_action_idxs", "actions/model"),
]:
val = getattr(self, field)
if val is None:
continue
for i, action in enumerate(actions):
# pyre-fixme[16]: `SummaryWriterContext` has no attribute
# `add_scalar`.
SummaryWriterContext.add_scalar(
"{}/{}".format(log_key, action), (val == i).sum().item()
)
for field, log_key in [
("td_loss", "td_loss"),
("imitator_loss", "imitator_loss"),
("reward_loss", "reward_loss"),
("logged_propensities", "propensities/logged"),
("logged_rewards", "reward/logged"),
("logged_values", "value/logged"),
("model_values_on_logged_actions", "value/model_logged_action"),
]:
val = getattr(self, field)
if val is None:
continue
assert len(val.shape) == 1 or (
len(val.shape) == 2 and val.shape[1] == 1
), "Unexpected shape for {}: {}".format(field, val.shape)
self._log_histogram_and_mean(log_key, val)
for field, log_key in [
("model_propensities", "propensities/model"),
("model_rewards", "reward/model"),
("model_values", "value/model"),
]:
val = getattr(self, field)
if val is None:
continue
if (
len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1)
) and not actions:
self._log_histogram_and_mean(log_key, val)
elif len(val.shape) == 2 and val.shape[1] == len(actions):
for i, action in enumerate(actions):
self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i])
else:
raise ValueError(
"Unexpected shape for {}: {}; actions: {}".format(
field, val.shape, actions
)
)
def _log_histogram_and_mean(self, log_key, val):
try:
SummaryWriterContext.add_histogram(log_key, val)
SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean())
except ValueError:
logger.warning(
f"Cannot create histogram for key: {log_key}; "
"this is likely because you have NULL value in your input; "
f"value: {val}"
)
raise
@staticmethod
def add_custom_scalars(action_names: Optional[List[str]]):
if not action_names:
return
SummaryWriterContext.add_custom_scalars_multilinechart(
[
"propensities/model/{}/mean".format(action_name)
for action_name in action_names
],
category="propensities",
title="model",
)
SummaryWriterContext.add_custom_scalars_multilinechart(
[
"propensities/logged/{}/mean".format(action_name)
for action_name in action_names
],
category="propensities",
title="logged",
)
SummaryWriterContext.add_custom_scalars_multilinechart(
["actions/logged/{}".format(action_name) for action_name in action_names],
category="actions",
title="logged",
)
SummaryWriterContext.add_custom_scalars_multilinechart(
["actions/model/{}".format(action_name) for action_name in action_names],
category="actions",
title="model",
)
def merge_tensor_namedtuple_list(l, cls):
def merge_tensor(f):
vals = [getattr(e, f) for e in l]
not_none_vals = [v for v in vals if v is not None]
assert len(not_none_vals) == 0 or len(not_none_vals) == len(vals)
if not not_none_vals:
return None
return torch.cat(not_none_vals, dim=0)
return cls(**{f: merge_tensor(f) for f in cls._fields})
class StatsByAction(object):
def __init__(self, actions):
self.stats = {action: [] for action in actions}
def append(self, stats):
for k in stats:
assert k in self.stats
for k in self.stats:
v = stats.get(k, 0)
if isinstance(v, torch.Tensor):
v = v.item()
self.stats[k].append(v)
def items(self):
return self.stats.items()
def __len__(self):
return len(self.stats)
class NoOpTrainingReporter:
def report(self, **kwargs):
pass
def flush(self):
pass
class TrainingReporter(object):
RECENT_WINDOW_SIZE = 100
def __init__(self, action_names: Optional[List[str]] = None):
assert action_names is None or len(action_names) > 0
self.action_names: List[str] = action_names or []
self.loss_report_interval = LOSS_REPORT_INTERVAL
BatchStats.add_custom_scalars(action_names)
self.clear()
def clear(self):
self.running_reward: Deque[float] = deque(maxlen=int(1e6))
self.td_loss: List[float] = []
self.reward_loss: List[float] = []
self.imitator_loss: List[float] = []
self.logged_action_q_value: List[float] = []
self.logged_action_counts = {action: 0 for action in self.action_names}
self.model_values = StatsByAction(self.action_names)
self.model_value_stds = StatsByAction(self.action_names)
self.model_action_counts = StatsByAction(self.action_names)
self.model_action_counts_cumulative = {
action: 0 for action in self.action_names
}
self.model_action_distr = StatsByAction(self.action_names)
self.incoming_stats: List[BatchStats] = []
@property
def num_batches(self):
return len(self.td_loss)
def report(self, **kwargs):
def _to_tensor(v):
if v is None:
return None
if not isinstance(v, torch.Tensor):
v = torch.tensor(v)
if len(v.shape) == 0:
v = v.reshape(1)
return v.detach().cpu()
kwargs = {k: _to_tensor(v) for k, v in kwargs.items()}
batch_stats = BatchStats(**kwargs)
self.incoming_stats.append(batch_stats)
if len(self.incoming_stats) >= self.loss_report_interval:
self.flush()
@torch.no_grad()
def flush(self):
if not len(self.incoming_stats):
logger.info("Nothing to report")
return
logger.info("Loss on {} batches".format(len(self.incoming_stats)))
batch_stats = merge_tensor_namedtuple_list(self.incoming_stats, BatchStats)
batch_stats.write_summary(self.action_names)
print_details = "Loss:\n"
td_loss_mean = float(batch_stats.td_loss.mean())
self.td_loss.append(td_loss_mean)
print_details = print_details + "TD LOSS: {0:.3f}\n".format(td_loss_mean)
if batch_stats.logged_rewards is not None:
flattened_rewards = torch.flatten(batch_stats.logged_rewards).tolist()
self.running_reward.extend(flattened_rewards)
if batch_stats.reward_loss is not None:
reward_loss_mean = float(batch_stats.reward_loss.mean())
self.reward_loss.append(reward_loss_mean)
print_details = print_details + "REWARD LOSS: {0:.3f}\n".format(
reward_loss_mean
)
if batch_stats.imitator_loss is not None:
imitator_loss_mean = float(batch_stats.imitator_loss.mean())
self.imitator_loss.append(imitator_loss_mean)
print_details = print_details + "IMITATOR LOSS: {0:.3f}\n".format(
imitator_loss_mean
)
if batch_stats.model_values is not None and self.action_names:
self.model_values.append(
dict(zip(self.action_names, batch_stats.model_values.mean(dim=0)))
)
self.model_value_stds.append(
dict(zip(self.action_names, batch_stats.model_values.std(dim=0)))
)
if batch_stats.model_values_on_logged_actions is not None:
self.logged_action_q_value.append(
batch_stats.model_values_on_logged_actions.mean().item()
)
if (
batch_stats.logged_actions is not None
and batch_stats.model_action_idxs is not None
):
logged_action_counts = {
action: (batch_stats.logged_actions == i).sum().item()
for i, action in enumerate(self.action_names)
}
model_action_counts = {
action: (batch_stats.model_action_idxs == i).sum().item()
for i, action in enumerate(self.action_names)
}
print_details += "The distribution of logged actions : {}\n".format(
logged_action_counts
)
print_details += "The distribution of model actions : {}\n".format(
model_action_counts
)
for action, count in logged_action_counts.items():
self.logged_action_counts[action] += count
self.model_action_counts.append(model_action_counts)
for action, count in model_action_counts.items():
self.model_action_counts_cumulative[action] += count
total = float(sum(model_action_counts.values()))
self.model_action_distr.append(
{action: count / total for action, count in model_action_counts.items()}
)
print_details += "Batch Evaluator Finished"
for print_detail in print_details.split("\n"):
logger.info(print_detail)
self.incoming_stats.clear()
def get_td_loss_after_n(self, n):
return self.td_loss[n:]
def get_recent_td_loss(self):
return TrainingReporter.calculate_recent_window_average(
self.td_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1
)
def get_recent_reward_loss(self):
return TrainingReporter.calculate_recent_window_average(
self.reward_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1
)
def get_recent_imitator_loss(self):
return TrainingReporter.calculate_recent_window_average(
self.imitator_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1
)
def get_logged_action_distribution(self):
total_actions = 1.0 * sum(self.logged_action_counts.values())
return {k: (v / total_actions) for k, v in self.logged_action_counts.items()}
def get_model_action_distribution(self):
total_actions = 1.0 * sum(self.model_action_counts_cumulative.values())
return {
k: (v / total_actions)
for k, v in self.model_action_counts_cumulative.items()
}
def get_recent_rewards(self):
return self.running_reward
def log_to_tensorboard(self, epoch: int) -> None:
def none_to_zero(x: Optional[float]) -> float:
if x is None or math.isnan(x):
return 0.0
return x
for name, value in [
("Training/td_loss", self.get_recent_td_loss()),
("Training/reward_loss", self.get_recent_reward_loss()),
("Training/imitator_loss", self.get_recent_imitator_loss()),
]:
# pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`.
SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
@staticmethod
def calculate_recent_window_average(arr, window_size, num_entries):
if len(arr) > 0:
begin = max(0, len(arr) - window_size)
return np.mean(np.array(arr[begin:]), axis=0)
else:
logger.error("Not enough samples for evaluation.")
if num_entries == 1:
return float("nan")
else:
return [float("nan")] * num_entries
-9
View File
@@ -1,9 +0,0 @@
#!/usr/bin/env python3
from typing import Optional
from reagent.core.registry_meta import RegistryMeta
class TrainingReport(metaclass=RegistryMeta):
pass
-95
View File
@@ -1,95 +0,0 @@
#!/usr/bin/env python3
import itertools
import logging
from typing import List, Tuple
from reagent.core import aggregators as agg
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.union import TrainingReport__Union
from reagent.reporting.oss_training_reports import (
DebugToolsReport,
OssWorldModelTrainingReport,
)
from reagent.reporting.reporter_base import ReporterBase
logger = logging.getLogger(__name__)
class WorldModelReporter(ReporterBase):
def __init__(self, report_interval: int = 10):
"""
For world model:
'loss' (referring to total loss),
'bce' (loss for predicting not_terminal),
'gmm' (loss for next state prediction),
'mse' (loss for predicting reward)
"""
aggregators: List[Tuple[str, agg.Aggregator]] = list(
itertools.chain(
[
("loss", agg.MeanAggregator("loss", interval=report_interval)),
("bce", agg.MeanAggregator("bce", interval=report_interval)),
("gmm", agg.MeanAggregator("gmm", interval=report_interval)),
("mse", agg.MeanAggregator("mse", interval=report_interval)),
],
[
(
f"{key}_tb",
agg.TensorBoardHistogramAndMeanAggregator(
key, log_key, interval=report_interval
),
)
for key, log_key in [
("loss", "loss"),
("bce", "bce"),
("gmm", "gmm"),
("mse", "mse"),
]
],
)
)
super().__init__(aggregators)
def publish(self) -> RLTrainingOutput:
report = OssWorldModelTrainingReport(
loss=self.loss.values,
bce=self.bce.values,
gmm=self.gmm.values,
mse=self.mse.values,
)
return RLTrainingOutput(
training_report=TrainingReport__Union(oss_world_model_report=report)
)
class DebugToolsReporter(ReporterBase):
def __init__(self, report_interval: int = 1):
"""
For debug tools: feature_importance, feature_sensitivity
"""
aggregators: List[Tuple[str, agg.Aggregator]] = [
("feature_importance", agg.AppendAggregator("feature_importance")),
("feature_sensitivity", agg.AppendAggregator("feature_sensitivity")),
]
super().__init__(aggregators)
def publish(self) -> RLTrainingOutput:
feature_importance = (
[]
if len(self.feature_importance.values) == 0
else self.feature_importance.values[-1]
)
feature_sensitivity = (
[]
if len(self.feature_sensitivity.values) == 0
else self.feature_sensitivity.values[-1]
)
report = DebugToolsReport(
feature_importance=feature_importance,
feature_sensitivity=feature_sensitivity,
)
return RLTrainingOutput(
training_report=TrainingReport__Union(oss_debug_tools_report=report)
)
-402
View File
@@ -1,402 +0,0 @@
#!/usr/bin/env python3
import dataclasses
import logging
import time
from contextlib import contextmanager
from typing import Dict, NamedTuple, Optional, Tuple
import torch
from reagent.core.rl_training_output import RLTrainingOutput
from reagent.core.types import (
Dataset,
ReaderOptions,
RecurringPeriod,
ResourceOptions,
RewardOptions,
TableSpec,
)
from reagent.data_fetchers.data_fetcher import DataFetcher
from reagent.evaluation.evaluator import Evaluator
from reagent.model_managers.model_manager import ModelManager
from reagent.parameters import NormalizationData
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
from reagent.publishers.model_publisher import ModelPublisher
from reagent.tensorboardX import SummaryWriterContext, summary_writer_context
from reagent.training.trainer import Trainer
from reagent.validators.model_validator import ModelValidator
from reagent.workflow_utils.iterators import DataLoaderWrapper
from torch.utils.tensorboard import SummaryWriter
logger = logging.getLogger(__name__)
class TrainEvalSampleRanges(NamedTuple):
train_sample_range: Tuple[float, float]
eval_sample_range: Tuple[float, float]
class BatchRunner:
def __init__(
self,
use_gpu: bool,
model_manager: ModelManager,
data_fetcher: DataFetcher,
reward_options: RewardOptions,
normalization_data_map: Dict[str, NormalizationData],
warmstart_path: Optional[str] = None,
):
self.use_gpu = use_gpu
self.model_manager = model_manager
self.data_fetcher = data_fetcher
self.normalization_data_map = normalization_data_map
self.reward_options = reward_options
self.warmstart_path = warmstart_path
def get_workflow_id(self) -> int:
raise NotImplementedError()
def initialize_trainer(self) -> Trainer:
# validate that we have all the required keys
for normalization_key in self.model_manager.required_normalization_keys:
normalization_data = self.normalization_data_map.get(
normalization_key, None
)
assert normalization_data is not None, (
f"NormalizationData for {normalization_key} "
"is required but not provided."
)
# NOTE: Don't need this check in the future, for non-dense parameters
assert normalization_data.dense_normalization_parameters is not None, (
f"Dense normalization parameters for "
f"{normalization_key} is not provided."
)
trainer = self.model_manager.build_trainer(
self.use_gpu, self.normalization_data_map, self.reward_options
)
if self.warmstart_path is not None:
trainer_state = torch.load(self.warmstart_path)
trainer.load_state_dict(trainer_state)
self.trainer = trainer
return trainer
def save_trainer(self, trainer: Trainer, output_path: str) -> None:
"""
Save the trainer for warmstarting/checkpointing.
"""
trainer_state = trainer.state_dict()
torch.save(trainer_state, output_path)
@staticmethod
def get_sample_range(
input_table_spec: TableSpec, calc_cpe_in_training: bool
) -> TrainEvalSampleRanges:
table_sample = input_table_spec.table_sample
eval_table_sample = input_table_spec.eval_table_sample
if not calc_cpe_in_training:
# use all data if table sample = None
if table_sample is None:
train_sample_range = (0.0, 100.0)
else:
train_sample_range = (0.0, table_sample)
return TrainEvalSampleRanges(
train_sample_range=train_sample_range,
# eval samples will not be used
eval_sample_range=(0.0, 0.0),
)
error_msg = (
"calc_cpe_in_training is set to True. "
f"Please specify table_sample(current={table_sample}) and "
f"eval_table_sample(current={eval_table_sample}) such that "
"eval_table_sample + table_sample <= 100. "
"In order to reliably calculate CPE, eval_table_sample "
"should not be too small."
)
assert table_sample is not None, error_msg
assert eval_table_sample is not None, error_msg
assert (eval_table_sample + table_sample) <= (100.0 + 1e-3), error_msg
return TrainEvalSampleRanges(
train_sample_range=(0.0, table_sample),
eval_sample_range=(100.0 - eval_table_sample, 100.0),
)
def query(
self,
input_table_spec: TableSpec,
reader_options: ReaderOptions,
resource_options: ResourceOptions,
) -> Tuple[Dataset, Dataset]:
logger.info("Starting query")
calc_cpe_in_training = self.model_manager.should_generate_eval_dataset
sample_range_output = BatchRunner.get_sample_range(
input_table_spec, calc_cpe_in_training
)
train_dataset = self.model_manager.query_data(
data_fetcher=self.data_fetcher,
input_table_spec=input_table_spec,
sample_range=sample_range_output.train_sample_range,
reward_options=self.reward_options,
)
eval_dataset = None
if calc_cpe_in_training:
eval_dataset = self.model_manager.query_data(
data_fetcher=self.data_fetcher,
input_table_spec=input_table_spec,
sample_range=sample_range_output.eval_sample_range,
reward_options=self.reward_options,
)
return (train_dataset, eval_dataset)
def run_feature_identification(
self, input_table_spec: TableSpec
) -> Dict[str, NormalizationData]:
return self.model_manager.run_feature_identification(
self.data_fetcher, input_table_spec
)
def train(
self,
train_dataset: Dataset,
eval_dataset: Dataset,
normalization_data_map: Dict[str, NormalizationData],
num_epochs: int,
reader_options: ReaderOptions,
resource_options: Optional[ResourceOptions] = None,
warmstart_path: Optional[str] = None,
validator: Optional[ModelValidator] = None,
parent_workflow_id: Optional[int] = None,
recurring_period: Optional[RecurringPeriod] = None,
) -> RLTrainingOutput:
logger.info(f"{reader_options}")
child_workflow_id = self.get_workflow_id()
if parent_workflow_id is None:
parent_workflow_id = child_workflow_id
resource_options = resource_options or ResourceOptions()
logger.info("Starting training")
results = self.train_workflow(
train_dataset,
eval_dataset,
num_epochs,
parent_workflow_id=parent_workflow_id,
child_workflow_id=child_workflow_id,
reader_options=reader_options,
resource_options=resource_options,
)
if validator is not None:
results = self.run_validator(validator, results)
return results
def run_validator(
self, model_validator: ModelValidator, training_output: RLTrainingOutput
) -> RLTrainingOutput:
assert (
training_output.validation_result is None
), f"validation_output was set to f{training_output.validation_output}"
validation_result = model_validator.validate(training_output)
return dataclasses.replace(training_output, validation_result=validation_result)
def run_publisher(
self,
model_publisher: ModelPublisher,
training_output: RLTrainingOutput,
recurring_workflow_id: int,
child_workflow_id: int,
recurring_period: Optional[RecurringPeriod],
) -> RLTrainingOutput:
assert (
training_output.publishing_result is None
), f"publishing_output was set to f{training_output.publishing_output}"
publishing_result = model_publisher.publish(
self.model_manager,
training_output,
recurring_workflow_id,
child_workflow_id,
recurring_period,
)
return dataclasses.replace(training_output, publishing_result=publishing_result)
def train_workflow(
self,
train_dataset: Dataset,
eval_dataset: Optional[Dataset],
num_epochs: int,
parent_workflow_id: int,
child_workflow_id: int,
reader_options: ReaderOptions,
resource_options: Optional[ResourceOptions] = None,
) -> RLTrainingOutput:
writer = SummaryWriter()
logger.info("TensorBoard logging location is: {}".format(writer.log_dir))
trainer = self.initialize_trainer()
with summary_writer_context(writer):
train_output: RLTrainingOutput = self._train(
train_dataset, eval_dataset, num_epochs, reader_options, trainer
)
torchscript_output_path = f"model_{round(time.time())}.torchscript"
serving_module = self.model_manager.build_serving_module(
self.normalization_data_map, trainer
)
torch.jit.save(serving_module, torchscript_output_path)
logger.info(f"Saved torchscript model to {torchscript_output_path}")
return dataclasses.replace(
train_output, local_output_path=torchscript_output_path
)
def _train(
self,
train_dataset: Dataset,
eval_dataset: Optional[Dataset],
num_epochs: int,
reader_options: ReaderOptions,
trainer: Trainer,
) -> RLTrainingOutput:
reporter = self.model_manager.get_reporter()
trainer.reporter = reporter
evaluator = self.model_manager.get_evaluator(trainer, self.reward_options)
if evaluator is not None:
evaluator.reporter = reporter
batch_preprocessor = self.model_manager.build_batch_preprocessor(
reader_options,
self.use_gpu,
trainer.minibatch_size,
self.normalization_data_map,
self.reward_options,
)
return self.train_and_evaluate_generic(
train_dataset,
eval_dataset,
trainer,
num_epochs,
self.use_gpu,
batch_preprocessor,
evaluator,
reader_options,
)
def run_on_dataset_batches(
self,
run_on_batch_fn,
dataset: Dataset,
minibatch_size: int,
batch_preprocessor: BatchPreprocessor,
use_gpu: bool,
reader_options: ReaderOptions,
dataset_size: Optional[int] = None,
) -> torch.utils.data.DataLoader:
logger.info(f"{reader_options}")
""" run_on_batch_fn is a function f that expects batches """
if dataset_size is None:
dataset_size = self.data_fetcher.get_table_row_count(dataset)
assert dataset_size is not None
assert dataset_size > 0, f"{dataset_size} is expected to be positive"
@contextmanager
def cleanup_dataloader_session(data_loader):
try:
yield data_loader
finally:
logger.info("Closing data loader")
if hasattr(data_loader, "destroy_session"):
logger.info("Closing DistributedDataLoader")
data_loader.destroy_session()
_dataloader = self.data_fetcher.get_dataloader(
dataset=dataset,
batch_size=minibatch_size,
batch_preprocessor=batch_preprocessor,
use_gpu=use_gpu,
reader_options=reader_options,
)
with cleanup_dataloader_session(_dataloader) as dataloader:
post_dataloader_preprocessor = self.data_fetcher.get_post_dataloader_preprocessor(
reader_options=reader_options, use_gpu=use_gpu
)
dataloader_wrapper = DataLoaderWrapper(
dataloader=dataloader,
dataloader_size=dataset_size,
post_dataloader_preprocessor=post_dataloader_preprocessor,
)
for batch in dataloader_wrapper:
run_on_batch_fn(batch)
return dataloader
def train_and_evaluate_generic(
self,
train_dataset: Dataset,
eval_dataset: Optional[Dataset],
trainer: Trainer,
num_epochs: int,
use_gpu: bool,
batch_preprocessor: BatchPreprocessor,
evaluator: Optional[Evaluator],
reader_options: ReaderOptions,
sort_eval_data: bool = True,
) -> RLTrainingOutput:
logger.info(f"{reader_options}")
assert num_epochs > 0, f"Epoch should be positive, got {num_epochs}"
train_dataset_size = self.data_fetcher.get_table_row_count(train_dataset)
if eval_dataset is not None and not sort_eval_data:
eval_dataset_size = self.data_fetcher.get_table_row_count(eval_dataset)
for epoch in range(num_epochs):
SummaryWriterContext._reset_globals()
logger.info(f"Starting training epoch {epoch}.")
data_loader = self.run_on_dataset_batches(
run_on_batch_fn=trainer.train,
dataset=train_dataset,
minibatch_size=trainer.minibatch_size,
batch_preprocessor=batch_preprocessor,
use_gpu=use_gpu,
reader_options=reader_options,
dataset_size=train_dataset_size,
)
if eval_dataset is not None and evaluator is not None:
if sort_eval_data:
logger.info(
f"Starting evaluation epoch {epoch} by sorting and one shot"
)
eval_data = self.data_fetcher.gather_and_sort_eval_data(
trainer=trainer,
eval_dataset=eval_dataset,
batch_preprocessor=batch_preprocessor,
use_gpu=use_gpu,
reader_options=reader_options,
)
evaluator.evaluate_one_shot(eval_data)
evaluator.finish()
else:
logger.info(
f"Starting evaluation epoch {epoch} by running on batches"
)
data_loader = self.run_on_dataset_batches(
run_on_batch_fn=evaluator.evaluate,
dataset=eval_dataset,
minibatch_size=trainer.minibatch_size,
batch_preprocessor=batch_preprocessor,
use_gpu=use_gpu,
reader_options=reader_options,
dataset_size=eval_dataset_size,
)
evaluator.finish()
trainer.reporter.finish_epoch()
report = trainer.reporter.publish()
if hasattr(data_loader, "shutdown"):
data_loader.shutdown()
return report
-39
View File
@@ -1,39 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import logging
import random
from typing import Dict, Optional
from reagent.core.types import RewardOptions
from reagent.data_fetchers.oss_data_fetcher import OssDataFetcher
from reagent.model_managers.model_manager import ModelManager
from reagent.parameters import NormalizationData
from reagent.runners.batch_runner import BatchRunner
logger = logging.getLogger(__name__)
class OssBatchRunner(BatchRunner):
def __init__(
self,
use_gpu: bool,
model_manager: ModelManager,
reward_options: RewardOptions,
normalization_data_map: Dict[str, NormalizationData],
warmstart_path: Optional[str] = None,
):
super().__init__(
use_gpu,
model_manager,
OssDataFetcher(),
reward_options,
normalization_data_map,
warmstart_path,
)
# Generate a random workflow id for this batch runner
self.workflow_id = random.randint(1000, 10000000)
def get_workflow_id(self) -> int:
return self.workflow_id
+49
View File
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from reagent.core.observers import ValueListObserver
from reagent.core.tracker import observable
class TestObservable(unittest.TestCase):
def test_observable(self):
@observable(td_loss=float, str_val=str)
class DummyClass:
def __init__(self, a, b, c=10):
super().__init__()
self.a = a
self.b = b
self.c = c
def do_something(self, i):
self.notify_observers(td_loss=i, str_val="not_used")
instance = DummyClass(1, 2)
self.assertIsInstance(instance, DummyClass)
self.assertEqual(instance.a, 1)
self.assertEqual(instance.b, 2)
self.assertEqual(instance.c, 10)
observers = [ValueListObserver("td_loss") for _i in range(3)]
instance.add_observers(observers)
# Adding twice should not result in double update
instance.add_observer(observers[0])
for i in range(10):
instance.do_something(float(i))
for observer in observers:
self.assertEqual(observer.values, [float(i) for i in range(10)])
def test_no_observable_values(self):
try:
@observable()
class NoObservableValues:
pass
except AssertionError:
pass
@@ -8,7 +8,7 @@ from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.evaluation.doubly_robust_estimator import DoublyRobustEstimator
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
from reagent.evaluation.ope_adapter import OPEstimatorAdapter
+162 -2
View File
@@ -1,11 +1,15 @@
import logging
import random
import unittest
import numpy as np
import torch
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
from reagent.evaluation.ope_adapter import OPEstimatorAdapter
from reagent.evaluation.ope_adapter import (
OPEstimatorAdapter,
SequentialOPEstimatorAdapter,
)
from reagent.ope.estimators.contextual_bandits_estimators import (
DMEstimator,
DoublyRobustEstimator,
@@ -13,6 +17,20 @@ from reagent.ope.estimators.contextual_bandits_estimators import (
SwitchDREstimator,
SwitchEstimator,
)
from reagent.ope.estimators.sequential_estimators import (
DoublyRobustEstimator as SeqDREstimator,
EpsilonGreedyRLPolicy,
RandomRLPolicy,
RLEstimatorInput,
)
from reagent.ope.estimators.types import Action, ActionSpace
from reagent.ope.test.envs import PolicyLogGenerator
from reagent.ope.test.gridworld import GridWorld, NoiseGridWorldModel
from reagent.ope.trainers.rl_tabular_trainers import (
DPTrainer,
DPValueFunction,
TabularPolicy,
)
from reagent.test.evaluation.test_evaluation_data_page import (
FakeSeq2SlateRewardNetwork,
FakeSeq2SlateTransformerNet,
@@ -22,6 +40,56 @@ from reagent.test.evaluation.test_evaluation_data_page import (
logger = logging.getLogger(__name__)
def rlestimator_input_to_edp(
input: RLEstimatorInput, num_actions: int
) -> EvaluationDataPage:
mdp_ids = []
logged_propensities = []
logged_rewards = []
action_mask = []
model_propensities = []
model_values = []
for mdp in input.log:
mdp_id = len(mdp_ids)
for t in mdp:
mdp_ids.append(mdp_id)
logged_propensities.append(t.action_prob)
logged_rewards.append(t.reward)
assert t.action is not None
action_mask.append(
[1 if x == t.action.value else 0 for x in range(num_actions)]
)
assert t.last_state is not None
model_propensities.append(
[
input.target_policy(t.last_state)[Action(x)]
for x in range(num_actions)
]
)
assert input.value_function is not None
model_values.append(
[
input.value_function(t.last_state, Action(x))
for x in range(num_actions)
]
)
return EvaluationDataPage(
mdp_id=torch.tensor(mdp_ids).reshape(len(mdp_ids), 1),
logged_propensities=torch.tensor(logged_propensities).reshape(
(len(logged_propensities), 1)
),
logged_rewards=torch.tensor(logged_rewards).reshape((len(logged_rewards), 1)),
action_mask=torch.tensor(action_mask),
model_propensities=torch.tensor(model_propensities),
model_values=torch.tensor(model_values),
sequence_number=torch.tensor([]),
model_rewards=torch.tensor([]),
model_rewards_for_logged_action=torch.tensor([]),
)
class TestOPEModuleAlgs(unittest.TestCase):
GAMMA = 0.9
CPE_PASS_BAR = 1.0
@@ -30,6 +98,98 @@ class TestOPEModuleAlgs(unittest.TestCase):
NOISE_EPSILON = 0.3
EPISODES = 2
def test_gridworld_sequential_adapter(self):
"""
Create a gridworld environment, logging policy, and target policy
Evaluates target policy using the direct OPE sequential doubly robust estimator,
then transforms the log into an evaluation data page which is passed to the ope adapter.
This test is meant to verify the adaptation of EDPs into RLEstimatorInputs as employed
by ReAgent since ReAgent provides EDPs to Evaluators. Going from EDP -> RLEstimatorInput
is more involved than RLEstimatorInput -> EDP since the EDP does not store the state
at each timestep in each MDP, only the corresponding logged outputs & model outputs.
Thus, the adapter must do some tricks to represent these timesteps as states so the
ope module can extract the correct outputs.
Note that there is some randomness in the model outputs since the model is purposefully
noisy. However, the same target policy is being evaluated on the same logged walks through
the gridworld, so the two results should be close in value (within 1).
"""
random.seed(0)
np.random.seed(0)
torch.random.manual_seed(0)
device = torch.device("cuda") if torch.cuda.is_available() else None
gridworld = GridWorld.from_grid(
[
["s", "0", "0", "0", "0"],
["0", "0", "0", "W", "0"],
["0", "0", "0", "0", "0"],
["0", "W", "0", "0", "0"],
["0", "0", "0", "0", "g"],
],
max_horizon=TestOPEModuleAlgs.MAX_HORIZON,
)
action_space = ActionSpace(4)
opt_policy = TabularPolicy(action_space)
trainer = DPTrainer(gridworld, opt_policy)
value_func = trainer.train(gamma=TestOPEModuleAlgs.GAMMA)
behavivor_policy = RandomRLPolicy(action_space)
target_policy = EpsilonGreedyRLPolicy(
opt_policy, TestOPEModuleAlgs.NOISE_EPSILON
)
model = NoiseGridWorldModel(
gridworld,
action_space,
epsilon=TestOPEModuleAlgs.NOISE_EPSILON,
max_horizon=TestOPEModuleAlgs.MAX_HORIZON,
)
value_func = DPValueFunction(target_policy, model, TestOPEModuleAlgs.GAMMA)
ground_truth = DPValueFunction(
target_policy, gridworld, TestOPEModuleAlgs.GAMMA
)
log = []
log_generator = PolicyLogGenerator(gridworld, behavivor_policy)
num_episodes = TestOPEModuleAlgs.EPISODES
for state in gridworld.states:
for _ in range(num_episodes):
log.append(log_generator.generate_log(state))
estimator_input = RLEstimatorInput(
gamma=TestOPEModuleAlgs.GAMMA,
log=log,
target_policy=target_policy,
value_function=value_func,
ground_truth=ground_truth,
)
edp = rlestimator_input_to_edp(estimator_input, len(model.action_space))
dr_estimator = SeqDREstimator(
weight_clamper=None, weighted=False, device=device
)
module_results = SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate(
dr_estimator.evaluate(estimator_input)
)
adapter_results = SequentialOPEstimatorAdapter(
dr_estimator, TestOPEModuleAlgs.GAMMA, device=device
).estimate(edp)
self.assertAlmostEqual(
adapter_results.raw,
module_results.raw,
delta=TestOPEModuleAlgs.CPE_PASS_BAR,
), f"OPE adapter results differed too much from underlying module (Diff: {abs(adapter_results.raw - module_results.raw)} > {TestOPEModuleAlgs.CPE_PASS_BAR})"
self.assertLess(
adapter_results.raw, TestOPEModuleAlgs.CPE_MAX_VALUE
), f"OPE adapter results are too large ({adapter_results.raw} > {TestOPEModuleAlgs.CPE_MAX_VALUE})"
def test_seq2slate_eval_data_page(self):
"""
Create 3 slate ranking logs and evaluate using Direct Method, Inverse
+1 -1
View File
@@ -8,7 +8,7 @@ from typing import Any
import torch
import torch.nn as nn
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.base import ModelBase
from reagent.test.models.test_utils import check_save_load
+1 -1
View File
@@ -7,7 +7,7 @@ import unittest
import numpy.testing as npt
import torch
import torch.nn.init as init
from reagent.core import types as rlt
from reagent import types as rlt
from reagent.models.bcq import BatchConstrainedDQN
from reagent.models.dqn import FullyConnectedDQN
from reagent.models.fully_connected_network import FullyConnectedNetwork
@@ -43,9 +43,7 @@ class TestNoSoftUpdteEmbedding(unittest.TestCase):
self.assertEqual(1, len(params))
param = params[0].detach().numpy()
trainer = RLTrainer(
rl_parameters=RLParameters(), minibatch_size=1024, use_gpu=False
)
trainer = RLTrainer(rl_parameters=RLParameters(), use_gpu=False)
trainer._soft_update(model, target_model, 0.1)
target_params = list(target_model.parameters())

Some files were not shown because too many files have changed in this diff Show More