mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Back out recent refactor
Summary: Need more tests before landing the refactor diffs: D22702504 (https://github.com/facebookresearch/ReAgent/commit/1b470c489d19c33beab88b8ea2e79843d4d31f28), D23123762 (https://github.com/facebookresearch/ReAgent/commit/76829287265bc39f879f3bc1d946a1374c5e1141), D23124179 (https://github.com/facebookresearch/ReAgent/commit/b28f84aa013be00194508f52498160592cb37e9d), D23219012 (https://github.com/facebookresearch/ReAgent/commit/e404c5772ea4118105c2eb136ca96ad5ca8e01db) Back out to a version based on D23155753. Check our team diff history: https://fburl.com/diffs/ppsgazgj Reviewed By: kittipatv Differential Revision: D23270626 fbshipit-source-id: 14653066bb3924a987a54650a51241895b321c8e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e404c5772e
commit
0d294b11e5
@@ -64,7 +64,7 @@ ml.rl.training.imitator\_training module
|
||||
ml.rl.training.loss\_reporter module
|
||||
------------------------------------
|
||||
|
||||
.. automodule:: ml.rl.training.rl_reporter
|
||||
.. automodule:: ml.rl.training.loss_reporter
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
+19
-125
@@ -3,100 +3,21 @@
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Deque, Dict, List, Optional
|
||||
from typing import Callable, Deque, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from reagent.core.tracker import Aggregator
|
||||
from reagent.tensorboardX import SummaryWriterContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Aggregator:
|
||||
def __init__(self, key: str, interval: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.iteration = 0
|
||||
self.interval = interval
|
||||
self.aggregate_epoch = interval is None
|
||||
self.intermediate_values: List[Any] = []
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.intermediate_values.append(value)
|
||||
self.iteration += 1
|
||||
# pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`.
|
||||
if self.interval and self.iteration % self.interval == 0:
|
||||
logger.info(
|
||||
f"Interval Agg. Update: {self.key}; iteration {self.iteration}; "
|
||||
f"aggregator: {self.__class__.__name__}"
|
||||
)
|
||||
self(self.key, self.intermediate_values)
|
||||
self.intermediate_values = []
|
||||
|
||||
def finish_epoch(self):
|
||||
# We need to reset iteration here to avoid aggregating on the same data multiple
|
||||
# times
|
||||
logger.info(
|
||||
f"Epoch finished. Flushing: {self.key}; "
|
||||
f"aggregator: {self.__class__.__name__}; points: {len(self.intermediate_values)}"
|
||||
)
|
||||
self.iteration = 0
|
||||
if self.aggregate_epoch:
|
||||
self(self.key, self.intermediate_values)
|
||||
self.intermediate_values = []
|
||||
|
||||
def __call__(self, key: str, values):
|
||||
assert key == self.key, f"Got {key}; expected {self.key}"
|
||||
self.aggregate(values)
|
||||
|
||||
def aggregate(self, intermediate_values):
|
||||
pass
|
||||
|
||||
def get_recent(self, count):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AppendAggregator(Aggregator):
|
||||
def __init__(self, key: str, interval: Optional[int] = None):
|
||||
super().__init__(key, interval)
|
||||
self.values = []
|
||||
|
||||
def __call__(self, key: str, values):
|
||||
assert key == self.key, f"Got {key}; expected {self.key}"
|
||||
self.aggregate(values)
|
||||
|
||||
def aggregate(self, intermediate_values):
|
||||
self.values.extend(intermediate_values)
|
||||
|
||||
def get_recent(self, count):
|
||||
if len(self.values) == 0:
|
||||
return []
|
||||
return self.values[-count:]
|
||||
|
||||
def get_all(self):
|
||||
return self.values
|
||||
|
||||
|
||||
class TensorAggregator(Aggregator):
|
||||
def __call__(self, key: str, values, interval: Optional[int] = None):
|
||||
if len(values) == 0:
|
||||
return super().__call__(key, torch.tensor([0.0]))
|
||||
def __call__(self, key: str, values):
|
||||
# Ensure that tensor is on cpu before aggregation.
|
||||
reshaped_values = []
|
||||
for value in values:
|
||||
if isinstance(value, list):
|
||||
reshaped_values.append(torch.tensor(value))
|
||||
elif not hasattr(value, "size"):
|
||||
reshaped_values.append(torch.tensor(value).unsqueeze(0))
|
||||
elif len(value.size()) == 0:
|
||||
reshaped_values.append(value.unsqueeze(0))
|
||||
else:
|
||||
reshaped_values.append(value)
|
||||
values = torch.cat(reshaped_values, dim=0).cpu()
|
||||
values = torch.cat(values, dim=0).cpu()
|
||||
return super().__call__(key, values)
|
||||
|
||||
|
||||
@@ -114,8 +35,8 @@ def _log_histogram_and_mean(log_key, val):
|
||||
|
||||
|
||||
class TensorBoardHistogramAndMeanAggregator(TensorAggregator):
|
||||
def __init__(self, key: str, log_key: str, interval: Optional[int] = None):
|
||||
super().__init__(key, interval)
|
||||
def __init__(self, key: str, log_key: str):
|
||||
super().__init__(key)
|
||||
self.log_key = log_key
|
||||
|
||||
def aggregate(self, values):
|
||||
@@ -133,9 +54,8 @@ class TensorBoardActionHistogramAndMeanAggregator(TensorAggregator):
|
||||
title: str,
|
||||
actions: List[str],
|
||||
log_key_prefix: Optional[str] = None,
|
||||
interval: Optional[int] = None,
|
||||
):
|
||||
super().__init__(key, interval)
|
||||
super().__init__(key)
|
||||
self.log_key_prefix = log_key_prefix or f"{category}/{title}"
|
||||
self.actions = actions
|
||||
SummaryWriterContext.add_custom_scalars_multilinechart(
|
||||
@@ -157,10 +77,8 @@ class TensorBoardActionHistogramAndMeanAggregator(TensorAggregator):
|
||||
|
||||
|
||||
class TensorBoardActionCountAggregator(TensorAggregator):
|
||||
def __init__(
|
||||
self, key: str, title: str, actions: List[str], interval: Optional[int] = None
|
||||
):
|
||||
super().__init__(key, interval)
|
||||
def __init__(self, key: str, title: str, actions: List[str]):
|
||||
super().__init__(key)
|
||||
self.log_key = f"actions/{title}"
|
||||
self.actions = actions
|
||||
SummaryWriterContext.add_custom_scalars_multilinechart(
|
||||
@@ -177,8 +95,8 @@ class TensorBoardActionCountAggregator(TensorAggregator):
|
||||
|
||||
|
||||
class MeanAggregator(TensorAggregator):
|
||||
def __init__(self, key: str, interval: Optional[int] = None):
|
||||
super().__init__(key, interval)
|
||||
def __init__(self, key: str):
|
||||
super().__init__(key)
|
||||
self.values: List[float] = []
|
||||
|
||||
def aggregate(self, values):
|
||||
@@ -186,14 +104,6 @@ class MeanAggregator(TensorAggregator):
|
||||
logger.info(f"{self.key}: {mean}")
|
||||
self.values.append(mean)
|
||||
|
||||
def get_recent(self, count):
|
||||
if len(self.values) == 0:
|
||||
return []
|
||||
return self.values[-count:]
|
||||
|
||||
def get_all(self):
|
||||
return self.values
|
||||
|
||||
|
||||
class FunctionsByActionAggregator(TensorAggregator):
|
||||
"""
|
||||
@@ -234,14 +144,8 @@ class FunctionsByActionAggregator(TensorAggregator):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
actions: List[str],
|
||||
fns: Dict[str, Callable],
|
||||
interval: Optional[int] = None,
|
||||
):
|
||||
super().__init__(key, interval)
|
||||
def __init__(self, key: str, actions: List[str], fns: Dict[str, Callable]):
|
||||
super().__init__(key)
|
||||
self.actions = actions
|
||||
self.values: Dict[str, Dict[str, List[float]]] = {
|
||||
fn: {action: [] for action in self.actions} for fn in fns
|
||||
@@ -268,8 +172,8 @@ class ActionCountAggregator(TensorAggregator):
|
||||
`len(actions) - 1`. The input is assumed to contain action index.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, actions: List[str], interval: Optional[int] = None):
|
||||
super().__init__(key, interval)
|
||||
def __init__(self, key: str, actions: List[str]):
|
||||
super().__init__(key)
|
||||
self.actions = actions
|
||||
self.values: Dict[str, List[int]] = {action: [] for action in actions}
|
||||
|
||||
@@ -286,7 +190,7 @@ class ActionCountAggregator(TensorAggregator):
|
||||
"""
|
||||
totals = np.array([sum(counts) for counts in zip(*self.values.values())])
|
||||
return {
|
||||
action: (np.array(counts) / np.clip(totals, 1, None)).tolist()
|
||||
action: (np.array(counts) / totals).tolist()
|
||||
for action, counts in self.values.items()
|
||||
}
|
||||
|
||||
@@ -294,7 +198,7 @@ class ActionCountAggregator(TensorAggregator):
|
||||
"""
|
||||
Returns the cumulative distributions in each aggregating step
|
||||
"""
|
||||
totals = max(1, sum(sum(counts) for counts in zip(*self.values.values())))
|
||||
totals = sum(sum(counts) for counts in zip(*self.values.values()))
|
||||
return {action: sum(counts) / totals for action, counts in self.values.items()}
|
||||
|
||||
|
||||
@@ -302,20 +206,10 @@ _RECENT_DEFAULT_SIZE = int(1e6)
|
||||
|
||||
|
||||
class RecentValuesAggregator(TensorAggregator):
|
||||
def __init__(
|
||||
self, key: str, size: int = _RECENT_DEFAULT_SIZE, interval: Optional[int] = None
|
||||
):
|
||||
super().__init__(key, interval)
|
||||
def __init__(self, key: str, size: int = _RECENT_DEFAULT_SIZE):
|
||||
super().__init__(key)
|
||||
self.values: Deque[float] = deque(maxlen=size)
|
||||
|
||||
def aggregate(self, values):
|
||||
flattened = torch.flatten(values).tolist()
|
||||
self.values.extend(flattened)
|
||||
|
||||
def get_recent(self, count):
|
||||
if len(self.values) == 0:
|
||||
return []
|
||||
return self.values[-count:]
|
||||
|
||||
def get_all(self):
|
||||
return self.values
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
|
||||
|
||||
if importlib.util.find_spec("fblearner") is not None:
|
||||
import fblearner.flow.api as flow
|
||||
|
||||
class AsyncWrapper:
|
||||
def __init__(self, **kwargs):
|
||||
self.async_wrapper = flow.flow_async(**kwargs)
|
||||
self.type_wrapper = flow.typed()
|
||||
|
||||
def __call__(self, func):
|
||||
return self.async_wrapper(self.type_wrapper(func))
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def AsyncWrapper(**outer_kwargs):
|
||||
def async_wrapper_internal(func):
|
||||
@functools.wraps(func)
|
||||
def async_wrapper_repeat(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper_repeat
|
||||
|
||||
return async_wrapper_internal
|
||||
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from reagent.core.tracker import Aggregator, Observer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompositeObserver(Observer):
|
||||
"""
|
||||
A composite observer which takes care of dispatching values to child observers
|
||||
"""
|
||||
|
||||
def __init__(self, observers: Iterable[Observer]):
|
||||
self.observers: Dict[str, List[Observer]] = {}
|
||||
for observer in observers:
|
||||
observing_keys = observer.get_observing_keys()
|
||||
for key in observing_keys:
|
||||
self.observers.setdefault(key, []).append(observer)
|
||||
super().__init__(list(self.observers))
|
||||
|
||||
def update(self, key: str, value):
|
||||
for observer in self.observers[key]:
|
||||
observer.update(key, value)
|
||||
|
||||
|
||||
class EpochEndObserver(Observer):
|
||||
"""
|
||||
Call the callback function with epoch # when the epoch ends
|
||||
"""
|
||||
|
||||
def __init__(self, callback, key: str = "epoch_end"):
|
||||
super().__init__(observing_keys=[key])
|
||||
self.callback = callback
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.callback(value)
|
||||
|
||||
|
||||
class ValueListObserver(Observer):
|
||||
"""
|
||||
Simple observer that collect values into a list
|
||||
"""
|
||||
|
||||
def __init__(self, observing_key: str):
|
||||
super().__init__(observing_keys=[observing_key])
|
||||
self.observing_key = observing_key
|
||||
self.values: List[Any] = []
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.values.append(value)
|
||||
|
||||
def reset(self):
|
||||
self.values = []
|
||||
|
||||
|
||||
class IntervalAggregatingObserver(Observer):
|
||||
def __init__(
|
||||
self,
|
||||
interval: Optional[int],
|
||||
aggregator: Aggregator,
|
||||
observe_epoch_end: bool = True,
|
||||
):
|
||||
self.key = aggregator.key
|
||||
obs_keys = ["epoch_end"] if observe_epoch_end else []
|
||||
obs_keys.append(self.key)
|
||||
super().__init__(observing_keys=obs_keys)
|
||||
self.iteration = 0
|
||||
self.interval = interval
|
||||
self.intermediate_values: List[Any] = []
|
||||
self.aggregator = aggregator
|
||||
|
||||
def update(self, key: str, value):
|
||||
if key == "epoch_end":
|
||||
self.flush()
|
||||
return
|
||||
|
||||
self.intermediate_values.append(value)
|
||||
self.iteration += 1
|
||||
# pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`.
|
||||
if self.interval and self.iteration % self.interval == 0:
|
||||
logger.info(
|
||||
f"Interval Agg. Update: {self.key}; iteration {self.iteration}; "
|
||||
f"aggregator: {self.aggregator.__class__.__name__}"
|
||||
)
|
||||
self.aggregator(self.key, self.intermediate_values)
|
||||
self.intermediate_values = []
|
||||
|
||||
def flush(self):
|
||||
# We need to reset iteration here to avoid aggregating on the same data multiple
|
||||
# times
|
||||
logger.info(
|
||||
f"Interval Agg. Flushing: {self.key}; iteration: {self.iteration}; "
|
||||
f"aggregator: {self.aggregator.__class__.__name__}; points: {len(self.intermediate_values)}"
|
||||
)
|
||||
self.iteration = 0
|
||||
if self.intermediate_values:
|
||||
self.aggregator(self.key, self.intermediate_values)
|
||||
self.intermediate_values = []
|
||||
@@ -16,7 +16,7 @@ class RegistryMeta(abc.ABCMeta):
|
||||
def __init__(cls, name, bases, attrs):
|
||||
if not hasattr(cls, "REGISTRY"):
|
||||
# Put REGISTRY on cls. This only happens once on the base class
|
||||
logger.debug("Adding REGISTRY to type {}".format(name))
|
||||
logger.info("Adding REGISTRY to type {}".format(name))
|
||||
cls.REGISTRY: Dict[str, Type] = {}
|
||||
cls.REGISTRY_NAME = name
|
||||
cls.REGISTRY_FROZEN = False
|
||||
@@ -28,19 +28,12 @@ class RegistryMeta(abc.ABCMeta):
|
||||
|
||||
if not cls.__abstractmethods__ and name != cls.REGISTRY_NAME:
|
||||
# Only register fully-defined classes
|
||||
logger.info(f"Registering {name} to {cls.REGISTRY_NAME}")
|
||||
if hasattr(cls, "__registry_name__"):
|
||||
registry_name = cls.__registry_name__
|
||||
logger.info(
|
||||
f"Registering {name} with alias {registry_name} to {cls.REGISTRY_NAME}"
|
||||
)
|
||||
logger.info(f"Using {registry_name} instead of {name}")
|
||||
name = registry_name
|
||||
else:
|
||||
logger.info(f"Registering {name} to {cls.REGISTRY_NAME}")
|
||||
# assert name not in cls.REGISTRY
|
||||
# TODO: Combine FB and OSS model managers and then bring back this assert.
|
||||
# For now this works because FB model managers inherit from their OSS counterparts
|
||||
if name in cls.REGISTRY:
|
||||
logger.warning(f"Overwriting open source {name} with internal version")
|
||||
assert name not in cls.REGISTRY
|
||||
cls.REGISTRY[name] = cls
|
||||
else:
|
||||
logger.info(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.reporting.result_registries import PublishingResult, ValidationResult
|
||||
from reagent.workflow.result_registries import PublishingResult, ValidationResult
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from reagent.core.union import (
|
||||
PublishingResult__Union,
|
||||
TrainingReport__Union,
|
||||
ValidationResult__Union,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTrainingOutput:
|
||||
validation_result: Optional[ValidationResult__Union] = None
|
||||
publishing_result: Optional[PublishingResult__Union] = None
|
||||
training_report: Optional[TrainingReport__Union] = None
|
||||
local_output_path: Optional[str] = None
|
||||
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Observer:
|
||||
"""
|
||||
Base class for observers
|
||||
"""
|
||||
|
||||
def __init__(self, observing_keys: List[str]):
|
||||
super().__init__()
|
||||
assert isinstance(observing_keys, list)
|
||||
self.observing_keys = observing_keys
|
||||
|
||||
def get_observing_keys(self) -> List[str]:
|
||||
return self.observing_keys
|
||||
|
||||
def update(self, key: str, value):
|
||||
pass
|
||||
|
||||
|
||||
class Aggregator:
|
||||
def __init__(self, key: str):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
|
||||
def __call__(self, key: str, values):
|
||||
assert key == self.key, f"Got {key}; expected {self.key}"
|
||||
self.aggregate(values)
|
||||
|
||||
def aggregate(self, values):
|
||||
pass
|
||||
|
||||
|
||||
def observable(cls=None, **kwargs): # noqa: C901
|
||||
"""
|
||||
Decorator to mark a class as producing observable values. The names of the
|
||||
observable values are the names of keyword arguments. The values of keyword
|
||||
arguments are the types of the value. The type is currently not used for
|
||||
anything.
|
||||
"""
|
||||
assert kwargs
|
||||
observable_value_types = kwargs
|
||||
|
||||
def wrap(cls):
|
||||
assert not hasattr(cls, "add_observer")
|
||||
assert not hasattr(cls, "notify_observers")
|
||||
|
||||
original_init = cls.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
assert not hasattr(self, "_observable_value_types")
|
||||
assert not hasattr(self, "_observers")
|
||||
self._observable_value_types = observable_value_types
|
||||
self._observers = {v: [] for v in observable_value_types}
|
||||
|
||||
cls.__init__ = new_init
|
||||
|
||||
def add_observer(self, observer: Observer) -> None:
|
||||
observing_keys = observer.get_observing_keys()
|
||||
unknown_keys = [
|
||||
k for k in observing_keys if k not in self._observable_value_types
|
||||
]
|
||||
if unknown_keys:
|
||||
logger.warning(f"{unknown_keys} cannot be observed in {type(self)}")
|
||||
for k in observing_keys:
|
||||
if k in self._observers and observer not in self._observers[k]:
|
||||
self._observers[k].append(observer)
|
||||
return self
|
||||
|
||||
cls.add_observer = add_observer
|
||||
|
||||
def add_observers(self, observers: List[Observer]) -> None:
|
||||
for observer in observers:
|
||||
self.add_observer(observer)
|
||||
return self
|
||||
|
||||
cls.add_observers = add_observers
|
||||
|
||||
def notify_observers(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
# Allow optional reporting
|
||||
continue
|
||||
|
||||
assert key in self._observers, f"Unknown key: {key}"
|
||||
|
||||
# TODO: Create a generic framework for type conversion
|
||||
if self._observable_value_types[key] == torch.Tensor:
|
||||
if not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
if len(value.shape) == 0:
|
||||
value = value.reshape(1)
|
||||
value = value.detach()
|
||||
|
||||
for observer in self._observers[key]:
|
||||
observer.update(key, value)
|
||||
|
||||
cls.notify_observers = notify_observers
|
||||
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return wrap
|
||||
|
||||
return wrap(cls)
|
||||
+65
-708
@@ -1,26 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
# The dataclasses in this file should be vanilla dataclass to have minimal overhead
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime as RecurringPeriod # noqa
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.base_dataclass import BaseDataClass
|
||||
from reagent.core.configuration import param_hash
|
||||
from reagent.core.dataclasses import dataclass as pydantic_dataclass
|
||||
from reagent.preprocessing.normalization_constants import (
|
||||
# Triggering registration to registries
|
||||
import reagent.core.result_types # noqa
|
||||
import reagent.workflow.training_reports # noqa
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
from reagent.core.tagged_union import TaggedUnion # noqa F401
|
||||
from reagent.models.model_feature_config_provider import ModelFeatureConfigProvider
|
||||
from reagent.preprocessing.normalization import (
|
||||
DEFAULT_MAX_QUANTILE_SIZE,
|
||||
DEFAULT_MAX_UNIQUE_ENUM,
|
||||
DEFAULT_NUM_SAMPLES,
|
||||
DEFAULT_QUANTILE_K2_THRESHOLD,
|
||||
)
|
||||
from reagent.preprocessing.types import InputColumn
|
||||
from reagent.types import BaseDataClass
|
||||
from reagent.workflow.result_registries import PublishingResult, ValidationResult
|
||||
from reagent.workflow.training_reports import TrainingReport
|
||||
|
||||
|
||||
if IS_FB_ENVIRONMENT:
|
||||
from reagent.fb.models.model_feature_config_builder import ( # noqa
|
||||
ConfigeratorModelFeatureConfigProvider,
|
||||
)
|
||||
import reagent.core.fb.fb_types # noqa
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +40,7 @@ class OssDataset(Dataset):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableSpec(BaseDataClass):
|
||||
class TableSpec:
|
||||
table: str
|
||||
table_sample: Optional[float] = None
|
||||
eval_table_sample: Optional[float] = None
|
||||
@@ -44,11 +50,27 @@ class TableSpec(BaseDataClass):
|
||||
class RewardOptions:
|
||||
custom_reward_expression: Optional[str] = None
|
||||
metric_reward_values: Optional[Dict[str, float]] = None
|
||||
additional_reward_expression: Optional[str] = None
|
||||
|
||||
# for ranking
|
||||
# key: feature id in slate_reward column, value: linear coefficient
|
||||
slate_reward_values: Optional[Dict[str, float]] = None
|
||||
# key: feature id in item_reward column, value: linear coefficient
|
||||
item_reward_values: Optional[Dict[str, float]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReaderOptions:
|
||||
pass
|
||||
num_threads: int = 32
|
||||
skip_smaller_batches: bool = True
|
||||
num_workers: int = 0
|
||||
koski_logging_level: int = 2
|
||||
# distributed reader
|
||||
distributed_reader: bool = False
|
||||
distributed_master_mem: str = "20G"
|
||||
distributed_worker_mem: str = "20G"
|
||||
distributed_num_workers: int = 2
|
||||
gang_name: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,7 +80,10 @@ class OssReaderOptions(ReaderOptions):
|
||||
|
||||
@dataclass
|
||||
class ResourceOptions:
|
||||
pass
|
||||
cpu: Optional[int] = None
|
||||
# "-1" or "xxG" where "xx" is a positive integer
|
||||
memory: Optional[str] = "40g"
|
||||
gpu: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -84,713 +109,45 @@ class PreprocessingOptions(BaseDataClass):
|
||||
set_missing_value_to_zero: Optional[bool] = False
|
||||
whitelist_features: Optional[List[int]] = None
|
||||
assert_whitelist_feature_coverage: bool = True
|
||||
variance_threshold: VarianceThreshold = VarianceThreshold()
|
||||
sequence_feature_id: Optional[int] = None
|
||||
|
||||
ignore_sanity_check_failure: bool = IGNORE_SANITY_CHECK_FAILURE
|
||||
ignore_sanity_check_task: bool = False
|
||||
variance_threshold: VarianceThreshold = VarianceThreshold()
|
||||
load_from_operator_id: Optional[int] = None
|
||||
skip_sanity_check: bool = False
|
||||
|
||||
# IdMappings are stored in manifold folder:
|
||||
# "tree/{namespace}/{tablename}/{ds}/{base_mapping_name}/{embedding_table_name}"
|
||||
base_mapping_name: str = "DefaultMappingName"
|
||||
sequence_feature_id: Optional[int] = None
|
||||
|
||||
### below here for preprocessing sparse features ###
|
||||
# If the number of occurrences of any raw features ids is lower than this, we
|
||||
# ignore those feature ids when constructing the IdMapping
|
||||
sparse_threshold: int = 0
|
||||
# IdMappings are stored in manifold folder:
|
||||
# "tree/{namespace}/{tablename}/{ds}/{base_mapping_name}/{embedding_table_name}"
|
||||
base_mapping_name: str = "DefaultMappingName"
|
||||
|
||||
|
||||
class NoDuplicatedWarningLogger:
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
self.msg = set()
|
||||
|
||||
def warning(self, msg):
|
||||
if msg not in self.msg:
|
||||
self.logger.warning(msg)
|
||||
self.msg.add(msg)
|
||||
@ModelFeatureConfigProvider.fill_union()
|
||||
class ModelFeatureConfigProvider__Union(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
no_dup_logger = NoDuplicatedWarningLogger(logger)
|
||||
@PublishingResult.fill_union()
|
||||
class PublishingResult__Union(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
def isinstance_namedtuple(x):
|
||||
return isinstance(x, tuple) and hasattr(x, "_fields")
|
||||
@ValidationResult.fill_union()
|
||||
class ValidationResult__Union(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
@TrainingReport.fill_union()
|
||||
class RLTrainingReport(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorDataClass(BaseDataClass):
|
||||
def __getattr__(self, attr):
|
||||
if attr.startswith("__") and attr.endswith("__"):
|
||||
raise AttributeError
|
||||
|
||||
tensor_attr = getattr(torch.Tensor, attr, None)
|
||||
|
||||
if tensor_attr is None or not callable(tensor_attr):
|
||||
logger.error(
|
||||
f"Attemping to call torch.Tensor.{attr} on "
|
||||
f"{type(self)} (instance of TensorDataClass)."
|
||||
)
|
||||
if tensor_attr is None:
|
||||
raise AttributeError(f"torch.Tensor doesn't have {attr} attribute.")
|
||||
else:
|
||||
raise RuntimeError(f"Tensor.{attr} is not callable.")
|
||||
|
||||
def continuation(*args, **kwargs):
|
||||
def f(v):
|
||||
# if possible, returns v.attr(*args, **kwargs).
|
||||
# otws, return v
|
||||
if isinstance(v, (torch.Tensor, TensorDataClass)):
|
||||
return getattr(v, attr)(*args, **kwargs)
|
||||
elif isinstance(v, dict):
|
||||
return {kk: f(vv) for kk, vv in v.items()}
|
||||
elif isinstance(v, tuple):
|
||||
return tuple(f(vv) for vv in v)
|
||||
return v
|
||||
|
||||
return type(self)(**f(self.__dict__))
|
||||
|
||||
return continuation
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
cuda_tensor = {}
|
||||
for k, v in self.__dict__.items(): # noqa F402
|
||||
if isinstance(v, torch.Tensor):
|
||||
kwargs["non_blocking"] = kwargs.get("non_blocking", True)
|
||||
cuda_tensor[k] = v.cuda(*args, **kwargs)
|
||||
elif isinstance(v, TensorDataClass):
|
||||
cuda_tensor[k] = v.cuda(*args, **kwargs)
|
||||
else:
|
||||
cuda_tensor[k] = v
|
||||
return type(self)(**cuda_tensor)
|
||||
|
||||
|
||||
# (offset, value)
|
||||
IdListFeatureValue = Tuple[torch.Tensor, torch.Tensor]
|
||||
# (offset, key, value)
|
||||
IdScoreListFeatureValue = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
# name -> value
|
||||
IdListFeature = Dict[str, IdListFeatureValue]
|
||||
IdScoreListFeature = Dict[str, IdScoreListFeatureValue]
|
||||
# id -> value
|
||||
ServingIdListFeature = Dict[int, IdListFeatureValue]
|
||||
ServingIdScoreListFeature = Dict[int, IdScoreListFeatureValue]
|
||||
|
||||
|
||||
#####
|
||||
# FIXME: These config types are misplaced but we need to write FBL config adapter
|
||||
# if we moved them.
|
||||
######
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class IdListFeatureConfig(BaseDataClass):
|
||||
name: str
|
||||
# integer feature ID
|
||||
feature_id: int
|
||||
# name of the embedding table to use
|
||||
id_mapping_name: str
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class IdScoreListFeatureConfig(BaseDataClass):
|
||||
name: str
|
||||
# integer feature ID
|
||||
feature_id: int
|
||||
# name of the embedding table to use
|
||||
id_mapping_name: str
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class FloatFeatureInfo(BaseDataClass):
|
||||
name: str
|
||||
feature_id: int
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class IdMapping(object):
|
||||
__hash__ = param_hash
|
||||
|
||||
ids: List[int] = field(default_factory=list)
|
||||
|
||||
def __post_init_post_parse__(self):
|
||||
"""
|
||||
used in preprocessing
|
||||
ids list represents mapping from idx -> value
|
||||
we want the reverse: from feature to embedding table indices
|
||||
"""
|
||||
self._id2index: Dict[int, int] = {}
|
||||
|
||||
@property
|
||||
def id2index(self) -> Dict[int, int]:
|
||||
# pyre-fixme[16]: `IdMapping` has no attribute `_id2index`.
|
||||
if not self._id2index:
|
||||
self._id2index = {id: i for i, id in enumerate(self.ids)}
|
||||
return self._id2index
|
||||
|
||||
@property
|
||||
def table_size(self):
|
||||
return len(self.ids)
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class ModelFeatureConfig(BaseDataClass):
|
||||
float_feature_infos: List[FloatFeatureInfo] = field(default_factory=list)
|
||||
# table name -> id mapping
|
||||
id_mapping_config: Dict[str, IdMapping] = field(default_factory=dict)
|
||||
# id_list_feature_configs is feature_id -> list of values
|
||||
id_list_feature_configs: List[IdListFeatureConfig] = field(default_factory=list)
|
||||
# id_score_list_feature_configs is feature_id -> (keys -> values)
|
||||
id_score_list_feature_configs: List[IdScoreListFeatureConfig] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def __post_init_post_parse__(self):
|
||||
both_lists = self.id_list_feature_configs + self.id_score_list_feature_configs
|
||||
if not self.only_dense:
|
||||
# sanity check for keys in mapping config
|
||||
ids = [config.feature_id for config in both_lists]
|
||||
names = [config.name for config in both_lists]
|
||||
assert len(ids) == len(set(ids)), f"duplicates in ids: {ids}"
|
||||
assert len(names) == len(set(names)), f"duplicates in names: {names}"
|
||||
assert len(ids) == len(names), f"{len(ids)} != {len(names)}"
|
||||
|
||||
self._id2name = {config.feature_id: config.name for config in both_lists}
|
||||
self._name2id = {config.name: config.feature_id for config in both_lists}
|
||||
self._id2config = {config.feature_id: config for config in both_lists}
|
||||
self._name2config = {config.name: config for config in both_lists}
|
||||
|
||||
@property
|
||||
def only_dense(self):
|
||||
return not (self.id_list_feature_configs or self.id_score_list_feature_configs)
|
||||
|
||||
@property
|
||||
def id2name(self):
|
||||
return self._id2name
|
||||
|
||||
@property
|
||||
def name2id(self):
|
||||
return self._name2id
|
||||
|
||||
@property
|
||||
def id2config(self):
|
||||
return self._id2config
|
||||
|
||||
@property
|
||||
def name2config(self):
|
||||
return self._name2config
|
||||
|
||||
|
||||
######
|
||||
# dataclasses for internal API
|
||||
######
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValuePresence(TensorDataClass):
|
||||
value: torch.Tensor
|
||||
presence: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorOutput(TensorDataClass):
|
||||
action: torch.Tensor
|
||||
log_prob: Optional[torch.Tensor] = None
|
||||
squashed_mean: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocList(TensorDataClass):
|
||||
# the shape is (batch_size, num_candidates, num_document_features)
|
||||
float_features: torch.Tensor
|
||||
# the shapes are (batch_size, num_candidates)
|
||||
mask: torch.Tensor
|
||||
value: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
assert (
|
||||
len(self.float_features.shape) == 3
|
||||
), f"Unexpected shape: {self.float_features.shape}"
|
||||
|
||||
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||
# its type `no_grad` is not callable.
|
||||
@torch.no_grad()
|
||||
def select_slate(self, action: torch.Tensor):
|
||||
row_idx = torch.repeat_interleave(
|
||||
torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1
|
||||
)
|
||||
mask = self.mask[row_idx, action]
|
||||
# Make sure the indices are in the right range
|
||||
assert mask.to(torch.bool).all()
|
||||
float_features = self.float_features[row_idx, action]
|
||||
value = self.value[row_idx, action]
|
||||
return DocList(float_features, mask, value)
|
||||
|
||||
def as_feature_data(self):
|
||||
_batch_size, _slate_size, feature_dim = self.float_features.shape
|
||||
return FeatureData(self.float_features.view(-1, feature_dim))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureData(TensorDataClass):
|
||||
# For dense features, shape is (batch_size, feature_dim)
|
||||
float_features: torch.Tensor
|
||||
id_list_features: IdListFeature = dataclasses.field(default_factory=dict)
|
||||
id_score_list_features: IdScoreListFeature = dataclasses.field(default_factory=dict)
|
||||
# For sequence, shape is (stack_size, batch_size, feature_dim)
|
||||
stacked_float_features: Optional[torch.Tensor] = None
|
||||
# For ranking algos,
|
||||
candidate_docs: Optional[DocList] = None
|
||||
# Experimental: sticking this here instead of putting it in float_features
|
||||
# because a lot of places derive the shape of float_features from
|
||||
# normalization parameters.
|
||||
time_since_first: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def usage():
|
||||
return (
|
||||
"For sequence features, use `stacked_float_features`."
|
||||
"For document features, use `candidate_doc_float_features`."
|
||||
)
|
||||
|
||||
if self.float_features.ndim == 3:
|
||||
no_dup_logger.warning(f"`float_features` should be 2D.\n{usage()}")
|
||||
elif self.float_features.ndim != 2:
|
||||
raise ValueError(
|
||||
f"float_features should be 2D; got {self.float_features.shape}.\n{usage()}"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_float_features_only(self) -> bool:
|
||||
return (
|
||||
not self.id_list_features
|
||||
and self.time_since_first is None
|
||||
and self.candidate_docs is None
|
||||
)
|
||||
|
||||
def get_tiled_batch(self, num_tiles: int):
|
||||
assert (
|
||||
self.has_float_features_only
|
||||
), f"only works for float features now: {self}"
|
||||
"""
|
||||
tiled_feature should be (batch_size * num_tiles, feature_dim)
|
||||
forall i in [batch_size],
|
||||
tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
|
||||
"""
|
||||
feat = self.float_features
|
||||
assert (
|
||||
len(feat.shape) == 2
|
||||
), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
|
||||
batch_size, _ = feat.shape
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
|
||||
tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
|
||||
return FeatureData(float_features=tiled_feat)
|
||||
|
||||
|
||||
class TensorFeatureData(torch.nn.Module):
|
||||
"""
|
||||
Primarily for using in nn.Sequential
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> FeatureData:
|
||||
assert isinstance(input, torch.Tensor)
|
||||
return FeatureData(input)
|
||||
|
||||
|
||||
class ServingFeatureData(NamedTuple):
|
||||
float_features_with_presence: Tuple[torch.Tensor, torch.Tensor]
|
||||
id_list_features: ServingIdListFeature
|
||||
id_score_list_features: ServingIdScoreListFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessedRankingInput(TensorDataClass):
|
||||
state: FeatureData
|
||||
src_seq: FeatureData
|
||||
src_src_mask: torch.Tensor
|
||||
tgt_in_seq: Optional[FeatureData] = None
|
||||
tgt_out_seq: Optional[FeatureData] = None
|
||||
tgt_tgt_mask: Optional[torch.Tensor] = None
|
||||
slate_reward: Optional[torch.Tensor] = None
|
||||
position_reward: Optional[torch.Tensor] = None
|
||||
# all indices will be +2 to account for padding
|
||||
# symbol (0) and decoder_start_symbol (1)
|
||||
src_in_idx: Optional[torch.Tensor] = None
|
||||
tgt_in_idx: Optional[torch.Tensor] = None
|
||||
tgt_out_idx: Optional[torch.Tensor] = None
|
||||
tgt_out_probs: Optional[torch.Tensor] = None
|
||||
# store ground-truth target sequences
|
||||
optim_tgt_in_idx: Optional[torch.Tensor] = None
|
||||
optim_tgt_out_idx: Optional[torch.Tensor] = None
|
||||
optim_tgt_in_seq: Optional[FeatureData] = None
|
||||
optim_tgt_out_seq: Optional[FeatureData] = None
|
||||
|
||||
def batch_size(self) -> int:
|
||||
return self.state.float_features.size()[0]
|
||||
|
||||
@classmethod
|
||||
def from_tensors(
|
||||
cls,
|
||||
state: torch.Tensor,
|
||||
src_seq: torch.Tensor,
|
||||
src_src_mask: torch.Tensor,
|
||||
tgt_in_seq: Optional[torch.Tensor] = None,
|
||||
tgt_out_seq: Optional[torch.Tensor] = None,
|
||||
tgt_tgt_mask: Optional[torch.Tensor] = None,
|
||||
slate_reward: Optional[torch.Tensor] = None,
|
||||
position_reward: Optional[torch.Tensor] = None,
|
||||
src_in_idx: Optional[torch.Tensor] = None,
|
||||
tgt_in_idx: Optional[torch.Tensor] = None,
|
||||
tgt_out_idx: Optional[torch.Tensor] = None,
|
||||
tgt_out_probs: Optional[torch.Tensor] = None,
|
||||
optim_tgt_in_idx: Optional[torch.Tensor] = None,
|
||||
optim_tgt_out_idx: Optional[torch.Tensor] = None,
|
||||
optim_tgt_in_seq: Optional[torch.Tensor] = None,
|
||||
optim_tgt_out_seq: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert isinstance(src_seq, torch.Tensor)
|
||||
assert isinstance(src_src_mask, torch.Tensor)
|
||||
assert tgt_in_seq is None or isinstance(tgt_in_seq, torch.Tensor)
|
||||
assert tgt_out_seq is None or isinstance(tgt_out_seq, torch.Tensor)
|
||||
assert tgt_tgt_mask is None or isinstance(tgt_tgt_mask, torch.Tensor)
|
||||
assert slate_reward is None or isinstance(slate_reward, torch.Tensor)
|
||||
assert position_reward is None or isinstance(position_reward, torch.Tensor)
|
||||
assert src_in_idx is None or isinstance(src_in_idx, torch.Tensor)
|
||||
assert tgt_in_idx is None or isinstance(tgt_in_idx, torch.Tensor)
|
||||
assert tgt_out_idx is None or isinstance(tgt_out_idx, torch.Tensor)
|
||||
assert tgt_out_probs is None or isinstance(tgt_out_probs, torch.Tensor)
|
||||
assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor)
|
||||
assert optim_tgt_out_idx is None or isinstance(optim_tgt_out_idx, torch.Tensor)
|
||||
assert optim_tgt_in_seq is None or isinstance(optim_tgt_in_seq, torch.Tensor)
|
||||
assert optim_tgt_out_seq is None or isinstance(optim_tgt_out_seq, torch.Tensor)
|
||||
|
||||
return cls(
|
||||
state=FeatureData(float_features=state),
|
||||
src_seq=FeatureData(float_features=src_seq),
|
||||
src_src_mask=src_src_mask,
|
||||
tgt_in_seq=FeatureData(float_features=tgt_in_seq)
|
||||
if tgt_in_seq is not None
|
||||
else None,
|
||||
tgt_out_seq=FeatureData(float_features=tgt_out_seq)
|
||||
if tgt_out_seq is not None
|
||||
else None,
|
||||
tgt_tgt_mask=tgt_tgt_mask,
|
||||
slate_reward=slate_reward,
|
||||
position_reward=position_reward,
|
||||
src_in_idx=src_in_idx,
|
||||
tgt_in_idx=tgt_in_idx,
|
||||
tgt_out_idx=tgt_out_idx,
|
||||
tgt_out_probs=tgt_out_probs,
|
||||
optim_tgt_in_idx=optim_tgt_in_idx,
|
||||
optim_tgt_out_idx=optim_tgt_out_idx,
|
||||
optim_tgt_in_seq=FeatureData(float_features=optim_tgt_in_seq)
|
||||
if optim_tgt_in_seq is not None
|
||||
else None,
|
||||
optim_tgt_out_seq=FeatureData(float_features=optim_tgt_out_seq)
|
||||
if optim_tgt_out_seq is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
isinstance(self.state, torch.Tensor)
|
||||
or isinstance(self.src_seq, torch.Tensor)
|
||||
or isinstance(self.tgt_in_seq, torch.Tensor)
|
||||
or isinstance(self.tgt_out_seq, torch.Tensor)
|
||||
or isinstance(self.optim_tgt_in_seq, torch.Tensor)
|
||||
or isinstance(self.optim_tgt_out_seq, torch.Tensor)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Use from_tensors() {type(self.state)} {type(self.src_seq)} "
|
||||
f"{type(self.tgt_in_seq)} {type(self.tgt_out_seq)} "
|
||||
f"{type(self.optim_tgt_in_seq)} {type(self.optim_tgt_out_seq)} "
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseInput(TensorDataClass):
|
||||
"""
|
||||
Base class for all inputs, both raw and preprocessed
|
||||
"""
|
||||
|
||||
state: FeatureData
|
||||
next_state: FeatureData
|
||||
reward: torch.Tensor
|
||||
time_diff: torch.Tensor
|
||||
step: Optional[torch.Tensor]
|
||||
not_terminal: torch.Tensor
|
||||
|
||||
def batch_size(self):
|
||||
return self.state.float_features.size()[0]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, batch):
|
||||
id_list_features = batch.get(InputColumn.STATE_ID_LIST_FEATURES, None) or {}
|
||||
id_score_list_features = (
|
||||
batch.get(InputColumn.STATE_ID_SCORE_LIST_FEATURES, None) or {}
|
||||
)
|
||||
next_id_list_features = (
|
||||
batch.get(InputColumn.NEXT_STATE_ID_LIST_FEATURES, None) or {}
|
||||
)
|
||||
next_id_score_list_features = (
|
||||
batch.get(InputColumn.NEXT_STATE_ID_SCORE_LIST_FEATURES, None) or {}
|
||||
)
|
||||
return BaseInput(
|
||||
state=FeatureData(
|
||||
float_features=batch[InputColumn.STATE_FEATURES],
|
||||
id_list_features=id_list_features,
|
||||
id_score_list_features=id_score_list_features,
|
||||
),
|
||||
next_state=FeatureData(
|
||||
float_features=batch[InputColumn.NEXT_STATE_FEATURES],
|
||||
id_list_features=next_id_list_features,
|
||||
id_score_list_features=next_id_score_list_features,
|
||||
),
|
||||
reward=batch[InputColumn.REWARD],
|
||||
time_diff=batch[InputColumn.TIME_DIFF],
|
||||
step=batch[InputColumn.STEP],
|
||||
not_terminal=batch[InputColumn.NOT_TERMINAL],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraData(TensorDataClass):
|
||||
mdp_id: Optional[torch.Tensor] = None
|
||||
sequence_number: Optional[torch.Tensor] = None
|
||||
action_probability: Optional[torch.Tensor] = None
|
||||
max_num_actions: Optional[int] = None
|
||||
metrics: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**{f.name: d.get(f.name, None) for f in dataclasses.fields(cls)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscreteDqnInput(BaseInput):
|
||||
action: torch.Tensor
|
||||
next_action: torch.Tensor
|
||||
possible_actions_mask: torch.Tensor
|
||||
possible_next_actions_mask: torch.Tensor
|
||||
extras: ExtraData
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, batch):
|
||||
base = super().from_dict(batch)
|
||||
return cls(
|
||||
state=base.state,
|
||||
next_state=base.next_state,
|
||||
reward=base.reward,
|
||||
time_diff=base.time_diff,
|
||||
step=base.step,
|
||||
not_terminal=base.not_terminal,
|
||||
action=batch[InputColumn.ACTION],
|
||||
next_action=batch[InputColumn.NEXT_ACTION],
|
||||
possible_actions_mask=batch[InputColumn.POSSIBLE_ACTIONS_MASK],
|
||||
possible_next_actions_mask=batch[InputColumn.POSSIBLE_NEXT_ACTIONS_MASK],
|
||||
extras=batch[InputColumn.EXTRAS],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlateQInput(BaseInput):
|
||||
"""
|
||||
The shapes of `reward`, `reward_mask`, & `next_item_mask` are
|
||||
`(batch_size, slate_size)`.
|
||||
|
||||
`reward_mask` indicated whether the reward could be observed, e.g.,
|
||||
the item got into viewport or not.
|
||||
"""
|
||||
|
||||
action: torch.Tensor
|
||||
next_action: torch.Tensor
|
||||
reward_mask: torch.Tensor
|
||||
extras: Optional[ExtraData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
action = d["action"]
|
||||
next_action = d["next_action"]
|
||||
return cls(
|
||||
state=FeatureData(
|
||||
float_features=d["state_features"],
|
||||
candidate_docs=DocList(
|
||||
float_features=d["candidate_features"],
|
||||
mask=d["item_mask"],
|
||||
value=d["item_probability"],
|
||||
),
|
||||
),
|
||||
next_state=FeatureData(
|
||||
float_features=d["next_state_features"],
|
||||
candidate_docs=DocList(
|
||||
float_features=d["next_candidate_features"],
|
||||
mask=d["next_item_mask"],
|
||||
value=d["next_item_probability"],
|
||||
),
|
||||
),
|
||||
action=action,
|
||||
next_action=next_action,
|
||||
reward=d["position_reward"],
|
||||
reward_mask=d["reward_mask"],
|
||||
time_diff=d["time_diff"],
|
||||
not_terminal=d["not_terminal"],
|
||||
step=None,
|
||||
extras=ExtraData.from_dict(d),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParametricDqnInput(BaseInput):
|
||||
action: FeatureData
|
||||
next_action: FeatureData
|
||||
possible_actions: FeatureData
|
||||
possible_actions_mask: torch.Tensor
|
||||
possible_next_actions: FeatureData
|
||||
possible_next_actions_mask: torch.Tensor
|
||||
extras: Optional[ExtraData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, batch):
|
||||
return cls(
|
||||
state=FeatureData(float_features=batch["state_features"]),
|
||||
action=FeatureData(float_features=batch["action"]),
|
||||
next_state=FeatureData(float_features=batch["next_state_features"]),
|
||||
next_action=FeatureData(float_features=batch["next_action"]),
|
||||
possible_actions=FeatureData(float_features=batch["possible_actions"]),
|
||||
possible_actions_mask=batch["possible_actions_mask"],
|
||||
possible_next_actions=FeatureData(
|
||||
float_features=batch["possible_next_actions"]
|
||||
),
|
||||
possible_next_actions_mask=batch["possible_next_actions_mask"],
|
||||
reward=batch["reward"],
|
||||
not_terminal=batch["not_terminal"],
|
||||
time_diff=batch["time_diff"],
|
||||
step=batch["step"],
|
||||
extras=batch["extras"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyNetworkInput(BaseInput):
|
||||
action: FeatureData
|
||||
next_action: FeatureData
|
||||
extras: Optional[ExtraData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, batch):
|
||||
return cls(
|
||||
state=FeatureData(float_features=batch["state_features"]),
|
||||
action=FeatureData(float_features=batch["action"]),
|
||||
next_state=FeatureData(float_features=batch["next_state_features"]),
|
||||
next_action=FeatureData(float_features=batch["next_action"]),
|
||||
reward=batch["reward"],
|
||||
not_terminal=batch["not_terminal"],
|
||||
time_diff=batch["time_diff"],
|
||||
step=batch["step"],
|
||||
extras=batch["extras"],
|
||||
)
|
||||
|
||||
def batch_size(self) -> int:
|
||||
return self.state.float_features.shape[0]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyGradientInput(BaseDataClass):
|
||||
state: FeatureData
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def input_prototype(cls):
|
||||
num_classes = 5
|
||||
batch_size = 10
|
||||
state_dim = 3
|
||||
return cls(
|
||||
state=FeatureData(float_features=torch.randn(batch_size, state_dim)),
|
||||
action=F.one_hot(torch.randint(high=num_classes, size=(batch_size,))),
|
||||
reward=torch.rand(batch_size),
|
||||
log_prob=torch.log(torch.rand(batch_size)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryNetworkInput(BaseInput):
|
||||
action: torch.Tensor
|
||||
|
||||
def batch_size(self):
|
||||
if len(self.state.float_features.size()) == 2:
|
||||
return self.state.float_features.size()[0]
|
||||
elif len(self.state.float_features.size()) == 3:
|
||||
return self.state.float_features.size()[1]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessedTrainingBatch(TensorDataClass):
|
||||
training_input: Union[PreprocessedRankingInput]
|
||||
# TODO: deplicate this and move into individual ones.
|
||||
extras: ExtraData = field(default_factory=ExtraData)
|
||||
|
||||
def batch_size(self):
|
||||
return self.training_input.state.float_features.size()[0]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryNetworkOutput(TensorDataClass):
|
||||
mus: torch.Tensor
|
||||
sigmas: torch.Tensor
|
||||
logpi: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
not_terminal: torch.Tensor
|
||||
last_step_lstm_hidden: torch.Tensor
|
||||
last_step_lstm_cell: torch.Tensor
|
||||
all_steps_lstm_hidden: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2RewardOutput(TensorDataClass):
|
||||
acc_reward: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DqnPolicyActionSet(TensorDataClass):
|
||||
greedy: int
|
||||
softmax: Optional[int] = None
|
||||
greedy_act_name: Optional[str] = None
|
||||
softmax_act_name: Optional[str] = None
|
||||
softmax_act_prob: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanningPolicyOutput(TensorDataClass):
|
||||
# best action to take next
|
||||
next_best_continuous_action: Optional[torch.Tensor] = None
|
||||
next_best_discrete_action_one_hot: Optional[torch.Tensor] = None
|
||||
next_best_discrete_action_idx: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankingOutput(TensorDataClass):
|
||||
# a tensor of integer indices w.r.t. to possible candidates
|
||||
# shape: batch_size, tgt_seq_len
|
||||
ranked_tgt_out_idx: Optional[torch.Tensor] = None
|
||||
# generative probability of ranked tgt sequences at each decoding step
|
||||
# shape: batch_size, tgt_seq_len, candidate_size
|
||||
ranked_tgt_out_probs: Optional[torch.Tensor] = None
|
||||
# log probabilities of given tgt sequences are used in REINFORCE
|
||||
# shape: batch_size
|
||||
log_probs: Optional[torch.Tensor] = None
|
||||
# encoder scores in tgt_out_idx order
|
||||
encoder_scores: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardNetworkOutput(TensorDataClass):
|
||||
predicted_reward: torch.Tensor
|
||||
class RLTrainingOutput:
|
||||
validation_result: Optional[ValidationResult__Union] = None
|
||||
publishing_result: Optional[PublishingResult__Union] = None
|
||||
training_report: Optional[RLTrainingReport] = None
|
||||
output_path: Optional[str] = None
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
from reagent.core.tagged_union import TaggedUnion
|
||||
from reagent.models.model_feature_config_provider import ModelFeatureConfigProvider
|
||||
from reagent.reporting.result_registries import PublishingResult, ValidationResult
|
||||
from reagent.reporting.training_reports import TrainingReport
|
||||
|
||||
|
||||
if True: # Register modules for unions
|
||||
import reagent.reporting.oss_training_reports # noqa
|
||||
import reagent.core.result_types # noqa
|
||||
|
||||
if IS_FB_ENVIRONMENT:
|
||||
import reagent.reporting.fb.fb_training_reports # noqa
|
||||
import reagent.fb.models.model_feature_config_builder # noqa
|
||||
import reagent.core.fb.fb_result_types # noqa
|
||||
import reagent.core.fb.fb_types # noqa
|
||||
|
||||
|
||||
@ModelFeatureConfigProvider.fill_union()
|
||||
class ModelFeatureConfigProvider__Union(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
@PublishingResult.fill_union()
|
||||
class PublishingResult__Union(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
@ValidationResult.fill_union()
|
||||
class ValidationResult__Union(TaggedUnion):
|
||||
pass
|
||||
|
||||
|
||||
@TrainingReport.fill_union()
|
||||
class TrainingReport__Union(TaggedUnion):
|
||||
pass
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from reagent.core.types import Dataset, PreprocessingOptions, ReaderOptions, TableSpec
|
||||
from reagent.parameters import NormalizationParameters
|
||||
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataFetcher:
|
||||
# TODO: T71636145 Make a more specific API for DataFetcher
|
||||
def query_data(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
# TODO: T71636145 Make a more specific API for DataFetcher
|
||||
def query_data_parametric(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def identify_normalization_parameters(
|
||||
self,
|
||||
table_spec: TableSpec,
|
||||
column_name: str,
|
||||
preprocessing_options: PreprocessingOptions,
|
||||
seed: Optional[int] = None,
|
||||
) -> Dict[int, NormalizationParameters]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dataloader(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
batch_preprocessor: Optional[BatchPreprocessor],
|
||||
use_gpu: bool,
|
||||
reader_options: ReaderOptions,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
@@ -3,8 +3,8 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from reagent.core.types import MemoryNetworkInput
|
||||
from reagent.training.world_model.compress_model_trainer import CompressModelTrainer
|
||||
from reagent.types import MemoryNetworkInput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -8,17 +8,8 @@ from typing import NamedTuple, Optional, cast
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet
|
||||
from reagent.ope.estimators.sequential_estimators import (
|
||||
Action,
|
||||
ActionSpace,
|
||||
RLEstimatorInput,
|
||||
RLPolicy,
|
||||
State,
|
||||
Transition,
|
||||
ValueFunction,
|
||||
)
|
||||
from reagent.torch_utils import masked_softmax
|
||||
from reagent.training import ParametricDQNTrainer
|
||||
from reagent.training.dqn_trainer import DQNTrainer
|
||||
@@ -51,7 +42,6 @@ class EvaluationDataPage(NamedTuple):
|
||||
model_metrics_values_for_logged_action: Optional[torch.Tensor] = None
|
||||
possible_actions_state_concat: Optional[torch.Tensor] = None
|
||||
contexts: Optional[torch.Tensor] = None
|
||||
sequential_estimator_input: Optional[RLEstimatorInput] = None
|
||||
|
||||
@classmethod
|
||||
def create_from_training_batch(
|
||||
@@ -320,83 +310,6 @@ class EvaluationDataPage(NamedTuple):
|
||||
eval_action_idxs=eval_action_idxs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_rl_estimator_input_from_tensors_dqn(
|
||||
trainer: DQNTrainer,
|
||||
mdp_ids: torch.Tensor,
|
||||
states: rlt.FeatureData,
|
||||
actions: rlt.FeatureData,
|
||||
propensities: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
):
|
||||
class DQNRLPolicy(RLPolicy):
|
||||
def __init__(self, trainer: DQNTrainer):
|
||||
super().__init__(ActionSpace(trainer.num_actions))
|
||||
self._trainer = trainer
|
||||
|
||||
def action_dist(self, state: State):
|
||||
feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1))
|
||||
# Only 1 batch
|
||||
q_values = self._trainer.get_detached_q_values(feat_data)[0][0]
|
||||
return self._action_space.distribution(
|
||||
torch.nn.Softmax(dim=0)(q_values)
|
||||
)
|
||||
|
||||
class CPEValueFunction(ValueFunction):
|
||||
def __init__(self, trainer: DQNTrainer):
|
||||
self._trainer = trainer
|
||||
|
||||
def state_action_value(self, state: State, action: Action) -> float:
|
||||
feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1))
|
||||
model_values = self._trainer.q_network_cpe(feat_data)[
|
||||
:, 0 : self._trainer.num_actions
|
||||
][0]
|
||||
return model_values[action.value].item()
|
||||
|
||||
def state_value(self, state: State) -> float:
|
||||
feat_data = rlt.FeatureData(float_features=state.value.reshape(1, -1))
|
||||
model_values = self._trainer.q_network_cpe(feat_data)[
|
||||
:, 0 : self._trainer.num_actions
|
||||
][0]
|
||||
q_values = self._trainer.get_detached_q_values(feat_data)[0][0]
|
||||
dist = torch.nn.Softmax(dim=0)(q_values)
|
||||
assert dist.shape == model_values.shape
|
||||
return torch.dot(dist, model_values).item()
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
states_tensor = states.float_features
|
||||
logged_actions = torch.argmax(actions.float(), dim=1)
|
||||
log = []
|
||||
cur_mdp = []
|
||||
i = 0
|
||||
while i < mdp_ids.shape[0]:
|
||||
if i + 1 < mdp_ids.shape[0] and mdp_ids[i, 0] == mdp_ids[i + 1, 0]:
|
||||
cur_mdp.append(
|
||||
Transition(
|
||||
last_state=State(states_tensor[i]),
|
||||
action=Action(logged_actions[i].item()),
|
||||
action_prob=propensities[i, 0].item(),
|
||||
state=State(states_tensor[i + 1]),
|
||||
reward=rewards[i, 0].item(),
|
||||
status=Transition.Status.NORMAL,
|
||||
)
|
||||
)
|
||||
elif len(cur_mdp) > 0:
|
||||
log.append(cur_mdp)
|
||||
cur_mdp = []
|
||||
i += 1
|
||||
|
||||
# Temporary value of gamma
|
||||
return RLEstimatorInput(
|
||||
gamma=1.0,
|
||||
log=log,
|
||||
target_policy=DQNRLPolicy(trainer),
|
||||
value_function=CPEValueFunction(trainer),
|
||||
discrete_states=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||
# its type `no_grad` is not callable.
|
||||
@@ -541,9 +454,6 @@ class EvaluationDataPage(NamedTuple):
|
||||
possible_actions_mask=possible_actions_mask,
|
||||
optimal_q_values=optimal_q_values,
|
||||
eval_action_idxs=eval_action_idxs,
|
||||
sequential_estimator_input=EvaluationDataPage.create_rl_estimator_input_from_tensors_dqn(
|
||||
trainer, mdp_ids, states, actions, propensities, rewards
|
||||
),
|
||||
)
|
||||
|
||||
def append(self, edp):
|
||||
@@ -560,15 +470,6 @@ class EvaluationDataPage(NamedTuple):
|
||||
new_edp[x] = torch.cat((t, other_t), dim=0)
|
||||
elif isinstance(t, np.ndarray):
|
||||
new_edp[x] = np.concatenate((t, other_t), axis=0)
|
||||
elif isinstance(t, RLEstimatorInput):
|
||||
t.log.extend(other_t.log)
|
||||
new_edp[x] = RLEstimatorInput(
|
||||
gamma=t.gamma,
|
||||
log=t.log,
|
||||
target_policy=t.target_policy,
|
||||
value_function=t.value_function,
|
||||
discrete_states=t.discrete_states,
|
||||
)
|
||||
else:
|
||||
raise Exception("Invalid type in training data page")
|
||||
else:
|
||||
@@ -583,10 +484,7 @@ class EvaluationDataPage(NamedTuple):
|
||||
new_edp = {}
|
||||
for x in EvaluationDataPage._fields:
|
||||
t = getattr(self, x)
|
||||
if hasattr(t, "__getitem__"):
|
||||
new_edp[x] = t[sorted_idxs] if t is not None else None
|
||||
else:
|
||||
new_edp[x] = t
|
||||
new_edp[x] = t[sorted_idxs] if t is not None else None
|
||||
|
||||
return EvaluationDataPage(**new_edp)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.core import types as rlt
|
||||
from reagent.core.tracker import observable
|
||||
from reagent.evaluation.cpe import CpeDetails, CpeEstimateSet
|
||||
from reagent.evaluation.doubly_robust_estimator import DoublyRobustEstimator
|
||||
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
|
||||
@@ -53,6 +53,7 @@ def get_metrics_to_score(metric_reward_values: Optional[Dict[str, float]]) -> Li
|
||||
return sorted([*metric_reward_values.keys()])
|
||||
|
||||
|
||||
@observable(cpe_details=CpeDetails)
|
||||
class Evaluator:
|
||||
NUM_J_STEPS_FOR_MAGIC_ESTIMATOR = 25
|
||||
|
||||
@@ -69,15 +70,7 @@ class Evaluator:
|
||||
gamma
|
||||
)
|
||||
|
||||
self.reporter = None
|
||||
|
||||
def evaluate(self, eval_input: rlt.TensorDataClass) -> None:
|
||||
pass
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
def evaluate_one_shot(self, edp: EvaluationDataPage) -> CpeDetails:
|
||||
def evaluate_post_training(self, edp: EvaluationDataPage) -> CpeDetails:
|
||||
cpe_details = CpeDetails()
|
||||
|
||||
cpe_details.reward_estimates = self.score_cpe("Reward", edp)
|
||||
@@ -123,9 +116,8 @@ class Evaluator:
|
||||
cpe_details.mc_loss = float(
|
||||
F.mse_loss(edp.logged_values, edp.model_values_for_logged_action)
|
||||
)
|
||||
|
||||
assert self.reporter is not None, "Missing reporter"
|
||||
self.reporter.report(cpe_results=cpe_details)
|
||||
# pyre-fixme[16]: `Evaluator` has no attribute `notify_observers`.
|
||||
self.notify_observers(cpe_details=cpe_details)
|
||||
return cpe_details
|
||||
|
||||
def score_cpe(self, metric_name, edp: EvaluationDataPage):
|
||||
|
||||
@@ -11,6 +11,9 @@ from reagent.evaluation.cpe import (
|
||||
)
|
||||
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
|
||||
from reagent.evaluation.evaluator import Evaluator
|
||||
from reagent.evaluation.weighted_sequential_doubly_robust_estimator import (
|
||||
WeightedSequentialDoublyRobustEstimator,
|
||||
)
|
||||
from reagent.ope.estimators.contextual_bandits_estimators import (
|
||||
BanditsEstimatorInput,
|
||||
DMEstimator,
|
||||
@@ -31,6 +34,10 @@ from reagent.ope.estimators.sequential_estimators import (
|
||||
MAGICEstimator,
|
||||
RLEstimator,
|
||||
RLEstimatorInput,
|
||||
RLPolicy,
|
||||
State,
|
||||
Transition,
|
||||
ValueFunction,
|
||||
)
|
||||
from reagent.ope.estimators.types import ActionSpace
|
||||
|
||||
@@ -109,6 +116,92 @@ class SequentialOPEstimatorAdapter:
|
||||
self.gamma = gamma
|
||||
self._device = device
|
||||
|
||||
class EDPSeqPolicy(RLPolicy):
|
||||
def __init__(
|
||||
self, num_actions: int, model_propensities: torch.Tensor, device=None
|
||||
):
|
||||
super().__init__(ActionSpace(num_actions), device)
|
||||
self.model_propensities = model_propensities
|
||||
|
||||
def action_dist(self, state: State) -> ActionDistribution:
|
||||
# "state" is (trajectory, step)
|
||||
return self.model_propensities[state.value]
|
||||
|
||||
class EDPValueFunc(ValueFunction):
|
||||
def __init__(
|
||||
self, model_values: torch.Tensor, target_propensities: torch.Tensor
|
||||
):
|
||||
self.model_values = model_values
|
||||
self.target_propensities = target_propensities
|
||||
|
||||
def state_action_value(self, state: State, action: Action) -> float:
|
||||
return self.model_values[state.value][action].item()
|
||||
|
||||
def state_value(self, state: State) -> float:
|
||||
return torch.dot(
|
||||
self.model_values[state.value], self.target_propensities[state.value]
|
||||
).item()
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def edp_to_rl_input(
|
||||
edp: EvaluationDataPage, gamma, device=None
|
||||
) -> RLEstimatorInput:
|
||||
assert edp.model_values is not None
|
||||
eq_len = WeightedSequentialDoublyRobustEstimator.transform_to_equal_length_trajectories(
|
||||
edp.mdp_id,
|
||||
edp.action_mask.cpu().numpy(),
|
||||
edp.logged_rewards.cpu().numpy().flatten(),
|
||||
edp.logged_propensities.cpu().numpy().flatten(),
|
||||
edp.model_propensities.cpu().numpy(),
|
||||
edp.model_values.cpu().numpy(),
|
||||
)
|
||||
|
||||
(
|
||||
actions,
|
||||
rewards,
|
||||
logged_propensities,
|
||||
target_propensities,
|
||||
estimated_q_values,
|
||||
) = (
|
||||
torch.tensor(x, dtype=torch.double, device=device, requires_grad=True)
|
||||
for x in eq_len
|
||||
)
|
||||
|
||||
num_examples = logged_propensities.shape[0]
|
||||
horizon = logged_propensities.shape[1]
|
||||
|
||||
log = []
|
||||
for traj in range(num_examples):
|
||||
log.append(
|
||||
[
|
||||
Transition(
|
||||
last_state=State((traj, i)),
|
||||
action=torch.argmax(actions[traj, i]).item(),
|
||||
action_prob=logged_propensities[traj, i].item(),
|
||||
state=State((traj, i + 1)),
|
||||
reward=rewards[traj, i].item(),
|
||||
)
|
||||
for i in range(horizon - 1)
|
||||
if actions[traj, i][torch.argmax(actions[traj, i]).item()] != 0.0
|
||||
]
|
||||
)
|
||||
|
||||
return RLEstimatorInput(
|
||||
gamma=gamma,
|
||||
log=log,
|
||||
target_policy=SequentialOPEstimatorAdapter.EDPSeqPolicy(
|
||||
actions.shape[2], target_propensities
|
||||
),
|
||||
value_function=SequentialOPEstimatorAdapter.EDPValueFunc(
|
||||
estimated_q_values, target_propensities
|
||||
),
|
||||
ground_truth=None,
|
||||
horizon=horizon,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def estimator_results_to_cpe_estimate(
|
||||
estimator_results: EstimatorResults,
|
||||
@@ -144,16 +237,8 @@ class SequentialOPEstimatorAdapter:
|
||||
)
|
||||
|
||||
def estimate(self, edp: EvaluationDataPage) -> CpeEstimate:
|
||||
est_input = edp.sequential_estimator_input
|
||||
assert est_input is not None, "EDP does not contain sequential estimator inputs"
|
||||
estimator_results = self.seq_ope_estimator.evaluate(
|
||||
RLEstimatorInput(
|
||||
gamma=self.gamma,
|
||||
log=est_input.log,
|
||||
target_policy=est_input.target_policy,
|
||||
value_function=est_input.value_function,
|
||||
discrete_states=est_input.discrete_states,
|
||||
)
|
||||
SequentialOPEstimatorAdapter.edp_to_rl_input(edp, self.gamma, self._device)
|
||||
)
|
||||
assert isinstance(estimator_results, EstimatorResults)
|
||||
return SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate(
|
||||
|
||||
@@ -7,8 +7,9 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from reagent.core.types import PreprocessedTrainingBatch
|
||||
from reagent.core.tracker import observable
|
||||
from reagent.models.seq2slate import Seq2SlateMode
|
||||
from reagent.types import PreprocessedTrainingBatch
|
||||
from sklearn.metrics import (
|
||||
average_precision_score,
|
||||
dcg_score,
|
||||
@@ -28,6 +29,17 @@ class ListwiseRankingMetrics:
|
||||
cross_entropy_loss: Optional[float] = 0.0
|
||||
|
||||
|
||||
@observable(
|
||||
cross_entropy_loss=torch.Tensor,
|
||||
dcg=torch.Tensor,
|
||||
ndcg=torch.Tensor,
|
||||
mean_ap=torch.Tensor,
|
||||
auc=torch.Tensor,
|
||||
base_dcg=torch.Tensor,
|
||||
base_ndcg=torch.Tensor,
|
||||
base_map=torch.Tensor,
|
||||
base_auc=torch.Tensor,
|
||||
)
|
||||
class RankingListwiseEvaluator:
|
||||
""" Evaluate listwise ranking models on common ranking metrics """
|
||||
|
||||
@@ -43,7 +55,6 @@ class RankingListwiseEvaluator:
|
||||
self.base_map = []
|
||||
self.log_softmax = nn.LogSoftmax(dim=1)
|
||||
self.kl_loss = nn.KLDivLoss(reduction="batchmean")
|
||||
self.reporter = None
|
||||
|
||||
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||
# its type `no_grad` is not callable.
|
||||
@@ -72,7 +83,9 @@ class RankingListwiseEvaluator:
|
||||
self.seq2slate_net.train(seq2slate_net_prev_mode)
|
||||
|
||||
if not self.calc_cpe:
|
||||
self.reporter.report_evaluation_minibatch(cross_entropy_loss=ce_loss)
|
||||
# pyre-fixme[16]: `RankingListwiseEvaluator` has no attribute
|
||||
# `notify_observers`.
|
||||
self.notify_observers(cross_entropy_loss=ce_loss)
|
||||
return
|
||||
|
||||
# shape: batch_size, tgt_seq_len
|
||||
@@ -119,7 +132,7 @@ class RankingListwiseEvaluator:
|
||||
batch_base_dcg.append(dcg_score(truth_scores, base_scores))
|
||||
batch_base_ndcg.append(ndcg_score(truth_scores, base_scores))
|
||||
|
||||
self.reporter.report_evaluation_minibatch(
|
||||
self.notify_observers(
|
||||
cross_entropy_loss=ce_loss,
|
||||
dcg=torch.mean(torch.tensor(batch_dcg)).reshape(1),
|
||||
ndcg=torch.mean(torch.tensor(batch_ndcg)).reshape(1),
|
||||
@@ -132,5 +145,5 @@ class RankingListwiseEvaluator:
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_one_shot(self):
|
||||
def evaluate_post_training(self):
|
||||
pass
|
||||
|
||||
@@ -8,15 +8,24 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from reagent.core.types import PreprocessedTrainingBatch
|
||||
from reagent.core.tracker import observable
|
||||
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
|
||||
from reagent.models.seq2slate import Seq2SlateMode
|
||||
from reagent.training.ranking.seq2slate_trainer import Seq2SlateTrainer
|
||||
from reagent.types import PreprocessedTrainingBatch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@observable(
|
||||
eval_baseline_loss=torch.Tensor,
|
||||
eval_advantages=torch.Tensor,
|
||||
logged_slate_rank_probs=torch.Tensor,
|
||||
ranked_slate_rank_probs=torch.Tensor,
|
||||
eval_data_pages_g=EvaluationDataPage,
|
||||
eval_data_pages_ng=EvaluationDataPage,
|
||||
)
|
||||
class RankingPolicyGradientEvaluator:
|
||||
""" Evaluate ranking models that are learned through policy gradient """
|
||||
|
||||
@@ -30,12 +39,13 @@ class RankingPolicyGradientEvaluator:
|
||||
self.trainer = trainer
|
||||
self.calc_cpe = calc_cpe
|
||||
self.reward_network = reward_network
|
||||
self.reporter = None
|
||||
|
||||
# Evaluate greedy/non-greedy version of the ranking model
|
||||
self.eval_data_pages_g: Optional[EvaluationDataPage] = None
|
||||
self.eval_data_pages_ng: Optional[EvaluationDataPage] = None
|
||||
|
||||
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||
# its type `no_grad` is not callable.
|
||||
@torch.no_grad()
|
||||
def evaluate(self, eval_tdp: PreprocessedTrainingBatch) -> None:
|
||||
seq2slate_net = self.trainer.seq2slate_net
|
||||
@@ -117,7 +127,9 @@ class RankingPolicyGradientEvaluator:
|
||||
else:
|
||||
self.eval_data_pages_ng = self.eval_data_pages_ng.append(edp_ng)
|
||||
|
||||
self.reporter.report_evaluation_minibatch(
|
||||
# pyre-fixme[16]: `RankingPolicyGradientEvaluator` has no attribute
|
||||
# `notify_observers`.
|
||||
self.notify_observers(
|
||||
eval_baseline_loss=eval_baseline_loss,
|
||||
eval_advantages=eval_advantage,
|
||||
logged_slate_rank_probs=logged_slate_rank_prob,
|
||||
@@ -125,13 +137,11 @@ class RankingPolicyGradientEvaluator:
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def finish(self):
|
||||
self.reporter.report_evaluation_epoch(
|
||||
def evaluate_post_training(self):
|
||||
self.notify_observers(
|
||||
# Use ValueListObserver as aggregating_observers requires input to be Tensor
|
||||
eval_data_pages_g=self.eval_data_pages_g,
|
||||
eval_data_pages_ng=self.eval_data_pages_ng,
|
||||
)
|
||||
self.eval_data_pages_g = None
|
||||
self.eval_data_pages_ng = None
|
||||
|
||||
def evaluate_one_shot(self, edp: EvaluationDataPage):
|
||||
pass
|
||||
|
||||
@@ -6,10 +6,9 @@ import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.core import types as rlt
|
||||
from reagent.core.types import PreprocessedTrainingBatch
|
||||
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
|
||||
from reagent import types as rlt
|
||||
from reagent.training.reward_network_trainer import RewardNetTrainer
|
||||
from reagent.types import PreprocessedTrainingBatch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +21,7 @@ class RewardNetEvaluator:
|
||||
self.trainer = trainer
|
||||
self.mse_loss = []
|
||||
self.rewards = []
|
||||
self.best_model = None
|
||||
self.best_model_loss = 1e9
|
||||
|
||||
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||
@@ -47,7 +47,7 @@ class RewardNetEvaluator:
|
||||
reward_net.train(reward_net_prev_mode)
|
||||
|
||||
@torch.no_grad()
|
||||
def finish(self):
|
||||
def evaluate_post_training(self):
|
||||
mean_mse_loss = np.mean(self.mse_loss)
|
||||
logger.info(f"Evaluation MSE={mean_mse_loss}")
|
||||
eval_res = {"mse": mean_mse_loss, "rewards": torch.cat(self.rewards)}
|
||||
@@ -56,9 +56,6 @@ class RewardNetEvaluator:
|
||||
|
||||
if mean_mse_loss < self.best_model_loss:
|
||||
self.best_model_loss = mean_mse_loss
|
||||
self.trainer.best_model = copy.deepcopy(self.trainer.reward_net)
|
||||
self.best_model = copy.deepcopy(self.trainer.reward_net)
|
||||
|
||||
return eval_res
|
||||
|
||||
def evaluate_one_shot(self, edp: EvaluationDataPage):
|
||||
pass
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from reagent.core.types import PreprocessedTrainingBatch
|
||||
from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer
|
||||
from reagent.types import PreprocessedTrainingBatch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -15,13 +15,15 @@ class Seq2RewardEvaluator:
|
||||
self.trainer = trainer
|
||||
self.reward_net = self.trainer.seq2reward_network
|
||||
|
||||
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||
# its type `no_grad` is not callable.
|
||||
@torch.no_grad()
|
||||
def evaluate(self, eval_tdp: PreprocessedTrainingBatch):
|
||||
reward_net_prev_mode = self.reward_net.training
|
||||
self.reward_net.eval()
|
||||
# pyre-fixme[6]: Expected `MemoryNetworkInput` for 1st param but got
|
||||
# `PreprocessedTrainingBatch`.
|
||||
loss = self.trainer.compute_loss(eval_tdp)
|
||||
loss = self.trainer.get_loss(eval_tdp)
|
||||
detached_loss = loss.cpu().detach().item()
|
||||
q_values = (
|
||||
self.trainer.get_Q(
|
||||
@@ -37,6 +39,3 @@ class Seq2RewardEvaluator:
|
||||
)
|
||||
self.reward_net.train(reward_net_prev_mode)
|
||||
return (detached_loss, q_values)
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
@@ -4,28 +4,23 @@ import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from reagent.core.types import FeatureData, MemoryNetworkInput
|
||||
from reagent.reporting.world_model_reporter import (
|
||||
DebugToolsReporter,
|
||||
WorldModelReporter,
|
||||
)
|
||||
from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer
|
||||
from reagent.types import FeatureData, MemoryNetworkInput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorldModelLossEvaluator(object):
|
||||
class LossEvaluator(object):
|
||||
""" Evaluate losses on data pages """
|
||||
|
||||
def __init__(self, trainer: MDNRNNTrainer, state_dim: int) -> None:
|
||||
self.trainer = trainer
|
||||
self.state_dim = state_dim
|
||||
self.reporter = WorldModelReporter(1)
|
||||
|
||||
def evaluate(self, tdp: MemoryNetworkInput) -> None:
|
||||
def evaluate(self, tdp: MemoryNetworkInput) -> Dict[str, float]:
|
||||
self.trainer.memory_network.mdnrnn.eval()
|
||||
losses = self.trainer.compute_loss(tdp, state_dim=self.state_dim)
|
||||
losses = self.trainer.get_loss(tdp, state_dim=self.state_dim)
|
||||
detached_losses = {
|
||||
"loss": losses["loss"].cpu().detach().item(),
|
||||
"gmm": losses["gmm"].cpu().detach().item(),
|
||||
@@ -34,10 +29,7 @@ class WorldModelLossEvaluator(object):
|
||||
}
|
||||
del losses
|
||||
self.trainer.memory_network.mdnrnn.train()
|
||||
self.reporter.report(**detached_losses)
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
return detached_losses
|
||||
|
||||
|
||||
class FeatureImportanceEvaluator(object):
|
||||
@@ -65,7 +57,6 @@ class FeatureImportanceEvaluator(object):
|
||||
self.action_feature_num = action_feature_num
|
||||
self.sorted_action_feature_start_indices = sorted_action_feature_start_indices
|
||||
self.sorted_state_feature_start_indices = sorted_state_feature_start_indices
|
||||
self.reporter = DebugToolsReporter()
|
||||
|
||||
def evaluate(self, batch: MemoryNetworkInput):
|
||||
""" Calculate feature importance: setting each state/action feature to
|
||||
@@ -80,7 +71,7 @@ class FeatureImportanceEvaluator(object):
|
||||
state_feature_num = self.state_feature_num
|
||||
feature_importance = torch.zeros(action_feature_num + state_feature_num)
|
||||
|
||||
orig_losses = self.trainer.compute_loss(batch, state_dim=state_dim)
|
||||
orig_losses = self.trainer.get_loss(batch, state_dim=state_dim)
|
||||
orig_loss = orig_losses["loss"].cpu().detach().item()
|
||||
del orig_losses
|
||||
|
||||
@@ -124,7 +115,7 @@ class FeatureImportanceEvaluator(object):
|
||||
not_terminal=batch.not_terminal,
|
||||
step=None,
|
||||
)
|
||||
losses = self.trainer.compute_loss(new_batch, state_dim=state_dim)
|
||||
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
|
||||
feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
|
||||
del losses
|
||||
|
||||
@@ -151,7 +142,7 @@ class FeatureImportanceEvaluator(object):
|
||||
not_terminal=batch.not_terminal,
|
||||
step=None,
|
||||
)
|
||||
losses = self.trainer.compute_loss(new_batch, state_dim=state_dim)
|
||||
losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
|
||||
feature_importance[i + action_feature_num] = (
|
||||
losses["loss"].cpu().detach().item() - orig_loss
|
||||
)
|
||||
@@ -161,7 +152,6 @@ class FeatureImportanceEvaluator(object):
|
||||
logger.info(
|
||||
"**** Debug tool feature importance ****: {}".format(feature_importance)
|
||||
)
|
||||
self.reporter.report(feature_importance=feature_importance.tolist())
|
||||
return {"feature_loss_increase": feature_importance.numpy()}
|
||||
|
||||
def compute_median_feature_value(self, features):
|
||||
@@ -180,9 +170,6 @@ class FeatureImportanceEvaluator(object):
|
||||
median_feature = features.mean(dim=0)
|
||||
return median_feature
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
|
||||
class FeatureSensitivityEvaluator(object):
|
||||
""" Evaluate state feature sensitivity caused by varying actions """
|
||||
@@ -196,7 +183,6 @@ class FeatureSensitivityEvaluator(object):
|
||||
self.trainer = trainer
|
||||
self.state_feature_num = state_feature_num
|
||||
self.sorted_state_feature_start_indices = sorted_state_feature_start_indices
|
||||
self.reporter = DebugToolsReporter()
|
||||
|
||||
def evaluate(self, batch: MemoryNetworkInput):
|
||||
""" Calculate state feature sensitivity due to actions:
|
||||
@@ -254,8 +240,4 @@ class FeatureSensitivityEvaluator(object):
|
||||
logger.info(
|
||||
"**** Debug tool feature sensitivity ****: {}".format(feature_sensitivity)
|
||||
)
|
||||
self.reporter.report(feature_sensitivity=feature_sensitivity.tolist())
|
||||
return {"feature_sensitivity": feature_sensitivity.numpy()}
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
@@ -19,7 +19,7 @@ import random
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.gym.envs.env_wrapper import EnvWrapper
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from gym import spaces
|
||||
from reagent.core.dataclasses import dataclass
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from gym import spaces
|
||||
from gym_minigrid.wrappers import ReseedWrapper
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from gym.spaces import Box
|
||||
from reagent.gym.envs.env_wrapper import EnvWrapper
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.gym.envs.env_wrapper import EnvWrapper
|
||||
from reagent.gym.envs.wrappers.recsim import ValueWrapper
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
from reagent.gym.types import Sampler, Scorer
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
from reagent.gym.policies import Policy
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.gym.policies.policy import Policy
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.gym.types import GaussianSamplerScore, Sampler
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.gym.types import Sampler
|
||||
@@ -41,9 +41,6 @@ class SoftmaxActionSampler(Sampler):
|
||||
assert raw_action.shape == (
|
||||
batch_size,
|
||||
), f"{raw_action.shape} != ({batch_size}, )"
|
||||
assert (
|
||||
int(raw_action.max().item()) < num_actions
|
||||
), f"Invalid action: {int(raw_action.max().item())}"
|
||||
action = F.one_hot(raw_action, num_actions)
|
||||
assert action.ndim == 2
|
||||
log_prob = m.log_prob(raw_action)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.gym.types import Sampler
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.gym.types import GaussianSamplerScore, Scorer
|
||||
from reagent.models.base import ModelBase
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.gym.preprocessors.trainer_preprocessor import get_possible_actions_for_gym
|
||||
from reagent.gym.types import Scorer
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.gym.types import Scorer
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gym import Env, spaces
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE
|
||||
|
||||
@@ -75,5 +75,5 @@ train_every_ts: 1
|
||||
train_after_ts: 20000
|
||||
num_train_episodes: 10
|
||||
num_eval_episodes: 10
|
||||
passing_score_bar: 190
|
||||
passing_score_bar: 200
|
||||
use_gpu: false
|
||||
|
||||
@@ -17,21 +17,13 @@ from reagent.gym.agents.post_step import train_with_replay_buffer_post_step
|
||||
from reagent.gym.envs.union import Env__Union
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes, run_episode
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.model_managers.union import ModelManager__Union
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.tensorboardX import summary_writer_context
|
||||
from reagent.test.base.horizon_test_base import HorizonTestBase
|
||||
from reagent.workflow.model_managers.union import ModelManager__Union
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
try:
|
||||
# Use internal runner or OSS otherwise
|
||||
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
|
||||
except ImportError:
|
||||
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
|
||||
|
||||
|
||||
# for seeding the environment
|
||||
SEED = 0
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -116,12 +108,13 @@ def run_test(
|
||||
normalization = build_normalizer(env)
|
||||
logger.info(f"Normalization is: \n{pprint.pformat(normalization)}")
|
||||
|
||||
manager: ModelManager = model.value
|
||||
runner = BatchRunner(use_gpu, manager, RewardOptions(), normalization)
|
||||
trainer = runner.initialize_trainer()
|
||||
reporter = manager.get_reporter()
|
||||
trainer.reporter = reporter
|
||||
training_policy = manager.create_policy(trainer)
|
||||
manager = model.value
|
||||
trainer = manager.initialize_trainer(
|
||||
use_gpu=use_gpu,
|
||||
reward_options=RewardOptions(),
|
||||
normalization_data_map=normalization,
|
||||
)
|
||||
training_policy = manager.create_policy(serving=False)
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
replay_capacity=replay_memory_size, batch_size=trainer.minibatch_size
|
||||
@@ -172,7 +165,7 @@ def run_test(
|
||||
f"{len(train_rewards)} episodes is less than < {passing_score_bar}.\n"
|
||||
)
|
||||
|
||||
serving_policy = manager.create_serving_policy(normalization, trainer)
|
||||
serving_policy = manager.create_policy(serving=True)
|
||||
agent = Agent.create_for_env_with_serving_policy(env, serving_policy)
|
||||
|
||||
eval_rewards = evaluate_for_n_episodes(
|
||||
|
||||
@@ -17,22 +17,14 @@ from reagent.gym.envs.gym import Gym
|
||||
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
from reagent.model_managers.union import ModelManager__Union
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.runners.oss_batch_runner import OssBatchRunner
|
||||
from reagent.tensorboardX import summary_writer_context
|
||||
from reagent.test.base.horizon_test_base import HorizonTestBase
|
||||
from reagent.workflow.model_managers.union import ModelManager__Union
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
try:
|
||||
# Use internal runner or OSS otherwise
|
||||
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
|
||||
except ImportError:
|
||||
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
|
||||
|
||||
|
||||
# for seeding the environment
|
||||
SEED = 0
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -86,7 +78,7 @@ class TestGymOffline(HorizonTestBase):
|
||||
|
||||
def evaluate_cem(env, manager, num_eval_episodes: int):
|
||||
# NOTE: for CEM, serving isn't implemented
|
||||
policy = manager.create_policy()
|
||||
policy = manager.create_policy(serving=False)
|
||||
agent = Agent.create_for_env(env, policy)
|
||||
return evaluate_for_n_episodes(
|
||||
n=num_eval_episodes, env=env, agent=agent, max_steps=env.max_steps
|
||||
@@ -110,13 +102,11 @@ def run_test_offline(
|
||||
logger.info(f"Normalization is: \n{pprint.pformat(normalization)}")
|
||||
|
||||
manager = model.value
|
||||
runner = OssBatchRunner(
|
||||
use_gpu,
|
||||
manager,
|
||||
trainer = manager.initialize_trainer(
|
||||
use_gpu=use_gpu,
|
||||
reward_options=RewardOptions(),
|
||||
normalization_data_map=normalization,
|
||||
)
|
||||
trainer = runner.initialize_trainer()
|
||||
|
||||
# first fill the replay buffer to burn_in
|
||||
replay_buffer = ReplayBuffer(
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from reagent.core.types import RewardOptions
|
||||
@@ -12,18 +12,10 @@ from reagent.gym.envs.env_wrapper import EnvWrapper
|
||||
from reagent.gym.envs.gym import Gym
|
||||
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
from reagent.model_managers.union import ModelManager__Union
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.runners.oss_batch_runner import OssBatchRunner
|
||||
from reagent.test.base.horizon_test_base import HorizonTestBase
|
||||
from reagent.training.world_model.seq2reward_trainer import Seq2RewardTrainer
|
||||
|
||||
|
||||
try:
|
||||
# Use internal runner or OSS otherwise
|
||||
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
|
||||
except ImportError:
|
||||
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
|
||||
from reagent.workflow.model_managers.union import ModelManager__Union
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -79,8 +71,8 @@ def train_seq2reward(
|
||||
)
|
||||
preprocessed_test_batch = trainer_preprocessor(test_batch)
|
||||
adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim)
|
||||
# valid_losses = trainer.get_loss(preprocessed_test_batch)
|
||||
# print_seq2reward_losses(epoch, "validation", valid_losses)
|
||||
valid_losses = trainer.get_loss(preprocessed_test_batch)
|
||||
print_seq2reward_losses(epoch, "validation", valid_losses)
|
||||
trainer.seq2reward_network.train()
|
||||
return trainer
|
||||
|
||||
@@ -117,13 +109,11 @@ def train_seq2reward_and_compute_reward_mse(
|
||||
env.seed(SEED)
|
||||
|
||||
manager = model.value
|
||||
runner = OssBatchRunner(
|
||||
use_gpu,
|
||||
manager,
|
||||
trainer = manager.initialize_trainer(
|
||||
use_gpu=use_gpu,
|
||||
reward_options=RewardOptions(),
|
||||
normalization_data_map=build_normalizer(env),
|
||||
)
|
||||
trainer = cast(Seq2RewardTrainer, runner.initialize_trainer())
|
||||
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
# pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
|
||||
@@ -159,7 +149,7 @@ def train_seq2reward_and_compute_reward_mse(
|
||||
)
|
||||
preprocessed_test_batch = trainer_preprocessor(test_batch)
|
||||
adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim)
|
||||
losses = trainer.compute_loss(preprocessed_test_batch)
|
||||
losses = trainer.get_loss(preprocessed_test_batch)
|
||||
detached_losses = losses.cpu().detach().item()
|
||||
trainer.seq2reward_network.train()
|
||||
return detached_losses
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from typing import Dict, List, Optional, cast
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.core.types import RewardOptions
|
||||
from reagent.evaluation.world_model_evaluator import (
|
||||
@@ -21,21 +21,14 @@ from reagent.gym.envs.pomdp.state_embed_env import StateEmbedEnvironment
|
||||
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
from reagent.model_managers.union import ModelManager__Union
|
||||
from reagent.models.world_model import MemoryNetwork
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.test.base.horizon_test_base import HorizonTestBase
|
||||
from reagent.training.world_model.mdnrnn_trainer import MDNRNNTrainer
|
||||
from reagent.workflow.model_managers.union import ModelManager__Union
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
try:
|
||||
# Use internal runner or OSS otherwise
|
||||
from reagent.runners.fb.fb_batch_runner import FbBatchRunner as BatchRunner
|
||||
except ImportError:
|
||||
from reagent.runners.oss_batch_runner import OssBatchRunner as BatchRunner
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
@@ -156,7 +149,7 @@ def train_mdnrnn(
|
||||
batch_size=batch_size
|
||||
)
|
||||
preprocessed_test_batch = trainer_preprocessor(test_batch)
|
||||
valid_losses = trainer.compute_loss(preprocessed_test_batch)
|
||||
valid_losses = trainer.get_loss(preprocessed_test_batch)
|
||||
print_mdnrnn_losses(epoch, "validation", valid_losses)
|
||||
trainer.memory_network.mdnrnn.train()
|
||||
return trainer
|
||||
@@ -178,13 +171,11 @@ def train_mdnrnn_and_compute_feature_stats(
|
||||
env.seed(SEED)
|
||||
|
||||
manager = model.value
|
||||
runner = BatchRunner(
|
||||
use_gpu,
|
||||
manager,
|
||||
trainer = manager.initialize_trainer(
|
||||
use_gpu=use_gpu,
|
||||
reward_options=RewardOptions(),
|
||||
normalization_data_map=build_normalizer(env),
|
||||
)
|
||||
trainer = cast(MDNRNNTrainer, runner.initialize_trainer())
|
||||
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
# pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
|
||||
@@ -297,13 +288,11 @@ def train_mdnrnn_and_train_on_embedded_env(
|
||||
env.seed(SEED)
|
||||
|
||||
embedding_manager = embedding_model.value
|
||||
embedding_runner = BatchRunner(
|
||||
use_gpu,
|
||||
embedding_manager,
|
||||
embedding_trainer = embedding_manager.initialize_trainer(
|
||||
use_gpu=use_gpu,
|
||||
reward_options=RewardOptions(),
|
||||
normalization_data_map=build_normalizer(env),
|
||||
)
|
||||
embedding_trainer = cast(MDNRNNTrainer, embedding_runner.initialize_trainer())
|
||||
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
embedding_trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
|
||||
@@ -347,14 +336,13 @@ def train_mdnrnn_and_train_on_embedded_env(
|
||||
state_max_value=state_max,
|
||||
)
|
||||
agent_manager = train_model.value
|
||||
agent_trainer = agent_manager.build_trainer(
|
||||
agent_trainer = agent_manager.initialize_trainer(
|
||||
use_gpu=use_gpu,
|
||||
reward_options=RewardOptions(),
|
||||
# pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got
|
||||
# `StateEmbedEnvironment`.
|
||||
normalization_data_map=build_normalizer(embed_env),
|
||||
)
|
||||
agent_trainer.reporter = agent_manager.get_reporter()
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
agent_trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
|
||||
agent_trainer,
|
||||
@@ -371,7 +359,7 @@ def train_mdnrnn_and_train_on_embedded_env(
|
||||
|
||||
# evaluate model
|
||||
rewards = []
|
||||
policy = agent_manager.create_policy(agent_trainer)
|
||||
policy = agent_manager.create_policy(serving=False)
|
||||
# pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got
|
||||
# `StateEmbedEnvironment`.
|
||||
agent = Agent.create_for_env(embed_env, policy=policy, device=device)
|
||||
|
||||
@@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, fields, is_dataclass
|
||||
from dataclasses import asdict, dataclass, fields, is_dataclass
|
||||
from typing import Any, NamedTuple, Type, Union
|
||||
|
||||
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from reagent.core import types as rlt
|
||||
from reagent.core.dataclasses import dataclass, field
|
||||
from reagent.core.types import (
|
||||
Dataset,
|
||||
PreprocessingOptions,
|
||||
ReaderOptions,
|
||||
RewardOptions,
|
||||
TableSpec,
|
||||
)
|
||||
from reagent.core.union import ModelFeatureConfigProvider__Union
|
||||
from reagent.data_fetchers.data_fetcher import DataFetcher
|
||||
from reagent.evaluation.evaluator import Evaluator, get_metrics_to_score
|
||||
from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
|
||||
from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.model_feature_config_provider import RawModelFeatureConfigProvider
|
||||
from reagent.parameters import EvaluationParameters, NormalizationData, NormalizationKey
|
||||
from reagent.preprocessing.batch_preprocessor import (
|
||||
BatchPreprocessor,
|
||||
DiscreteDqnBatchPreprocessor,
|
||||
)
|
||||
from reagent.preprocessing.preprocessor import Preprocessor
|
||||
from reagent.preprocessing.types import InputColumn
|
||||
from reagent.reporting.discrete_dqn_reporter import DiscreteDQNReporter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscreteDQNBase(ModelManager):
|
||||
target_action_distribution: Optional[List[float]] = None
|
||||
state_feature_config_provider: ModelFeatureConfigProvider__Union = field(
|
||||
# pyre-fixme[28]: Unexpected keyword argument `raw`.
|
||||
# pyre-fixme[28]: Unexpected keyword argument `raw`.
|
||||
default_factory=lambda: ModelFeatureConfigProvider__Union(
|
||||
raw=RawModelFeatureConfigProvider(float_feature_infos=[])
|
||||
)
|
||||
)
|
||||
eval_parameters: EvaluationParameters = field(default_factory=EvaluationParameters)
|
||||
preprocessing_options: Optional[PreprocessingOptions] = None
|
||||
reader_options: Optional[ReaderOptions] = None
|
||||
|
||||
def __post_init_post_parse__(self):
|
||||
super().__init__()
|
||||
|
||||
def create_policy(self, trainer) -> Policy:
|
||||
""" Create an online DiscreteDQN Policy from env. """
|
||||
sampler = SoftmaxActionSampler(temperature=self.trainer_param.rl.temperature)
|
||||
scorer = discrete_dqn_scorer(trainer.q_network)
|
||||
return Policy(scorer=scorer, sampler=sampler)
|
||||
|
||||
@property
|
||||
def state_feature_config(self) -> rlt.ModelFeatureConfig:
|
||||
return self.state_feature_config_provider.value.get_model_feature_config()
|
||||
|
||||
def metrics_to_score(self, reward_options: RewardOptions) -> List[str]:
|
||||
return get_metrics_to_score(reward_options.metric_reward_values)
|
||||
|
||||
@property
|
||||
def should_generate_eval_dataset(self) -> bool:
|
||||
return self.eval_parameters.calc_cpe_in_training
|
||||
|
||||
@property
|
||||
def required_normalization_keys(self) -> List[str]:
|
||||
return [NormalizationKey.STATE]
|
||||
|
||||
def run_feature_identification(
|
||||
self, data_fetcher: DataFetcher, input_table_spec: TableSpec
|
||||
) -> Dict[str, NormalizationData]:
|
||||
preprocessing_options = self.preprocessing_options or PreprocessingOptions()
|
||||
logger.info("Overriding whitelist_features")
|
||||
state_features = [
|
||||
ffi.feature_id for ffi in self.state_feature_config.float_feature_infos
|
||||
]
|
||||
preprocessing_options = preprocessing_options._replace(
|
||||
whitelist_features=state_features
|
||||
)
|
||||
return {
|
||||
NormalizationKey.STATE: NormalizationData(
|
||||
dense_normalization_parameters=data_fetcher.identify_normalization_parameters(
|
||||
input_table_spec, InputColumn.STATE_FEATURES, preprocessing_options
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def query_data(
|
||||
self,
|
||||
data_fetcher: DataFetcher,
|
||||
input_table_spec: TableSpec,
|
||||
sample_range: Optional[Tuple[float, float]],
|
||||
reward_options: RewardOptions,
|
||||
) -> Dataset:
|
||||
return data_fetcher.query_data(
|
||||
input_table_spec=input_table_spec,
|
||||
discrete_action=True,
|
||||
actions=self.trainer_param.actions,
|
||||
include_possible_actions=True,
|
||||
sample_range=sample_range,
|
||||
custom_reward_expression=reward_options.custom_reward_expression,
|
||||
multi_steps=self.multi_steps,
|
||||
gamma=self.trainer_param.rl.gamma,
|
||||
)
|
||||
|
||||
@property
|
||||
def multi_steps(self) -> Optional[int]:
|
||||
return self.trainer_param.rl.multi_steps
|
||||
|
||||
def build_batch_preprocessor(
|
||||
self,
|
||||
reader_options: ReaderOptions,
|
||||
use_gpu: bool,
|
||||
batch_size: int,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
reward_options: RewardOptions,
|
||||
) -> BatchPreprocessor:
|
||||
state_preprocessor = Preprocessor(
|
||||
normalization_data_map[
|
||||
NormalizationKey.STATE
|
||||
].dense_normalization_parameters,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
return DiscreteDqnBatchPreprocessor(
|
||||
num_actions=len(self.trainer_param.actions),
|
||||
state_preprocessor=state_preprocessor,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
|
||||
def get_reporter(self):
|
||||
return DiscreteDQNReporter(
|
||||
self.trainer_param.actions,
|
||||
target_action_distribution=self.target_action_distribution,
|
||||
)
|
||||
|
||||
def get_evaluator(self, trainer, reward_options: RewardOptions):
|
||||
return Evaluator(
|
||||
self.trainer_param.actions,
|
||||
self.trainer_param.rl.gamma,
|
||||
trainer,
|
||||
metrics_to_score=self.metrics_to_score(reward_options),
|
||||
)
|
||||
@@ -1,130 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
from reagent.core.types import Dataset, ReaderOptions, RewardOptions, TableSpec
|
||||
from reagent.data_fetchers.data_fetcher import DataFetcher
|
||||
from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model
|
||||
from reagent.parameters import NormalizationData
|
||||
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
|
||||
from reagent.training.trainer import Trainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager(metaclass=RegistryMeta):
|
||||
"""
|
||||
ModelManager manages how to train models.
|
||||
|
||||
Each type of models can have their own config type, implemented as
|
||||
`config_type()` class method. `__init__()` of the concrete class must take
|
||||
this type.
|
||||
|
||||
ModelManager abstracts over common phases of training, i.e.,:
|
||||
1. `run_feature_identification()` defines how to derive feature preprocessing
|
||||
parameters from given data.
|
||||
2. `query_data()` massages the input table into the format expected by the trainer
|
||||
3. `initialize_trainer()` creates the trainer
|
||||
4. `train()`
|
||||
5. `build_serving_module()` builds the module for prediction
|
||||
6. `save_trainer()` saves the trainer for warmstarting
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run_feature_identification(
|
||||
self, data_fetcher: DataFetcher, input_table_spec: TableSpec
|
||||
) -> Dict[str, NormalizationData]:
|
||||
"""
|
||||
Derive preprocessing parameters from data. The keys of the dict should
|
||||
match the keys from `required_normalization_keys()`
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def required_normalization_keys(self) -> List[str]:
|
||||
""" Get the normalization keys required for current instance """
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def should_generate_eval_dataset(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_evaluator(self, trainer, reward_options: RewardOptions):
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_data(
|
||||
self,
|
||||
data_fetcher: DataFetcher,
|
||||
input_table_spec: TableSpec,
|
||||
sample_range: Optional[Tuple[float, float]],
|
||||
reward_options: RewardOptions,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Massage input table into the format expected by the trainer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_reporter(self):
|
||||
"""
|
||||
Get the reporter that displays statistics after training
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_batch_preprocessor(
|
||||
self,
|
||||
reader_options: ReaderOptions,
|
||||
use_gpu: bool,
|
||||
batch_size: int,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
reward_options: RewardOptions,
|
||||
) -> BatchPreprocessor:
|
||||
"""
|
||||
The Batch Preprocessor is a module that transforms data to a form that can be (1) read by the trainer
|
||||
or (2) used in part of the serving module. For training, the batch preprocessor is typically run
|
||||
on reader machines in parallel so the GPUs on the trainer machines can be fully utilized.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_trainer(
|
||||
self,
|
||||
use_gpu: bool,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
reward_options: RewardOptions,
|
||||
) -> Trainer:
|
||||
"""
|
||||
Implement this to build the trainer, given the config
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_policy(self, trainer) -> Policy:
|
||||
""" Create a Policy from env. """
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_serving_policy(
|
||||
self, normalization_data_map: Dict[str, NormalizationData], trainer
|
||||
) -> Policy:
|
||||
""" Create an online Policy from env. """
|
||||
return create_predictor_policy_from_model(
|
||||
self.build_serving_module(normalization_data_map, trainer)
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_serving_module(
|
||||
self, normalization_data_map: Dict[str, NormalizationData], trainer
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Returns TorchScript module to be used in predictor
|
||||
"""
|
||||
pass
|
||||
@@ -1,96 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from reagent.core.dataclasses import dataclass, field
|
||||
from reagent.core.types import RewardOptions
|
||||
from reagent.model_managers.parametric_dqn_base import ParametricDQNBase
|
||||
from reagent.net_builder.parametric_dqn.fully_connected import FullyConnected
|
||||
from reagent.net_builder.unions import ParametricDQNNetBuilder__Union
|
||||
from reagent.parameters import NormalizationData, NormalizationKey, param_hash
|
||||
from reagent.preprocessing.normalization import (
|
||||
get_feature_config,
|
||||
get_num_output_features,
|
||||
)
|
||||
from reagent.training import ParametricDQNTrainer, ParametricDQNTrainerParameters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParametricDQN(ParametricDQNBase):
|
||||
__hash__ = param_hash
|
||||
|
||||
trainer_param: ParametricDQNTrainerParameters = field(
|
||||
default_factory=ParametricDQNTrainerParameters
|
||||
)
|
||||
net_builder: ParametricDQNNetBuilder__Union = field(
|
||||
# pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
|
||||
# pyre-fixme[28]: Unexpected keyword argument `FullyConnected`.
|
||||
default_factory=lambda: ParametricDQNNetBuilder__Union(
|
||||
FullyConnected=FullyConnected()
|
||||
)
|
||||
)
|
||||
|
||||
def __post_init_post_parse__(self):
|
||||
super().__post_init_post_parse__()
|
||||
|
||||
def build_trainer(
|
||||
self,
|
||||
use_gpu: bool,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
reward_options: RewardOptions,
|
||||
) -> ParametricDQNTrainer:
|
||||
net_builder = self.net_builder.value
|
||||
q_network = net_builder.build_q_network(
|
||||
normalization_data_map[NormalizationKey.STATE],
|
||||
normalization_data_map[NormalizationKey.ACTION],
|
||||
)
|
||||
# Metrics + reward
|
||||
reward_output_dim = len(self.metrics_to_score(reward_options)) + 1
|
||||
reward_network = net_builder.build_q_network(
|
||||
normalization_data_map[NormalizationKey.STATE],
|
||||
normalization_data_map[NormalizationKey.ACTION],
|
||||
output_dim=reward_output_dim,
|
||||
)
|
||||
|
||||
if use_gpu:
|
||||
q_network = q_network.cuda()
|
||||
reward_network = reward_network.cuda()
|
||||
|
||||
q_network_target = q_network.get_target_network()
|
||||
trainer = ParametricDQNTrainer(
|
||||
q_network=q_network,
|
||||
q_network_target=q_network_target,
|
||||
reward_network=reward_network,
|
||||
use_gpu=use_gpu,
|
||||
# pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
|
||||
# `asdict`.
|
||||
# pyre-fixme[16]: `ParametricDQNTrainerParameters` has no attribute
|
||||
# `asdict`.
|
||||
**self.trainer_param.asdict(),
|
||||
)
|
||||
|
||||
# HACK: injecting num_actions to build policies for gym
|
||||
trainer.num_gym_actions = get_num_output_features(
|
||||
normalization_data_map[
|
||||
NormalizationKey.ACTION
|
||||
].dense_normalization_parameters
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
def build_serving_module(
|
||||
self,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
trainer: ParametricDQNTrainer,
|
||||
) -> torch.nn.Module:
|
||||
net_builder = self.net_builder.value
|
||||
return net_builder.build_serving_module(
|
||||
trainer.q_network,
|
||||
normalization_data_map[NormalizationKey.STATE],
|
||||
normalization_data_map[NormalizationKey.ACTION],
|
||||
)
|
||||
@@ -5,7 +5,7 @@ import math
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.fully_connected_network import FullyConnectedNetwork
|
||||
from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE
|
||||
|
||||
@@ -5,7 +5,7 @@ from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
|
||||
|
||||
# add ABCMeta once https://github.com/sphinx-doc/sphinx/issues/5995 is fixed
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import numpy as np
|
||||
import scipy.stats as stats
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.world_model import MemoryNetwork
|
||||
from reagent.parameters import CONTINUOUS_TRAINING_ACTION_RANGE
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.fully_connected_network import FullyConnectedNetwork
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.fully_connected_network import FullyConnectedNetwork
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.critic import FullyConnectedCritic
|
||||
from reagent.models.dqn import FullyConnectedDQN
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as f
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.torch_utils import stack
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import abc
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.seq2slate import (
|
||||
DECODER_START_SYMBOL,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.mdn_rnn import MDNRNN
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.core.dataclasses import dataclass, field
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.dueling_q_network import DuelingQNetwork
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.core.dataclasses import dataclass, field
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.dqn import FullyConnectedDQN
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import List
|
||||
|
||||
import reagent.models as models
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.core.dataclasses import dataclass, field
|
||||
from reagent.net_builder.discrete_dqn_net_builder import DiscreteDQNNetBuilder
|
||||
from reagent.parameters import NormalizationData, param_hash
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
|
||||
@@ -687,27 +687,15 @@ class NeuralDualDICE(RLEstimator):
|
||||
), "Expected all fields to be present"
|
||||
tgt_dist = input.target_policy.action_dist(t.state)
|
||||
tgt_action = tgt_dist.sample()[0]
|
||||
samples["init_state"].append(
|
||||
state.value.cpu().numpy()
|
||||
if isinstance(state.value, torch.Tensor)
|
||||
else state.value
|
||||
)
|
||||
samples["init_state"].append(state.value)
|
||||
samples["init_action"].append(
|
||||
torch.nn.functional.one_hot(
|
||||
torch.tensor(tgt_init_action.value, dtype=torch.long),
|
||||
self.action_dim,
|
||||
).float()
|
||||
)
|
||||
samples["last_state"].append(
|
||||
t.last_state.value.cpu().numpy()
|
||||
if isinstance(t.last_state.value, torch.Tensor)
|
||||
else t.last_state.value
|
||||
)
|
||||
samples["state"].append(
|
||||
t.state.value.cpu().numpy()
|
||||
if isinstance(t.state.value, torch.Tensor)
|
||||
else t.state.value
|
||||
)
|
||||
samples["last_state"].append(t.last_state.value)
|
||||
samples["state"].append(t.state.value)
|
||||
samples["log_action"].append(
|
||||
torch.nn.functional.one_hot(
|
||||
torch.tensor(t.action.value, dtype=torch.long), self.action_dim
|
||||
|
||||
@@ -58,7 +58,6 @@ class MDNRNNTrainerParameters(BaseDataClass):
|
||||
action_dim: int = 2
|
||||
action_names: List[str] = field(default_factory=lambda: [])
|
||||
multi_steps: int = 1
|
||||
shuffle_training_data: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -5,7 +5,7 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.core.types import BaseDataClass
|
||||
from reagent.types import BaseDataClass
|
||||
|
||||
|
||||
class LearningMethod(Enum):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.models.seq2slate import Seq2SlateMode, Seq2SlateTransformerNet
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.preprocessing.preprocessor import Preprocessor
|
||||
|
||||
|
||||
|
||||
@@ -7,24 +7,12 @@ from dataclasses import asdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import six
|
||||
import torch
|
||||
from reagent.parameters import NormalizationData, NormalizationParameters
|
||||
from reagent.preprocessing import identify_types
|
||||
from reagent.preprocessing.identify_types import DEFAULT_MAX_UNIQUE_ENUM, FEATURE_TYPES
|
||||
from reagent.preprocessing.normalization_constants import (
|
||||
BOX_COX_MARGIN,
|
||||
BOX_COX_MAX_STDDEV,
|
||||
DEFAULT_MAX_QUANTILE_SIZE,
|
||||
DEFAULT_NUM_SAMPLES,
|
||||
DEFAULT_QUANTILE_K2_THRESHOLD,
|
||||
EPS,
|
||||
MAX_FEATURE_VALUE,
|
||||
MIN_FEATURE_VALUE,
|
||||
MINIMUM_SAMPLES_TO_IDENTIFY,
|
||||
MISSING_VALUE,
|
||||
)
|
||||
from scipy import stats
|
||||
from scipy.stats.mstats import mquantiles
|
||||
|
||||
@@ -32,6 +20,18 @@ from scipy.stats.mstats import mquantiles
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BOX_COX_MAX_STDDEV = 1e8
|
||||
BOX_COX_MARGIN = 1e-4
|
||||
MISSING_VALUE = -1337.1337
|
||||
DEFAULT_QUANTILE_K2_THRESHOLD = 1000.0
|
||||
MINIMUM_SAMPLES_TO_IDENTIFY = 20
|
||||
DEFAULT_MAX_QUANTILE_SIZE = 20
|
||||
DEFAULT_NUM_SAMPLES = 100000
|
||||
MAX_FEATURE_VALUE = 6.0
|
||||
MIN_FEATURE_VALUE = MAX_FEATURE_VALUE * -1
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def no_op_feature():
|
||||
return NormalizationParameters(
|
||||
identify_types.CONTINUOUS, None, 0, 0, 1, None, None, None, None
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from reagent.preprocessing.identify_types import ( # noqa
|
||||
DEFAULT_MAX_UNIQUE_ENUM,
|
||||
FEATURE_TYPES,
|
||||
)
|
||||
|
||||
|
||||
BOX_COX_MAX_STDDEV = 1e8
|
||||
BOX_COX_MARGIN = 1e-4
|
||||
MISSING_VALUE = -1337.1337
|
||||
DEFAULT_QUANTILE_K2_THRESHOLD = 1000.0
|
||||
MINIMUM_SAMPLES_TO_IDENTIFY = 20
|
||||
DEFAULT_MAX_QUANTILE_SIZE = 20
|
||||
DEFAULT_NUM_SAMPLES = 100000
|
||||
MAX_FEATURE_VALUE = 6.0
|
||||
MIN_FEATURE_VALUE = MAX_FEATURE_VALUE * -1
|
||||
EPS = 1e-6
|
||||
@@ -4,7 +4,7 @@
|
||||
import logging
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import reagent.core.types as rlt
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.parameters import NormalizationData
|
||||
|
||||
@@ -6,10 +6,9 @@ from typing import Optional
|
||||
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.core.result_types import NoPublishingResults
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.types import RecurringPeriod
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.core.types import RecurringPeriod, RLTrainingOutput
|
||||
from reagent.publishers.model_publisher import ModelPublisher
|
||||
from reagent.workflow.model_managers.model_manager import ModelManager
|
||||
|
||||
|
||||
try:
|
||||
@@ -73,7 +72,7 @@ if HAS_TINYDB:
|
||||
child_workflow_id: int,
|
||||
recurring_period: Optional[RecurringPeriod],
|
||||
) -> NoPublishingResults:
|
||||
path = training_output.local_output_path
|
||||
path = training_output.output_path
|
||||
assert path is not None, f"Given path is None."
|
||||
assert os.path.exists(path), f"Given path {path} doesn't exist."
|
||||
Model = Query()
|
||||
|
||||
@@ -5,10 +5,9 @@ import inspect
|
||||
from typing import Optional
|
||||
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.types import RecurringPeriod
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.reporting.result_registries import PublishingResult
|
||||
from reagent.core.types import RecurringPeriod, RLTrainingOutput
|
||||
from reagent.workflow.model_managers.model_manager import ModelManager
|
||||
from reagent.workflow.result_registries import PublishingResult
|
||||
|
||||
|
||||
class ModelPublisher(metaclass=RegistryMeta):
|
||||
@@ -39,7 +38,7 @@ class ModelPublisher(metaclass=RegistryMeta):
|
||||
recurring_period,
|
||||
)
|
||||
# Avoid circular dependency at import time
|
||||
from reagent.core.union import PublishingResult__Union
|
||||
from reagent.core.types import PublishingResult__Union
|
||||
|
||||
# We need to use inspection because the result can be a future when running on
|
||||
# FBL
|
||||
|
||||
@@ -4,10 +4,9 @@ from typing import Optional
|
||||
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.core.result_types import NoPublishingResults
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.types import RecurringPeriod
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.core.types import RecurringPeriod, RLTrainingOutput
|
||||
from reagent.publishers.model_publisher import ModelPublisher
|
||||
from reagent.workflow.model_managers.model_manager import ModelManager
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
|
||||
|
||||
|
||||
if True: # To prevent auto sorting of inputs
|
||||
# Triggering registration to registries
|
||||
import reagent.core.result_types # noqa
|
||||
import reagent.reporting.oss_training_reports # noqa
|
||||
from reagent.model_managers.union import * # noqa
|
||||
|
||||
if IS_FB_ENVIRONMENT:
|
||||
import reagent.core.fb.fb_result_types # noqa
|
||||
|
||||
# Register all unions
|
||||
from reagent.core.union import * # noqa
|
||||
from reagent.model_managers.union import * # noqa
|
||||
from reagent.optimizer.union import * # noqa
|
||||
from reagent.publishers.union import * # noqa
|
||||
from reagent.validators.union import * # noqa
|
||||
|
||||
if IS_FB_ENVIRONMENT:
|
||||
from reagent.model_managers.fb.union import * # noqa
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
from reagent.core import aggregators as agg
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.union import TrainingReport__Union
|
||||
from reagent.reporting.oss_training_reports import OssActorCriticTrainingReport
|
||||
from reagent.reporting.reporter_base import ReporterBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActorCriticReporter(ReporterBase):
|
||||
def __init__(self, report_interval: int = 100):
|
||||
aggregators = itertools.chain(
|
||||
[
|
||||
("cpe_results", agg.AppendAggregator("cpe_details")),
|
||||
("td_loss", agg.MeanAggregator("td_loss", interval=report_interval)),
|
||||
(
|
||||
"reward_loss",
|
||||
agg.MeanAggregator("reward_loss", interval=report_interval),
|
||||
),
|
||||
(
|
||||
"recent_rewards",
|
||||
agg.RecentValuesAggregator(
|
||||
"logged_rewards", interval=report_interval
|
||||
),
|
||||
),
|
||||
],
|
||||
[
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardHistogramAndMeanAggregator(
|
||||
key, log_key, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, log_key in [
|
||||
("td_loss", "td_loss"),
|
||||
("reward_loss", "reward_loss"),
|
||||
("logged_propensities", "propensities/logged"),
|
||||
("logged_rewards", "reward/logged"),
|
||||
]
|
||||
],
|
||||
)
|
||||
super().__init__(aggregators)
|
||||
|
||||
# TODO: T71636196 write this for OSS
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
report = OssActorCriticTrainingReport()
|
||||
return RLTrainingOutput(
|
||||
training_report=TrainingReport__Union(oss_actor_critic_report=report)
|
||||
)
|
||||
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from reagent.core import aggregators as agg
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.union import TrainingReport__Union
|
||||
from reagent.reporting.oss_training_reports import OssDQNTrainingReport
|
||||
from reagent.reporting.reporter_base import ReporterBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscreteDQNReporter(ReporterBase):
|
||||
def __init__(
|
||||
self,
|
||||
actions: List[str],
|
||||
report_interval: int = 100,
|
||||
target_action_distribution: Optional[List[float]] = None,
|
||||
recent_window_size: int = 100,
|
||||
):
|
||||
aggregators = itertools.chain(
|
||||
[
|
||||
("CPE Results", agg.AppendAggregator("cpe_details")),
|
||||
("TD Loss", agg.MeanAggregator("td_loss", interval=report_interval)),
|
||||
(
|
||||
"Reward Loss",
|
||||
agg.MeanAggregator("reward_loss", interval=report_interval),
|
||||
),
|
||||
(
|
||||
"Model Action Values",
|
||||
agg.FunctionsByActionAggregator(
|
||||
"model_values",
|
||||
actions,
|
||||
{"mean": torch.mean, "std": torch.std},
|
||||
interval=report_interval,
|
||||
),
|
||||
),
|
||||
(
|
||||
"Logged Actions",
|
||||
agg.ActionCountAggregator(
|
||||
"logged_actions", actions, interval=report_interval
|
||||
),
|
||||
),
|
||||
(
|
||||
"model_action",
|
||||
agg.ActionCountAggregator(
|
||||
"model_action_idxs", actions, interval=report_interval
|
||||
),
|
||||
),
|
||||
(
|
||||
"Recent Logged Rewards",
|
||||
agg.RecentValuesAggregator(
|
||||
"logged_rewards", interval=report_interval
|
||||
),
|
||||
),
|
||||
],
|
||||
[
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardActionCountAggregator(
|
||||
key, title, actions, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, title in [
|
||||
("logged_actions", "logged"),
|
||||
("model_action_idxs", "model"),
|
||||
]
|
||||
],
|
||||
[
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardHistogramAndMeanAggregator(
|
||||
key, log_key, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, log_key in [
|
||||
("td_loss", "td_loss"),
|
||||
("reward_loss", "reward_loss"),
|
||||
("logged_propensities", "propensities/logged"),
|
||||
("logged_rewards", "reward/logged"),
|
||||
]
|
||||
],
|
||||
[
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardActionHistogramAndMeanAggregator(
|
||||
key, category, title, actions, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, category, title in [
|
||||
("model_propensities", "propensities", "model"),
|
||||
("model_rewards", "reward", "model"),
|
||||
("model_values", "value", "model"),
|
||||
]
|
||||
],
|
||||
)
|
||||
super().__init__(aggregators)
|
||||
self.target_action_distribution = target_action_distribution
|
||||
self.recent_window_size = recent_window_size
|
||||
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
return RLTrainingOutput(
|
||||
training_report=TrainingReport__Union(oss_dqn_report=OssDQNTrainingReport())
|
||||
)
|
||||
@@ -1,62 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.evaluation.cpe import CpeEstimate
|
||||
from reagent.reporting.training_reports import TrainingReport
|
||||
|
||||
|
||||
@dataclass
|
||||
class OssDQNTrainingReport(TrainingReport):
|
||||
__registry_name__ = "oss_dqn_report"
|
||||
|
||||
td_loss: Optional[float] = None
|
||||
mc_loss: Optional[float] = None
|
||||
reward_ips: Optional[CpeEstimate] = None
|
||||
reward_dm: Optional[CpeEstimate] = None
|
||||
reward_dr: Optional[CpeEstimate] = None
|
||||
value_sequential_dr: Optional[CpeEstimate] = None
|
||||
value_weighted_dr: Optional[CpeEstimate] = None
|
||||
value_magic_dr: Optional[CpeEstimate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OssActorCriticTrainingReport(TrainingReport):
|
||||
__registry_name__ = "oss_actor_critic_report"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OssParametricDQNTrainingReport(TrainingReport):
|
||||
__registry_name__ = "oss_parametric_dqn_report"
|
||||
|
||||
td_loss: Optional[float] = None
|
||||
mc_loss: Optional[float] = None
|
||||
reward_ips: Optional[CpeEstimate] = None
|
||||
reward_dm: Optional[CpeEstimate] = None
|
||||
reward_dr: Optional[CpeEstimate] = None
|
||||
value_sequential_dr: Optional[CpeEstimate] = None
|
||||
value_weighted_dr: Optional[CpeEstimate] = None
|
||||
value_magic_dr: Optional[CpeEstimate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OssWorldModelTrainingReport(TrainingReport):
|
||||
__registry_name__ = "oss_world_model_report"
|
||||
loss: List[float]
|
||||
gmm: List[float]
|
||||
bce: List[float]
|
||||
mse: List[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugToolsReport(TrainingReport):
|
||||
__registry_name__ = "oss_debug_tools_report"
|
||||
|
||||
feature_importance: Optional[List[float]] = None
|
||||
feature_sensitivity: Optional[List[float]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OssRankingModelTrainingReport(TrainingReport):
|
||||
__registry_name__ = "oss_ranking_model_training_report"
|
||||
@@ -1,64 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from reagent.core import aggregators as agg
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.union import TrainingReport__Union
|
||||
from reagent.reporting.oss_training_reports import OssParametricDQNTrainingReport
|
||||
from reagent.reporting.reporter_base import ReporterBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParametricDQNReporter(ReporterBase):
|
||||
def __init__(
|
||||
self,
|
||||
report_interval: int = 100,
|
||||
target_action_distribution: Optional[List[float]] = None,
|
||||
recent_window_size: int = 100,
|
||||
):
|
||||
aggregators = itertools.chain(
|
||||
[
|
||||
("cpe_results", agg.AppendAggregator("cpe_results")),
|
||||
("td_loss", agg.MeanAggregator("td_loss", interval=report_interval)),
|
||||
(
|
||||
"reward_loss",
|
||||
agg.MeanAggregator("reward_loss", interval=report_interval),
|
||||
),
|
||||
(
|
||||
"logged_rewards",
|
||||
agg.RecentValuesAggregator(
|
||||
"logged_rewards", interval=report_interval
|
||||
),
|
||||
),
|
||||
],
|
||||
[
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardHistogramAndMeanAggregator(
|
||||
key, log_key, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, log_key in [
|
||||
("td_loss", "td_loss"),
|
||||
("reward_loss", "reward_loss"),
|
||||
("logged_propensities", "propensities/logged"),
|
||||
("logged_rewards", "reward/logged"),
|
||||
]
|
||||
],
|
||||
)
|
||||
super().__init__(aggregators)
|
||||
self.target_action_distribution = target_action_distribution
|
||||
self.recent_window_size = recent_window_size
|
||||
|
||||
# TODO: T71636218 write this for OSS
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
cpe_results = self.cpe_results.values
|
||||
report = OssParametricDQNTrainingReport()
|
||||
return RLTrainingOutput(
|
||||
training_report=TrainingReport__Union(oss_parametric_dqn_report=report)
|
||||
)
|
||||
@@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
|
||||
from reagent.core import aggregators as agg
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.union import TrainingReport__Union
|
||||
from reagent.reporting.oss_training_reports import OssRankingModelTrainingReport
|
||||
from reagent.reporting.reporter_base import ReporterBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RankingModelReporter(ReporterBase):
|
||||
def __init__(self, report_interval: int = 100):
|
||||
"""
|
||||
For Ranking model:
|
||||
'pg' (policy gradient loss)
|
||||
'baseline' (the baseline model's loss, usually for fitting V(s))
|
||||
'kendall_tau' (kendall_tau coefficient between advantage and log_probs,
|
||||
used in evaluation page handlers)
|
||||
'kendaull_tau_p_value' (the p-value for kendall_tau test, used in
|
||||
evaluation page handlers)
|
||||
"""
|
||||
aggregators = [
|
||||
("pg", agg.MeanAggregator("pg", interval=report_interval)),
|
||||
("baseline", agg.MeanAggregator("baseline", interval=report_interval)),
|
||||
(
|
||||
"kendall_tau",
|
||||
agg.MeanAggregator("kendall_tau", interval=report_interval),
|
||||
),
|
||||
(
|
||||
"kendaull_tau_p_value",
|
||||
agg.MeanAggregator("kendaull_tau_p_value", interval=report_interval),
|
||||
),
|
||||
] + [
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardHistogramAndMeanAggregator(
|
||||
key, log_key, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, log_key in [
|
||||
("pg", "pg"),
|
||||
("baseline", "baseline"),
|
||||
("kendall_tau", "kendall_tau"),
|
||||
("kendaull_tau_p_value", "kendaull_tau_p_value"),
|
||||
]
|
||||
]
|
||||
super().__init__(aggregators)
|
||||
|
||||
# TODO: T71636236 write this for OSS
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
report = OssRankingModelTrainingReport()
|
||||
return RLTrainingOutput(
|
||||
training_report=TrainingReport__Union(
|
||||
oss_ranking_model_training_report=report
|
||||
)
|
||||
)
|
||||
@@ -1,59 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from reagent.core import aggregators as agg
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReporterBase:
|
||||
def __init__(self, aggregators: List[Tuple[str, agg.Aggregator]]):
|
||||
self.aggregators = OrderedDict(aggregators)
|
||||
|
||||
def report(self, **kwargs: Dict[str, Any]):
|
||||
for name, value in kwargs.items():
|
||||
for aggregator in self.aggregators.values():
|
||||
if aggregator.key == name:
|
||||
aggregator.update(name, value)
|
||||
|
||||
def finish_epoch(self):
|
||||
for aggregator in self.aggregators.values():
|
||||
aggregator.finish_epoch()
|
||||
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
pass
|
||||
|
||||
def get_recent(self, key: str, count: int, average: bool):
|
||||
for _, aggregator in self.aggregators.items():
|
||||
if aggregator.key == key:
|
||||
recent = aggregator.aggregator.get_recent(count)
|
||||
if len(recent) == 0:
|
||||
return None
|
||||
if average:
|
||||
return float(torch.mean(torch.tensor(recent)))
|
||||
return recent
|
||||
return None
|
||||
|
||||
def get_all(self, key: str, average: bool):
|
||||
for _, aggregator in self.aggregators.items():
|
||||
if aggregator.key == key:
|
||||
all_data = aggregator.aggregator.get_all()
|
||||
if len(all_data) == 0:
|
||||
return None
|
||||
if average:
|
||||
return float(torch.mean(torch.tensor(all_data)))
|
||||
return all_data
|
||||
return None
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
return self.aggregators[key]
|
||||
|
||||
def end_epoch(self):
|
||||
for aggregator in self.aggregators.values():
|
||||
aggregator.end_epoch()
|
||||
@@ -1,363 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import Deque, List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from reagent.tensorboardX import SummaryWriterContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
LOSS_REPORT_INTERVAL = 100
|
||||
|
||||
|
||||
class BatchStats(NamedTuple):
|
||||
td_loss: Optional[torch.Tensor] = None
|
||||
reward_loss: Optional[torch.Tensor] = None
|
||||
imitator_loss: Optional[torch.Tensor] = None
|
||||
logged_actions: Optional[torch.Tensor] = None
|
||||
logged_propensities: Optional[torch.Tensor] = None
|
||||
logged_rewards: Optional[torch.Tensor] = None
|
||||
logged_values: Optional[torch.Tensor] = None
|
||||
model_propensities: Optional[torch.Tensor] = None
|
||||
model_rewards: Optional[torch.Tensor] = None
|
||||
model_values: Optional[torch.Tensor] = None
|
||||
model_values_on_logged_actions: Optional[torch.Tensor] = None
|
||||
model_action_idxs: Optional[torch.Tensor] = None
|
||||
|
||||
def write_summary(self, actions: List[str]):
|
||||
if actions:
|
||||
for field, log_key in [
|
||||
("logged_actions", "actions/logged"),
|
||||
("model_action_idxs", "actions/model"),
|
||||
]:
|
||||
val = getattr(self, field)
|
||||
if val is None:
|
||||
continue
|
||||
for i, action in enumerate(actions):
|
||||
# pyre-fixme[16]: `SummaryWriterContext` has no attribute
|
||||
# `add_scalar`.
|
||||
SummaryWriterContext.add_scalar(
|
||||
"{}/{}".format(log_key, action), (val == i).sum().item()
|
||||
)
|
||||
|
||||
for field, log_key in [
|
||||
("td_loss", "td_loss"),
|
||||
("imitator_loss", "imitator_loss"),
|
||||
("reward_loss", "reward_loss"),
|
||||
("logged_propensities", "propensities/logged"),
|
||||
("logged_rewards", "reward/logged"),
|
||||
("logged_values", "value/logged"),
|
||||
("model_values_on_logged_actions", "value/model_logged_action"),
|
||||
]:
|
||||
val = getattr(self, field)
|
||||
if val is None:
|
||||
continue
|
||||
assert len(val.shape) == 1 or (
|
||||
len(val.shape) == 2 and val.shape[1] == 1
|
||||
), "Unexpected shape for {}: {}".format(field, val.shape)
|
||||
self._log_histogram_and_mean(log_key, val)
|
||||
|
||||
for field, log_key in [
|
||||
("model_propensities", "propensities/model"),
|
||||
("model_rewards", "reward/model"),
|
||||
("model_values", "value/model"),
|
||||
]:
|
||||
val = getattr(self, field)
|
||||
if val is None:
|
||||
continue
|
||||
if (
|
||||
len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1)
|
||||
) and not actions:
|
||||
self._log_histogram_and_mean(log_key, val)
|
||||
elif len(val.shape) == 2 and val.shape[1] == len(actions):
|
||||
for i, action in enumerate(actions):
|
||||
self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unexpected shape for {}: {}; actions: {}".format(
|
||||
field, val.shape, actions
|
||||
)
|
||||
)
|
||||
|
||||
def _log_histogram_and_mean(self, log_key, val):
|
||||
try:
|
||||
SummaryWriterContext.add_histogram(log_key, val)
|
||||
SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean())
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Cannot create histogram for key: {log_key}; "
|
||||
"this is likely because you have NULL value in your input; "
|
||||
f"value: {val}"
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def add_custom_scalars(action_names: Optional[List[str]]):
|
||||
if not action_names:
|
||||
return
|
||||
|
||||
SummaryWriterContext.add_custom_scalars_multilinechart(
|
||||
[
|
||||
"propensities/model/{}/mean".format(action_name)
|
||||
for action_name in action_names
|
||||
],
|
||||
category="propensities",
|
||||
title="model",
|
||||
)
|
||||
SummaryWriterContext.add_custom_scalars_multilinechart(
|
||||
[
|
||||
"propensities/logged/{}/mean".format(action_name)
|
||||
for action_name in action_names
|
||||
],
|
||||
category="propensities",
|
||||
title="logged",
|
||||
)
|
||||
SummaryWriterContext.add_custom_scalars_multilinechart(
|
||||
["actions/logged/{}".format(action_name) for action_name in action_names],
|
||||
category="actions",
|
||||
title="logged",
|
||||
)
|
||||
SummaryWriterContext.add_custom_scalars_multilinechart(
|
||||
["actions/model/{}".format(action_name) for action_name in action_names],
|
||||
category="actions",
|
||||
title="model",
|
||||
)
|
||||
|
||||
|
||||
def merge_tensor_namedtuple_list(l, cls):
|
||||
def merge_tensor(f):
|
||||
vals = [getattr(e, f) for e in l]
|
||||
not_none_vals = [v for v in vals if v is not None]
|
||||
assert len(not_none_vals) == 0 or len(not_none_vals) == len(vals)
|
||||
if not not_none_vals:
|
||||
return None
|
||||
return torch.cat(not_none_vals, dim=0)
|
||||
|
||||
return cls(**{f: merge_tensor(f) for f in cls._fields})
|
||||
|
||||
|
||||
class StatsByAction(object):
|
||||
def __init__(self, actions):
|
||||
self.stats = {action: [] for action in actions}
|
||||
|
||||
def append(self, stats):
|
||||
for k in stats:
|
||||
assert k in self.stats
|
||||
for k in self.stats:
|
||||
v = stats.get(k, 0)
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
self.stats[k].append(v)
|
||||
|
||||
def items(self):
|
||||
return self.stats.items()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.stats)
|
||||
|
||||
|
||||
class NoOpTrainingReporter:
|
||||
def report(self, **kwargs):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
class TrainingReporter(object):
|
||||
RECENT_WINDOW_SIZE = 100
|
||||
|
||||
def __init__(self, action_names: Optional[List[str]] = None):
|
||||
assert action_names is None or len(action_names) > 0
|
||||
self.action_names: List[str] = action_names or []
|
||||
self.loss_report_interval = LOSS_REPORT_INTERVAL
|
||||
BatchStats.add_custom_scalars(action_names)
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.running_reward: Deque[float] = deque(maxlen=int(1e6))
|
||||
|
||||
self.td_loss: List[float] = []
|
||||
self.reward_loss: List[float] = []
|
||||
self.imitator_loss: List[float] = []
|
||||
self.logged_action_q_value: List[float] = []
|
||||
self.logged_action_counts = {action: 0 for action in self.action_names}
|
||||
self.model_values = StatsByAction(self.action_names)
|
||||
self.model_value_stds = StatsByAction(self.action_names)
|
||||
self.model_action_counts = StatsByAction(self.action_names)
|
||||
self.model_action_counts_cumulative = {
|
||||
action: 0 for action in self.action_names
|
||||
}
|
||||
self.model_action_distr = StatsByAction(self.action_names)
|
||||
|
||||
self.incoming_stats: List[BatchStats] = []
|
||||
|
||||
@property
|
||||
def num_batches(self):
|
||||
return len(self.td_loss)
|
||||
|
||||
def report(self, **kwargs):
|
||||
def _to_tensor(v):
|
||||
if v is None:
|
||||
return None
|
||||
if not isinstance(v, torch.Tensor):
|
||||
v = torch.tensor(v)
|
||||
if len(v.shape) == 0:
|
||||
v = v.reshape(1)
|
||||
return v.detach().cpu()
|
||||
|
||||
kwargs = {k: _to_tensor(v) for k, v in kwargs.items()}
|
||||
batch_stats = BatchStats(**kwargs)
|
||||
self.incoming_stats.append(batch_stats)
|
||||
if len(self.incoming_stats) >= self.loss_report_interval:
|
||||
self.flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def flush(self):
|
||||
if not len(self.incoming_stats):
|
||||
logger.info("Nothing to report")
|
||||
return
|
||||
|
||||
logger.info("Loss on {} batches".format(len(self.incoming_stats)))
|
||||
|
||||
batch_stats = merge_tensor_namedtuple_list(self.incoming_stats, BatchStats)
|
||||
batch_stats.write_summary(self.action_names)
|
||||
|
||||
print_details = "Loss:\n"
|
||||
|
||||
td_loss_mean = float(batch_stats.td_loss.mean())
|
||||
self.td_loss.append(td_loss_mean)
|
||||
print_details = print_details + "TD LOSS: {0:.3f}\n".format(td_loss_mean)
|
||||
|
||||
if batch_stats.logged_rewards is not None:
|
||||
flattened_rewards = torch.flatten(batch_stats.logged_rewards).tolist()
|
||||
self.running_reward.extend(flattened_rewards)
|
||||
|
||||
if batch_stats.reward_loss is not None:
|
||||
reward_loss_mean = float(batch_stats.reward_loss.mean())
|
||||
self.reward_loss.append(reward_loss_mean)
|
||||
print_details = print_details + "REWARD LOSS: {0:.3f}\n".format(
|
||||
reward_loss_mean
|
||||
)
|
||||
|
||||
if batch_stats.imitator_loss is not None:
|
||||
imitator_loss_mean = float(batch_stats.imitator_loss.mean())
|
||||
self.imitator_loss.append(imitator_loss_mean)
|
||||
print_details = print_details + "IMITATOR LOSS: {0:.3f}\n".format(
|
||||
imitator_loss_mean
|
||||
)
|
||||
|
||||
if batch_stats.model_values is not None and self.action_names:
|
||||
self.model_values.append(
|
||||
dict(zip(self.action_names, batch_stats.model_values.mean(dim=0)))
|
||||
)
|
||||
self.model_value_stds.append(
|
||||
dict(zip(self.action_names, batch_stats.model_values.std(dim=0)))
|
||||
)
|
||||
|
||||
if batch_stats.model_values_on_logged_actions is not None:
|
||||
self.logged_action_q_value.append(
|
||||
batch_stats.model_values_on_logged_actions.mean().item()
|
||||
)
|
||||
|
||||
if (
|
||||
batch_stats.logged_actions is not None
|
||||
and batch_stats.model_action_idxs is not None
|
||||
):
|
||||
logged_action_counts = {
|
||||
action: (batch_stats.logged_actions == i).sum().item()
|
||||
for i, action in enumerate(self.action_names)
|
||||
}
|
||||
model_action_counts = {
|
||||
action: (batch_stats.model_action_idxs == i).sum().item()
|
||||
for i, action in enumerate(self.action_names)
|
||||
}
|
||||
print_details += "The distribution of logged actions : {}\n".format(
|
||||
logged_action_counts
|
||||
)
|
||||
print_details += "The distribution of model actions : {}\n".format(
|
||||
model_action_counts
|
||||
)
|
||||
for action, count in logged_action_counts.items():
|
||||
self.logged_action_counts[action] += count
|
||||
|
||||
self.model_action_counts.append(model_action_counts)
|
||||
|
||||
for action, count in model_action_counts.items():
|
||||
self.model_action_counts_cumulative[action] += count
|
||||
|
||||
total = float(sum(model_action_counts.values()))
|
||||
self.model_action_distr.append(
|
||||
{action: count / total for action, count in model_action_counts.items()}
|
||||
)
|
||||
|
||||
print_details += "Batch Evaluator Finished"
|
||||
for print_detail in print_details.split("\n"):
|
||||
logger.info(print_detail)
|
||||
|
||||
self.incoming_stats.clear()
|
||||
|
||||
def get_td_loss_after_n(self, n):
|
||||
return self.td_loss[n:]
|
||||
|
||||
def get_recent_td_loss(self):
|
||||
return TrainingReporter.calculate_recent_window_average(
|
||||
self.td_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1
|
||||
)
|
||||
|
||||
def get_recent_reward_loss(self):
|
||||
return TrainingReporter.calculate_recent_window_average(
|
||||
self.reward_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1
|
||||
)
|
||||
|
||||
def get_recent_imitator_loss(self):
|
||||
return TrainingReporter.calculate_recent_window_average(
|
||||
self.imitator_loss, TrainingReporter.RECENT_WINDOW_SIZE, num_entries=1
|
||||
)
|
||||
|
||||
def get_logged_action_distribution(self):
|
||||
total_actions = 1.0 * sum(self.logged_action_counts.values())
|
||||
return {k: (v / total_actions) for k, v in self.logged_action_counts.items()}
|
||||
|
||||
def get_model_action_distribution(self):
|
||||
total_actions = 1.0 * sum(self.model_action_counts_cumulative.values())
|
||||
return {
|
||||
k: (v / total_actions)
|
||||
for k, v in self.model_action_counts_cumulative.items()
|
||||
}
|
||||
|
||||
def get_recent_rewards(self):
|
||||
return self.running_reward
|
||||
|
||||
def log_to_tensorboard(self, epoch: int) -> None:
|
||||
def none_to_zero(x: Optional[float]) -> float:
|
||||
if x is None or math.isnan(x):
|
||||
return 0.0
|
||||
return x
|
||||
|
||||
for name, value in [
|
||||
("Training/td_loss", self.get_recent_td_loss()),
|
||||
("Training/reward_loss", self.get_recent_reward_loss()),
|
||||
("Training/imitator_loss", self.get_recent_imitator_loss()),
|
||||
]:
|
||||
# pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`.
|
||||
SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
|
||||
|
||||
@staticmethod
|
||||
def calculate_recent_window_average(arr, window_size, num_entries):
|
||||
if len(arr) > 0:
|
||||
begin = max(0, len(arr) - window_size)
|
||||
return np.mean(np.array(arr[begin:]), axis=0)
|
||||
else:
|
||||
logger.error("Not enough samples for evaluation.")
|
||||
if num_entries == 1:
|
||||
return float("nan")
|
||||
else:
|
||||
return [float("nan")] * num_entries
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
|
||||
|
||||
class TrainingReport(metaclass=RegistryMeta):
|
||||
pass
|
||||
@@ -1,95 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from reagent.core import aggregators as agg
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.union import TrainingReport__Union
|
||||
from reagent.reporting.oss_training_reports import (
|
||||
DebugToolsReport,
|
||||
OssWorldModelTrainingReport,
|
||||
)
|
||||
from reagent.reporting.reporter_base import ReporterBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorldModelReporter(ReporterBase):
|
||||
def __init__(self, report_interval: int = 10):
|
||||
"""
|
||||
For world model:
|
||||
'loss' (referring to total loss),
|
||||
'bce' (loss for predicting not_terminal),
|
||||
'gmm' (loss for next state prediction),
|
||||
'mse' (loss for predicting reward)
|
||||
"""
|
||||
aggregators: List[Tuple[str, agg.Aggregator]] = list(
|
||||
itertools.chain(
|
||||
[
|
||||
("loss", agg.MeanAggregator("loss", interval=report_interval)),
|
||||
("bce", agg.MeanAggregator("bce", interval=report_interval)),
|
||||
("gmm", agg.MeanAggregator("gmm", interval=report_interval)),
|
||||
("mse", agg.MeanAggregator("mse", interval=report_interval)),
|
||||
],
|
||||
[
|
||||
(
|
||||
f"{key}_tb",
|
||||
agg.TensorBoardHistogramAndMeanAggregator(
|
||||
key, log_key, interval=report_interval
|
||||
),
|
||||
)
|
||||
for key, log_key in [
|
||||
("loss", "loss"),
|
||||
("bce", "bce"),
|
||||
("gmm", "gmm"),
|
||||
("mse", "mse"),
|
||||
]
|
||||
],
|
||||
)
|
||||
)
|
||||
super().__init__(aggregators)
|
||||
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
report = OssWorldModelTrainingReport(
|
||||
loss=self.loss.values,
|
||||
bce=self.bce.values,
|
||||
gmm=self.gmm.values,
|
||||
mse=self.mse.values,
|
||||
)
|
||||
return RLTrainingOutput(
|
||||
training_report=TrainingReport__Union(oss_world_model_report=report)
|
||||
)
|
||||
|
||||
|
||||
class DebugToolsReporter(ReporterBase):
|
||||
def __init__(self, report_interval: int = 1):
|
||||
"""
|
||||
For debug tools: feature_importance, feature_sensitivity
|
||||
"""
|
||||
aggregators: List[Tuple[str, agg.Aggregator]] = [
|
||||
("feature_importance", agg.AppendAggregator("feature_importance")),
|
||||
("feature_sensitivity", agg.AppendAggregator("feature_sensitivity")),
|
||||
]
|
||||
super().__init__(aggregators)
|
||||
|
||||
def publish(self) -> RLTrainingOutput:
|
||||
feature_importance = (
|
||||
[]
|
||||
if len(self.feature_importance.values) == 0
|
||||
else self.feature_importance.values[-1]
|
||||
)
|
||||
feature_sensitivity = (
|
||||
[]
|
||||
if len(self.feature_sensitivity.values) == 0
|
||||
else self.feature_sensitivity.values[-1]
|
||||
)
|
||||
report = DebugToolsReport(
|
||||
feature_importance=feature_importance,
|
||||
feature_sensitivity=feature_sensitivity,
|
||||
)
|
||||
return RLTrainingOutput(
|
||||
training_report=TrainingReport__Union(oss_debug_tools_report=report)
|
||||
)
|
||||
@@ -1,402 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from reagent.core.rl_training_output import RLTrainingOutput
|
||||
from reagent.core.types import (
|
||||
Dataset,
|
||||
ReaderOptions,
|
||||
RecurringPeriod,
|
||||
ResourceOptions,
|
||||
RewardOptions,
|
||||
TableSpec,
|
||||
)
|
||||
from reagent.data_fetchers.data_fetcher import DataFetcher
|
||||
from reagent.evaluation.evaluator import Evaluator
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.parameters import NormalizationData
|
||||
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
|
||||
from reagent.publishers.model_publisher import ModelPublisher
|
||||
from reagent.tensorboardX import SummaryWriterContext, summary_writer_context
|
||||
from reagent.training.trainer import Trainer
|
||||
from reagent.validators.model_validator import ModelValidator
|
||||
from reagent.workflow_utils.iterators import DataLoaderWrapper
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainEvalSampleRanges(NamedTuple):
|
||||
train_sample_range: Tuple[float, float]
|
||||
eval_sample_range: Tuple[float, float]
|
||||
|
||||
|
||||
class BatchRunner:
|
||||
def __init__(
|
||||
self,
|
||||
use_gpu: bool,
|
||||
model_manager: ModelManager,
|
||||
data_fetcher: DataFetcher,
|
||||
reward_options: RewardOptions,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
warmstart_path: Optional[str] = None,
|
||||
):
|
||||
self.use_gpu = use_gpu
|
||||
self.model_manager = model_manager
|
||||
self.data_fetcher = data_fetcher
|
||||
self.normalization_data_map = normalization_data_map
|
||||
self.reward_options = reward_options
|
||||
self.warmstart_path = warmstart_path
|
||||
|
||||
def get_workflow_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def initialize_trainer(self) -> Trainer:
|
||||
# validate that we have all the required keys
|
||||
for normalization_key in self.model_manager.required_normalization_keys:
|
||||
normalization_data = self.normalization_data_map.get(
|
||||
normalization_key, None
|
||||
)
|
||||
assert normalization_data is not None, (
|
||||
f"NormalizationData for {normalization_key} "
|
||||
"is required but not provided."
|
||||
)
|
||||
# NOTE: Don't need this check in the future, for non-dense parameters
|
||||
assert normalization_data.dense_normalization_parameters is not None, (
|
||||
f"Dense normalization parameters for "
|
||||
f"{normalization_key} is not provided."
|
||||
)
|
||||
trainer = self.model_manager.build_trainer(
|
||||
self.use_gpu, self.normalization_data_map, self.reward_options
|
||||
)
|
||||
if self.warmstart_path is not None:
|
||||
trainer_state = torch.load(self.warmstart_path)
|
||||
trainer.load_state_dict(trainer_state)
|
||||
|
||||
self.trainer = trainer
|
||||
return trainer
|
||||
|
||||
def save_trainer(self, trainer: Trainer, output_path: str) -> None:
|
||||
"""
|
||||
Save the trainer for warmstarting/checkpointing.
|
||||
"""
|
||||
trainer_state = trainer.state_dict()
|
||||
torch.save(trainer_state, output_path)
|
||||
|
||||
@staticmethod
|
||||
def get_sample_range(
|
||||
input_table_spec: TableSpec, calc_cpe_in_training: bool
|
||||
) -> TrainEvalSampleRanges:
|
||||
table_sample = input_table_spec.table_sample
|
||||
eval_table_sample = input_table_spec.eval_table_sample
|
||||
|
||||
if not calc_cpe_in_training:
|
||||
# use all data if table sample = None
|
||||
if table_sample is None:
|
||||
train_sample_range = (0.0, 100.0)
|
||||
else:
|
||||
train_sample_range = (0.0, table_sample)
|
||||
return TrainEvalSampleRanges(
|
||||
train_sample_range=train_sample_range,
|
||||
# eval samples will not be used
|
||||
eval_sample_range=(0.0, 0.0),
|
||||
)
|
||||
|
||||
error_msg = (
|
||||
"calc_cpe_in_training is set to True. "
|
||||
f"Please specify table_sample(current={table_sample}) and "
|
||||
f"eval_table_sample(current={eval_table_sample}) such that "
|
||||
"eval_table_sample + table_sample <= 100. "
|
||||
"In order to reliably calculate CPE, eval_table_sample "
|
||||
"should not be too small."
|
||||
)
|
||||
assert table_sample is not None, error_msg
|
||||
assert eval_table_sample is not None, error_msg
|
||||
assert (eval_table_sample + table_sample) <= (100.0 + 1e-3), error_msg
|
||||
|
||||
return TrainEvalSampleRanges(
|
||||
train_sample_range=(0.0, table_sample),
|
||||
eval_sample_range=(100.0 - eval_table_sample, 100.0),
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_table_spec: TableSpec,
|
||||
reader_options: ReaderOptions,
|
||||
resource_options: ResourceOptions,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
logger.info("Starting query")
|
||||
|
||||
calc_cpe_in_training = self.model_manager.should_generate_eval_dataset
|
||||
sample_range_output = BatchRunner.get_sample_range(
|
||||
input_table_spec, calc_cpe_in_training
|
||||
)
|
||||
train_dataset = self.model_manager.query_data(
|
||||
data_fetcher=self.data_fetcher,
|
||||
input_table_spec=input_table_spec,
|
||||
sample_range=sample_range_output.train_sample_range,
|
||||
reward_options=self.reward_options,
|
||||
)
|
||||
eval_dataset = None
|
||||
if calc_cpe_in_training:
|
||||
eval_dataset = self.model_manager.query_data(
|
||||
data_fetcher=self.data_fetcher,
|
||||
input_table_spec=input_table_spec,
|
||||
sample_range=sample_range_output.eval_sample_range,
|
||||
reward_options=self.reward_options,
|
||||
)
|
||||
|
||||
return (train_dataset, eval_dataset)
|
||||
|
||||
def run_feature_identification(
|
||||
self, input_table_spec: TableSpec
|
||||
) -> Dict[str, NormalizationData]:
|
||||
return self.model_manager.run_feature_identification(
|
||||
self.data_fetcher, input_table_spec
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Dataset,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
num_epochs: int,
|
||||
reader_options: ReaderOptions,
|
||||
resource_options: Optional[ResourceOptions] = None,
|
||||
warmstart_path: Optional[str] = None,
|
||||
validator: Optional[ModelValidator] = None,
|
||||
parent_workflow_id: Optional[int] = None,
|
||||
recurring_period: Optional[RecurringPeriod] = None,
|
||||
) -> RLTrainingOutput:
|
||||
logger.info(f"{reader_options}")
|
||||
child_workflow_id = self.get_workflow_id()
|
||||
if parent_workflow_id is None:
|
||||
parent_workflow_id = child_workflow_id
|
||||
|
||||
resource_options = resource_options or ResourceOptions()
|
||||
|
||||
logger.info("Starting training")
|
||||
results = self.train_workflow(
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
num_epochs,
|
||||
parent_workflow_id=parent_workflow_id,
|
||||
child_workflow_id=child_workflow_id,
|
||||
reader_options=reader_options,
|
||||
resource_options=resource_options,
|
||||
)
|
||||
|
||||
if validator is not None:
|
||||
results = self.run_validator(validator, results)
|
||||
|
||||
return results
|
||||
|
||||
def run_validator(
|
||||
self, model_validator: ModelValidator, training_output: RLTrainingOutput
|
||||
) -> RLTrainingOutput:
|
||||
assert (
|
||||
training_output.validation_result is None
|
||||
), f"validation_output was set to f{training_output.validation_output}"
|
||||
validation_result = model_validator.validate(training_output)
|
||||
return dataclasses.replace(training_output, validation_result=validation_result)
|
||||
|
||||
def run_publisher(
|
||||
self,
|
||||
model_publisher: ModelPublisher,
|
||||
training_output: RLTrainingOutput,
|
||||
recurring_workflow_id: int,
|
||||
child_workflow_id: int,
|
||||
recurring_period: Optional[RecurringPeriod],
|
||||
) -> RLTrainingOutput:
|
||||
assert (
|
||||
training_output.publishing_result is None
|
||||
), f"publishing_output was set to f{training_output.publishing_output}"
|
||||
publishing_result = model_publisher.publish(
|
||||
self.model_manager,
|
||||
training_output,
|
||||
recurring_workflow_id,
|
||||
child_workflow_id,
|
||||
recurring_period,
|
||||
)
|
||||
return dataclasses.replace(training_output, publishing_result=publishing_result)
|
||||
|
||||
def train_workflow(
|
||||
self,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Optional[Dataset],
|
||||
num_epochs: int,
|
||||
parent_workflow_id: int,
|
||||
child_workflow_id: int,
|
||||
reader_options: ReaderOptions,
|
||||
resource_options: Optional[ResourceOptions] = None,
|
||||
) -> RLTrainingOutput:
|
||||
writer = SummaryWriter()
|
||||
logger.info("TensorBoard logging location is: {}".format(writer.log_dir))
|
||||
|
||||
trainer = self.initialize_trainer()
|
||||
|
||||
with summary_writer_context(writer):
|
||||
train_output: RLTrainingOutput = self._train(
|
||||
train_dataset, eval_dataset, num_epochs, reader_options, trainer
|
||||
)
|
||||
|
||||
torchscript_output_path = f"model_{round(time.time())}.torchscript"
|
||||
serving_module = self.model_manager.build_serving_module(
|
||||
self.normalization_data_map, trainer
|
||||
)
|
||||
torch.jit.save(serving_module, torchscript_output_path)
|
||||
logger.info(f"Saved torchscript model to {torchscript_output_path}")
|
||||
return dataclasses.replace(
|
||||
train_output, local_output_path=torchscript_output_path
|
||||
)
|
||||
|
||||
def _train(
|
||||
self,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Optional[Dataset],
|
||||
num_epochs: int,
|
||||
reader_options: ReaderOptions,
|
||||
trainer: Trainer,
|
||||
) -> RLTrainingOutput:
|
||||
reporter = self.model_manager.get_reporter()
|
||||
trainer.reporter = reporter
|
||||
|
||||
evaluator = self.model_manager.get_evaluator(trainer, self.reward_options)
|
||||
if evaluator is not None:
|
||||
evaluator.reporter = reporter
|
||||
|
||||
batch_preprocessor = self.model_manager.build_batch_preprocessor(
|
||||
reader_options,
|
||||
self.use_gpu,
|
||||
trainer.minibatch_size,
|
||||
self.normalization_data_map,
|
||||
self.reward_options,
|
||||
)
|
||||
return self.train_and_evaluate_generic(
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
trainer,
|
||||
num_epochs,
|
||||
self.use_gpu,
|
||||
batch_preprocessor,
|
||||
evaluator,
|
||||
reader_options,
|
||||
)
|
||||
|
||||
def run_on_dataset_batches(
|
||||
self,
|
||||
run_on_batch_fn,
|
||||
dataset: Dataset,
|
||||
minibatch_size: int,
|
||||
batch_preprocessor: BatchPreprocessor,
|
||||
use_gpu: bool,
|
||||
reader_options: ReaderOptions,
|
||||
dataset_size: Optional[int] = None,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
logger.info(f"{reader_options}")
|
||||
""" run_on_batch_fn is a function f that expects batches """
|
||||
if dataset_size is None:
|
||||
dataset_size = self.data_fetcher.get_table_row_count(dataset)
|
||||
assert dataset_size is not None
|
||||
assert dataset_size > 0, f"{dataset_size} is expected to be positive"
|
||||
|
||||
@contextmanager
|
||||
def cleanup_dataloader_session(data_loader):
|
||||
try:
|
||||
yield data_loader
|
||||
finally:
|
||||
logger.info("Closing data loader")
|
||||
if hasattr(data_loader, "destroy_session"):
|
||||
logger.info("Closing DistributedDataLoader")
|
||||
data_loader.destroy_session()
|
||||
|
||||
_dataloader = self.data_fetcher.get_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=minibatch_size,
|
||||
batch_preprocessor=batch_preprocessor,
|
||||
use_gpu=use_gpu,
|
||||
reader_options=reader_options,
|
||||
)
|
||||
with cleanup_dataloader_session(_dataloader) as dataloader:
|
||||
post_dataloader_preprocessor = self.data_fetcher.get_post_dataloader_preprocessor(
|
||||
reader_options=reader_options, use_gpu=use_gpu
|
||||
)
|
||||
dataloader_wrapper = DataLoaderWrapper(
|
||||
dataloader=dataloader,
|
||||
dataloader_size=dataset_size,
|
||||
post_dataloader_preprocessor=post_dataloader_preprocessor,
|
||||
)
|
||||
for batch in dataloader_wrapper:
|
||||
run_on_batch_fn(batch)
|
||||
return dataloader
|
||||
|
||||
def train_and_evaluate_generic(
|
||||
self,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Optional[Dataset],
|
||||
trainer: Trainer,
|
||||
num_epochs: int,
|
||||
use_gpu: bool,
|
||||
batch_preprocessor: BatchPreprocessor,
|
||||
evaluator: Optional[Evaluator],
|
||||
reader_options: ReaderOptions,
|
||||
sort_eval_data: bool = True,
|
||||
) -> RLTrainingOutput:
|
||||
logger.info(f"{reader_options}")
|
||||
assert num_epochs > 0, f"Epoch should be positive, got {num_epochs}"
|
||||
train_dataset_size = self.data_fetcher.get_table_row_count(train_dataset)
|
||||
if eval_dataset is not None and not sort_eval_data:
|
||||
eval_dataset_size = self.data_fetcher.get_table_row_count(eval_dataset)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
SummaryWriterContext._reset_globals()
|
||||
logger.info(f"Starting training epoch {epoch}.")
|
||||
data_loader = self.run_on_dataset_batches(
|
||||
run_on_batch_fn=trainer.train,
|
||||
dataset=train_dataset,
|
||||
minibatch_size=trainer.minibatch_size,
|
||||
batch_preprocessor=batch_preprocessor,
|
||||
use_gpu=use_gpu,
|
||||
reader_options=reader_options,
|
||||
dataset_size=train_dataset_size,
|
||||
)
|
||||
if eval_dataset is not None and evaluator is not None:
|
||||
if sort_eval_data:
|
||||
logger.info(
|
||||
f"Starting evaluation epoch {epoch} by sorting and one shot"
|
||||
)
|
||||
eval_data = self.data_fetcher.gather_and_sort_eval_data(
|
||||
trainer=trainer,
|
||||
eval_dataset=eval_dataset,
|
||||
batch_preprocessor=batch_preprocessor,
|
||||
use_gpu=use_gpu,
|
||||
reader_options=reader_options,
|
||||
)
|
||||
evaluator.evaluate_one_shot(eval_data)
|
||||
evaluator.finish()
|
||||
else:
|
||||
logger.info(
|
||||
f"Starting evaluation epoch {epoch} by running on batches"
|
||||
)
|
||||
data_loader = self.run_on_dataset_batches(
|
||||
run_on_batch_fn=evaluator.evaluate,
|
||||
dataset=eval_dataset,
|
||||
minibatch_size=trainer.minibatch_size,
|
||||
batch_preprocessor=batch_preprocessor,
|
||||
use_gpu=use_gpu,
|
||||
reader_options=reader_options,
|
||||
dataset_size=eval_dataset_size,
|
||||
)
|
||||
evaluator.finish()
|
||||
trainer.reporter.finish_epoch()
|
||||
report = trainer.reporter.publish()
|
||||
|
||||
if hasattr(data_loader, "shutdown"):
|
||||
data_loader.shutdown()
|
||||
return report
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from reagent.core.types import RewardOptions
|
||||
from reagent.data_fetchers.oss_data_fetcher import OssDataFetcher
|
||||
from reagent.model_managers.model_manager import ModelManager
|
||||
from reagent.parameters import NormalizationData
|
||||
from reagent.runners.batch_runner import BatchRunner
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OssBatchRunner(BatchRunner):
|
||||
def __init__(
|
||||
self,
|
||||
use_gpu: bool,
|
||||
model_manager: ModelManager,
|
||||
reward_options: RewardOptions,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
warmstart_path: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
use_gpu,
|
||||
model_manager,
|
||||
OssDataFetcher(),
|
||||
reward_options,
|
||||
normalization_data_map,
|
||||
warmstart_path,
|
||||
)
|
||||
# Generate a random workflow id for this batch runner
|
||||
self.workflow_id = random.randint(1000, 10000000)
|
||||
|
||||
def get_workflow_id(self) -> int:
|
||||
return self.workflow_id
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from reagent.core.observers import ValueListObserver
|
||||
from reagent.core.tracker import observable
|
||||
|
||||
|
||||
class TestObservable(unittest.TestCase):
|
||||
def test_observable(self):
|
||||
@observable(td_loss=float, str_val=str)
|
||||
class DummyClass:
|
||||
def __init__(self, a, b, c=10):
|
||||
super().__init__()
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.c = c
|
||||
|
||||
def do_something(self, i):
|
||||
self.notify_observers(td_loss=i, str_val="not_used")
|
||||
|
||||
instance = DummyClass(1, 2)
|
||||
self.assertIsInstance(instance, DummyClass)
|
||||
self.assertEqual(instance.a, 1)
|
||||
self.assertEqual(instance.b, 2)
|
||||
self.assertEqual(instance.c, 10)
|
||||
|
||||
observers = [ValueListObserver("td_loss") for _i in range(3)]
|
||||
instance.add_observers(observers)
|
||||
# Adding twice should not result in double update
|
||||
instance.add_observer(observers[0])
|
||||
|
||||
for i in range(10):
|
||||
instance.do_something(float(i))
|
||||
|
||||
for observer in observers:
|
||||
self.assertEqual(observer.values, [float(i) for i in range(10)])
|
||||
|
||||
def test_no_observable_values(self):
|
||||
try:
|
||||
|
||||
@observable()
|
||||
class NoObservableValues:
|
||||
pass
|
||||
|
||||
except AssertionError:
|
||||
pass
|
||||
@@ -8,7 +8,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.evaluation.doubly_robust_estimator import DoublyRobustEstimator
|
||||
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
|
||||
from reagent.evaluation.ope_adapter import OPEstimatorAdapter
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import logging
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.evaluation.evaluation_data_page import EvaluationDataPage
|
||||
from reagent.evaluation.ope_adapter import OPEstimatorAdapter
|
||||
from reagent.evaluation.ope_adapter import (
|
||||
OPEstimatorAdapter,
|
||||
SequentialOPEstimatorAdapter,
|
||||
)
|
||||
from reagent.ope.estimators.contextual_bandits_estimators import (
|
||||
DMEstimator,
|
||||
DoublyRobustEstimator,
|
||||
@@ -13,6 +17,20 @@ from reagent.ope.estimators.contextual_bandits_estimators import (
|
||||
SwitchDREstimator,
|
||||
SwitchEstimator,
|
||||
)
|
||||
from reagent.ope.estimators.sequential_estimators import (
|
||||
DoublyRobustEstimator as SeqDREstimator,
|
||||
EpsilonGreedyRLPolicy,
|
||||
RandomRLPolicy,
|
||||
RLEstimatorInput,
|
||||
)
|
||||
from reagent.ope.estimators.types import Action, ActionSpace
|
||||
from reagent.ope.test.envs import PolicyLogGenerator
|
||||
from reagent.ope.test.gridworld import GridWorld, NoiseGridWorldModel
|
||||
from reagent.ope.trainers.rl_tabular_trainers import (
|
||||
DPTrainer,
|
||||
DPValueFunction,
|
||||
TabularPolicy,
|
||||
)
|
||||
from reagent.test.evaluation.test_evaluation_data_page import (
|
||||
FakeSeq2SlateRewardNetwork,
|
||||
FakeSeq2SlateTransformerNet,
|
||||
@@ -22,6 +40,56 @@ from reagent.test.evaluation.test_evaluation_data_page import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rlestimator_input_to_edp(
|
||||
input: RLEstimatorInput, num_actions: int
|
||||
) -> EvaluationDataPage:
|
||||
mdp_ids = []
|
||||
logged_propensities = []
|
||||
logged_rewards = []
|
||||
action_mask = []
|
||||
model_propensities = []
|
||||
model_values = []
|
||||
|
||||
for mdp in input.log:
|
||||
mdp_id = len(mdp_ids)
|
||||
for t in mdp:
|
||||
mdp_ids.append(mdp_id)
|
||||
logged_propensities.append(t.action_prob)
|
||||
logged_rewards.append(t.reward)
|
||||
assert t.action is not None
|
||||
action_mask.append(
|
||||
[1 if x == t.action.value else 0 for x in range(num_actions)]
|
||||
)
|
||||
assert t.last_state is not None
|
||||
model_propensities.append(
|
||||
[
|
||||
input.target_policy(t.last_state)[Action(x)]
|
||||
for x in range(num_actions)
|
||||
]
|
||||
)
|
||||
assert input.value_function is not None
|
||||
model_values.append(
|
||||
[
|
||||
input.value_function(t.last_state, Action(x))
|
||||
for x in range(num_actions)
|
||||
]
|
||||
)
|
||||
|
||||
return EvaluationDataPage(
|
||||
mdp_id=torch.tensor(mdp_ids).reshape(len(mdp_ids), 1),
|
||||
logged_propensities=torch.tensor(logged_propensities).reshape(
|
||||
(len(logged_propensities), 1)
|
||||
),
|
||||
logged_rewards=torch.tensor(logged_rewards).reshape((len(logged_rewards), 1)),
|
||||
action_mask=torch.tensor(action_mask),
|
||||
model_propensities=torch.tensor(model_propensities),
|
||||
model_values=torch.tensor(model_values),
|
||||
sequence_number=torch.tensor([]),
|
||||
model_rewards=torch.tensor([]),
|
||||
model_rewards_for_logged_action=torch.tensor([]),
|
||||
)
|
||||
|
||||
|
||||
class TestOPEModuleAlgs(unittest.TestCase):
|
||||
GAMMA = 0.9
|
||||
CPE_PASS_BAR = 1.0
|
||||
@@ -30,6 +98,98 @@ class TestOPEModuleAlgs(unittest.TestCase):
|
||||
NOISE_EPSILON = 0.3
|
||||
EPISODES = 2
|
||||
|
||||
def test_gridworld_sequential_adapter(self):
|
||||
"""
|
||||
Create a gridworld environment, logging policy, and target policy
|
||||
Evaluates target policy using the direct OPE sequential doubly robust estimator,
|
||||
then transforms the log into an evaluation data page which is passed to the ope adapter.
|
||||
|
||||
This test is meant to verify the adaptation of EDPs into RLEstimatorInputs as employed
|
||||
by ReAgent since ReAgent provides EDPs to Evaluators. Going from EDP -> RLEstimatorInput
|
||||
is more involved than RLEstimatorInput -> EDP since the EDP does not store the state
|
||||
at each timestep in each MDP, only the corresponding logged outputs & model outputs.
|
||||
Thus, the adapter must do some tricks to represent these timesteps as states so the
|
||||
ope module can extract the correct outputs.
|
||||
|
||||
Note that there is some randomness in the model outputs since the model is purposefully
|
||||
noisy. However, the same target policy is being evaluated on the same logged walks through
|
||||
the gridworld, so the two results should be close in value (within 1).
|
||||
|
||||
"""
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else None
|
||||
|
||||
gridworld = GridWorld.from_grid(
|
||||
[
|
||||
["s", "0", "0", "0", "0"],
|
||||
["0", "0", "0", "W", "0"],
|
||||
["0", "0", "0", "0", "0"],
|
||||
["0", "W", "0", "0", "0"],
|
||||
["0", "0", "0", "0", "g"],
|
||||
],
|
||||
max_horizon=TestOPEModuleAlgs.MAX_HORIZON,
|
||||
)
|
||||
|
||||
action_space = ActionSpace(4)
|
||||
opt_policy = TabularPolicy(action_space)
|
||||
trainer = DPTrainer(gridworld, opt_policy)
|
||||
value_func = trainer.train(gamma=TestOPEModuleAlgs.GAMMA)
|
||||
|
||||
behavivor_policy = RandomRLPolicy(action_space)
|
||||
target_policy = EpsilonGreedyRLPolicy(
|
||||
opt_policy, TestOPEModuleAlgs.NOISE_EPSILON
|
||||
)
|
||||
model = NoiseGridWorldModel(
|
||||
gridworld,
|
||||
action_space,
|
||||
epsilon=TestOPEModuleAlgs.NOISE_EPSILON,
|
||||
max_horizon=TestOPEModuleAlgs.MAX_HORIZON,
|
||||
)
|
||||
value_func = DPValueFunction(target_policy, model, TestOPEModuleAlgs.GAMMA)
|
||||
ground_truth = DPValueFunction(
|
||||
target_policy, gridworld, TestOPEModuleAlgs.GAMMA
|
||||
)
|
||||
|
||||
log = []
|
||||
log_generator = PolicyLogGenerator(gridworld, behavivor_policy)
|
||||
num_episodes = TestOPEModuleAlgs.EPISODES
|
||||
for state in gridworld.states:
|
||||
for _ in range(num_episodes):
|
||||
log.append(log_generator.generate_log(state))
|
||||
|
||||
estimator_input = RLEstimatorInput(
|
||||
gamma=TestOPEModuleAlgs.GAMMA,
|
||||
log=log,
|
||||
target_policy=target_policy,
|
||||
value_function=value_func,
|
||||
ground_truth=ground_truth,
|
||||
)
|
||||
|
||||
edp = rlestimator_input_to_edp(estimator_input, len(model.action_space))
|
||||
|
||||
dr_estimator = SeqDREstimator(
|
||||
weight_clamper=None, weighted=False, device=device
|
||||
)
|
||||
|
||||
module_results = SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate(
|
||||
dr_estimator.evaluate(estimator_input)
|
||||
)
|
||||
adapter_results = SequentialOPEstimatorAdapter(
|
||||
dr_estimator, TestOPEModuleAlgs.GAMMA, device=device
|
||||
).estimate(edp)
|
||||
|
||||
self.assertAlmostEqual(
|
||||
adapter_results.raw,
|
||||
module_results.raw,
|
||||
delta=TestOPEModuleAlgs.CPE_PASS_BAR,
|
||||
), f"OPE adapter results differed too much from underlying module (Diff: {abs(adapter_results.raw - module_results.raw)} > {TestOPEModuleAlgs.CPE_PASS_BAR})"
|
||||
self.assertLess(
|
||||
adapter_results.raw, TestOPEModuleAlgs.CPE_MAX_VALUE
|
||||
), f"OPE adapter results are too large ({adapter_results.raw} > {TestOPEModuleAlgs.CPE_MAX_VALUE})"
|
||||
|
||||
def test_seq2slate_eval_data_page(self):
|
||||
"""
|
||||
Create 3 slate ranking logs and evaluate using Direct Method, Inverse
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.test.models.test_utils import check_save_load
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import unittest
|
||||
import numpy.testing as npt
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
from reagent.core import types as rlt
|
||||
from reagent import types as rlt
|
||||
from reagent.models.bcq import BatchConstrainedDQN
|
||||
from reagent.models.dqn import FullyConnectedDQN
|
||||
from reagent.models.fully_connected_network import FullyConnectedNetwork
|
||||
|
||||
@@ -43,9 +43,7 @@ class TestNoSoftUpdteEmbedding(unittest.TestCase):
|
||||
self.assertEqual(1, len(params))
|
||||
param = params[0].detach().numpy()
|
||||
|
||||
trainer = RLTrainer(
|
||||
rl_parameters=RLParameters(), minibatch_size=1024, use_gpu=False
|
||||
)
|
||||
trainer = RLTrainer(rl_parameters=RLParameters(), use_gpu=False)
|
||||
trainer._soft_update(model, target_model, 0.1)
|
||||
|
||||
target_params = list(target_model.parameters())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user