mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Gym post step (#232)
Summary: - Use PostStep for Gym instead of ReplayBufferAdd/Train Fns, for cleaner interface - Refactoring gym tests and replay buffer - Introduced and used flake8 and isort Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/232 Reviewed By: kittipatv Differential Revision: D21078739 Pulled By: kaiwenw fbshipit-source-id: 5d83edb343e70bd0bac7e5553e9c88cbb7e6a672
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a424fb81ee
commit
8a172c73ab
@@ -0,0 +1,9 @@
|
||||
[settings]
|
||||
multi_line_output=3
|
||||
include_trailing_comma=True
|
||||
force_grid_wrap=0
|
||||
use_parentheses=True
|
||||
line_length=88
|
||||
lines_after_imports=2
|
||||
reverse_relative=True
|
||||
default_section=THIRDPARTY
|
||||
+27
-40
@@ -5,8 +5,11 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.types import ActionPreprocessor, ReplayBufferAddFn, ReplayBufferTrainFn
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.gym.types import ActionPreprocessor, PostStep
|
||||
|
||||
|
||||
def no_op(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class Agent:
|
||||
@@ -14,9 +17,7 @@ class Agent:
|
||||
self,
|
||||
policy: Policy,
|
||||
action_preprocessor: ActionPreprocessor,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
replay_buffer_add_fn: Optional[ReplayBufferAddFn] = None,
|
||||
replay_buffer_train_fn: Optional[ReplayBufferTrainFn] = None,
|
||||
post_transition_callback: PostStep = no_op,
|
||||
):
|
||||
"""
|
||||
The Agent orchestrates the interactions on our RL components, given
|
||||
@@ -24,20 +25,16 @@ class Agent:
|
||||
|
||||
Args:
|
||||
policy: Policy that acts given preprocessed input
|
||||
replay_buffer: if provided, inserts each experience via the
|
||||
replay_buffer_add_fn
|
||||
replay_buffer_add_fn: fn of the form
|
||||
(replay_buffer, obs, action, r, t) -> void
|
||||
which adds an experience into the given replay buffer
|
||||
replay_buffer_train_fn: called in poststep after adding experience.
|
||||
Performs training steps based on replay buffer samples.
|
||||
action_preprocessor: preprocesses action for environment
|
||||
post_step: called after env.step(action).
|
||||
Default post_step is to do nothing.
|
||||
"""
|
||||
self.policy = policy
|
||||
self.action_preprocessor = action_preprocessor
|
||||
self.replay_buffer = replay_buffer
|
||||
self.replay_buffer_add_fn = replay_buffer_add_fn
|
||||
self.replay_buffer_train_fn = replay_buffer_train_fn
|
||||
self.post_transition_callback = post_transition_callback
|
||||
self._reset_internal_states()
|
||||
|
||||
def _reset_internal_states(self):
|
||||
# intermediate state between act and post_step
|
||||
self._obs: Any = None
|
||||
self._actor_output: Any = None
|
||||
@@ -47,30 +44,20 @@ class Agent:
|
||||
self, obs: Any, possible_actions_mask: Optional[torch.Tensor] = None
|
||||
) -> Any:
|
||||
actor_output = self.policy.act(obs, possible_actions_mask)
|
||||
if self.replay_buffer:
|
||||
self._obs = obs
|
||||
self._actor_output = actor_output
|
||||
self._possible_actions_mask = possible_actions_mask
|
||||
action_for_env = self.action_preprocessor(actor_output)
|
||||
return action_for_env
|
||||
|
||||
def post_step(self, reward: float, terminal: bool, *args):
|
||||
# store intermediate states for post_step
|
||||
self._obs = obs
|
||||
self._actor_output = actor_output
|
||||
self._possible_actions_mask = possible_actions_mask
|
||||
|
||||
# return action for the environment
|
||||
return self.action_preprocessor(actor_output)
|
||||
|
||||
def post_step(self, reward: float, terminal: bool):
|
||||
""" to be called after step(action) """
|
||||
if self.replay_buffer:
|
||||
assert self._obs is not None
|
||||
assert self._actor_output is not None
|
||||
assert self.replay_buffer_add_fn is not None
|
||||
assert self.replay_buffer_train_fn is not None
|
||||
self.replay_buffer_add_fn(
|
||||
self.replay_buffer,
|
||||
self._obs,
|
||||
self._actor_output,
|
||||
reward,
|
||||
terminal,
|
||||
self._possible_actions_mask,
|
||||
)
|
||||
self._obs = None
|
||||
self._actor_output = None
|
||||
self._possible_actions_mask = None
|
||||
|
||||
self.replay_buffer_train_fn(self.replay_buffer)
|
||||
assert self._obs is not None
|
||||
assert self._actor_output is not None
|
||||
self.post_transition_callback(
|
||||
self._obs, self._actor_output, reward, terminal, self._possible_actions_mask
|
||||
)
|
||||
self._reset_internal_states()
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.gym.types import ActionPreprocessor, PostStep, TrainerPreprocessor
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.training.rl_dataset import RLDataset
|
||||
from reagent.training.trainer import Trainer
|
||||
|
||||
|
||||
def train_with_replay_buffer_post_step(
|
||||
replay_buffer: ReplayBuffer,
|
||||
trainer: Trainer,
|
||||
trainer_preprocessor: TrainerPreprocessor,
|
||||
training_freq: int,
|
||||
batch_size: int,
|
||||
replay_burnin: Optional[int] = None,
|
||||
) -> PostStep:
|
||||
""" Called in post_step of agent to train based on replay buffer (RB).
|
||||
Args:
|
||||
trainer: responsible for having a .train method to train the model
|
||||
trainer_preprocessor: format RB output for trainer.train
|
||||
training_freq: how many steps in between trains
|
||||
batch_size: how big of a batch to sample
|
||||
replay_burnin: optional requirement for minimum size of RB before
|
||||
training begins. (i.e. burn in this many frames)
|
||||
"""
|
||||
_num_steps = 0
|
||||
size_req = batch_size
|
||||
if replay_burnin is not None:
|
||||
size_req = max(size_req, replay_burnin)
|
||||
|
||||
def post_step(
|
||||
obs: Any,
|
||||
actor_output: rlt.ActorOutput,
|
||||
reward: float,
|
||||
terminal: bool,
|
||||
possible_actions_mask: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
nonlocal _num_steps
|
||||
|
||||
action = actor_output.action.numpy()
|
||||
log_prob = actor_output.log_prob.numpy()
|
||||
if possible_actions_mask is None:
|
||||
possible_actions_mask = torch.ones_like(actor_output.action).to(torch.bool)
|
||||
possible_actions_mask = possible_actions_mask.numpy()
|
||||
replay_buffer.add(
|
||||
obs, action, reward, terminal, possible_actions_mask, log_prob
|
||||
)
|
||||
|
||||
if replay_buffer.size >= size_req and _num_steps % training_freq == 0:
|
||||
train_batch = replay_buffer.sample_transition_batch(batch_size=batch_size)
|
||||
preprocessed_batch = trainer_preprocessor(train_batch)
|
||||
trainer.train(preprocessed_batch)
|
||||
_num_steps += 1
|
||||
return
|
||||
|
||||
return post_step
|
||||
|
||||
|
||||
def log_data_post_step(
|
||||
dataset: RLDataset, action_preprocessor: ActionPreprocessor, mdp_id: str
|
||||
) -> PostStep:
|
||||
sequence_number = 0
|
||||
|
||||
def post_step(
|
||||
obs: Any,
|
||||
actor_output: rlt.ActorOutput,
|
||||
reward: float,
|
||||
terminal: bool,
|
||||
possible_actions_mask: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
""" log data into dataset """
|
||||
nonlocal sequence_number
|
||||
|
||||
if possible_actions_mask is None:
|
||||
possible_actions_mask = torch.ones_like(actor_output.action).to(torch.bool)
|
||||
|
||||
if terminal:
|
||||
possible_actions_mask = torch.zeros_like(actor_output.action).to(torch.bool)
|
||||
|
||||
# timeline operator expects str for disc and map<str, double> for cts
|
||||
# TODO: case for cts
|
||||
action = str(action_preprocessor(actor_output))
|
||||
|
||||
dataset.insert_pre_timeline_format(
|
||||
mdp_id=None,
|
||||
sequence_number=sequence_number,
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
possible_actions=None,
|
||||
time_diff=1,
|
||||
action_probability=actor_output.log_prob.exp().item(),
|
||||
possible_actions_mask=possible_actions_mask,
|
||||
)
|
||||
sequence_number += 1
|
||||
return
|
||||
|
||||
return post_step
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
|
||||
|
||||
def replay_buffer_add_fn(
|
||||
replay_buffer: ReplayBuffer,
|
||||
obs: Any,
|
||||
actor_output: rlt.ActorOutput,
|
||||
reward: float,
|
||||
terminal: bool,
|
||||
possible_actions_mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
""" Simply adds transition into buffer after converting to numpy """
|
||||
action = actor_output.action.numpy()
|
||||
log_prob = actor_output.log_prob.numpy()
|
||||
if possible_actions_mask is None:
|
||||
possible_actions_mask = torch.ones_like(actor_output.action).to(torch.bool)
|
||||
possible_actions_mask = possible_actions_mask.numpy()
|
||||
replay_buffer.add(obs, action, reward, terminal, possible_actions_mask, log_prob)
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from reagent.gym.types import ReplayBufferTrainFn, TrainerPreprocessor
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.training.trainer import Trainer
|
||||
|
||||
|
||||
def replay_buffer_train_fn(
|
||||
trainer: Trainer,
|
||||
trainer_preprocessor: TrainerPreprocessor,
|
||||
training_freq: int,
|
||||
batch_size: int,
|
||||
replay_burnin: Optional[int] = None,
|
||||
) -> ReplayBufferTrainFn:
|
||||
""" Called in post_step of agent to train based on replay buffer (RB).
|
||||
Args:
|
||||
trainer: responsible for having a .train method to train the model
|
||||
trainer_preprocessor: format RB output for trainer.train
|
||||
training_freq: how many steps in between trains
|
||||
batch_size: how big of a batch to sample
|
||||
replay_burnin: optional requirement for minimum size of RB before
|
||||
training begins. (i.e. burn in this many frames)
|
||||
"""
|
||||
_num_steps = 0
|
||||
size_req = batch_size
|
||||
if replay_burnin is not None:
|
||||
size_req = max(size_req, replay_burnin)
|
||||
|
||||
def train(replay_buffer: ReplayBuffer) -> None:
|
||||
""" To be called in post step """
|
||||
nonlocal _num_steps, size_req
|
||||
if replay_buffer.size >= size_req and _num_steps % training_freq == 0:
|
||||
train_batch = replay_buffer.sample_transition_batch(batch_size=batch_size)
|
||||
preprocessed_batch = trainer_preprocessor(train_batch)
|
||||
trainer.train(preprocessed_batch)
|
||||
_num_steps += 1
|
||||
return
|
||||
|
||||
return train
|
||||
@@ -10,7 +10,8 @@ class EnvFactory:
|
||||
@staticmethod
|
||||
def make(name: str) -> gym.Env:
|
||||
env = gym.make(name)
|
||||
env = ReseedWrapper(env)
|
||||
if name.startswith("MiniGrid-"):
|
||||
# Wrap in minigrid simplifier
|
||||
env = SimpleObsWrapper(ReseedWrapper(env))
|
||||
env = SimpleObsWrapper(env)
|
||||
return env
|
||||
|
||||
@@ -20,12 +20,12 @@ class SimpleObsWrapper(gym.core.ObservationWrapper):
|
||||
low=0,
|
||||
high=1,
|
||||
shape=(self.env.width * self.env.height * NUM_DIRECTIONS,),
|
||||
dtype="uint8",
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
def observation(self, obs):
|
||||
retval = np.zeros(
|
||||
(self.env.width * self.env.height * NUM_DIRECTIONS,), dtype=np.uint8
|
||||
(self.env.width * self.env.height * NUM_DIRECTIONS,), dtype=np.float32
|
||||
)
|
||||
retval[
|
||||
self.env.agent_pos[0] * self.env.height * NUM_DIRECTIONS
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from reagent.gym.policies.policy import Policy
|
||||
|
||||
|
||||
class DiscreteRandomPolicy(Policy):
|
||||
def __init__(self, num_actions):
|
||||
""" Random actor for accumulating random offline data. """
|
||||
self.num_actions = num_actions
|
||||
self.default_weights = torch.ones(num_actions)
|
||||
|
||||
def act(
|
||||
self, obs: Any, possible_actions_mask: Optional[torch.Tensor] = None
|
||||
) -> rlt.ActorOutput:
|
||||
""" Act randomly regardless of the observation. """
|
||||
weights = self.default_weights
|
||||
if possible_actions_mask:
|
||||
assert possible_actions_mask.shape == self.default_weights.shape
|
||||
weights = weights * possible_actions_mask
|
||||
|
||||
# sample a random action
|
||||
m = torch.distributions.Categorical(weights)
|
||||
raw_action = m.sample()
|
||||
action = F.one_hot(raw_action, self.num_actions).squeeze(0)
|
||||
log_prob = m.log_prob(raw_action).float().squeeze(0)
|
||||
return rlt.ActorOutput(action=action, log_prob=log_prob)
|
||||
@@ -30,8 +30,6 @@ class SoftmaxActionSampler(Sampler):
|
||||
def sample_action(
|
||||
self, scores: torch.Tensor, possible_actions_mask: Optional[torch.Tensor] = None
|
||||
) -> rlt.ActorOutput:
|
||||
# TODO: temp hack, convert to single instead of batched
|
||||
scores = scores.unsqueeze(0)
|
||||
assert scores.dim() == 2, (
|
||||
"scores dim is %d" % scores.dim()
|
||||
) # batch_size x num_actions
|
||||
|
||||
@@ -1,2 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .action_preprocessors.action_preprocessor import (
|
||||
argmax_action_preprocessor,
|
||||
numpy_action_preprocessor,
|
||||
)
|
||||
from .policy_preprocessors.policy_preprocessor import (
|
||||
numpy_policy_preprocessor,
|
||||
tiled_numpy_policy_preprocessor,
|
||||
)
|
||||
from .trainer_preprocessors.trainer_preprocessor import (
|
||||
discrete_dqn_trainer_preprocessor,
|
||||
parametric_dqn_trainer_preprocessor,
|
||||
sac_trainer_preprocessor,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"numpy_action_preprocessor",
|
||||
"argmax_action_preprocessor",
|
||||
"numpy_policy_preprocessor",
|
||||
"tiled_numpy_policy_preprocessor",
|
||||
"discrete_dqn_trainer_preprocessor",
|
||||
"parametric_dqn_trainer_preprocessor",
|
||||
"sac_trainer_preprocessor",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import reagent.types as rlt
|
||||
|
||||
|
||||
def discrete_action_preprocessor(actor_output: rlt.ActorOutput) -> np.array:
|
||||
def argmax_action_preprocessor(actor_output: rlt.ActorOutput) -> np.array:
|
||||
""" Simply reverses the one-hot encoding and convert to numpy """
|
||||
action = actor_output.action
|
||||
assert action.dim() == 1, "action has dim %d" % action.dim()
|
||||
@@ -19,6 +19,6 @@ def discrete_action_preprocessor(actor_output: rlt.ActorOutput) -> np.array:
|
||||
return idx
|
||||
|
||||
|
||||
def continuous_action_preprocessor(actor_output: rlt.ActorOutput) -> np.array:
|
||||
def numpy_action_preprocessor(actor_output: rlt.ActorOutput) -> np.array:
|
||||
""" Simply identity map """
|
||||
return actor_output.action.numpy()
|
||||
|
||||
@@ -16,7 +16,12 @@ def numpy_policy_preprocessor(device: str = "cpu") -> PolicyPreprocessor:
|
||||
device = torch.device(device)
|
||||
|
||||
def preprocess_obs(obs: np.array) -> rlt.PreprocessedState:
|
||||
return rlt.PreprocessedState.from_tensor(torch.tensor(obs).float().to(device))
|
||||
# convert to batch of one
|
||||
assert (
|
||||
obs.ndim == 1
|
||||
), f"Expect single obs of dim 1, got obs with shape {obs.shape}"
|
||||
state = torch.tensor(obs).float().unsqueeze(0).to(device)
|
||||
return rlt.PreprocessedState.from_tensor(state)
|
||||
|
||||
return preprocess_obs
|
||||
|
||||
@@ -31,7 +36,7 @@ def tiled_numpy_policy_preprocessor(
|
||||
tiled_state = torch.repeat_interleave(
|
||||
obs.unsqueeze(0), repeats=num_actions, axis=0
|
||||
)
|
||||
actions = torch.eye(num_actions)
|
||||
actions = torch.eye(num_actions).to(device)
|
||||
ts_size = tiled_state.size(0)
|
||||
a_size = actions.size(0)
|
||||
assert (
|
||||
|
||||
@@ -9,7 +9,7 @@ from reagent.gym.agents.agent import Agent
|
||||
|
||||
def run_episode(env: Env, agent: Agent, max_steps: Optional[int] = None) -> float:
|
||||
"""
|
||||
Return total reward
|
||||
Return sum of rewards from episode.
|
||||
"""
|
||||
ep_reward = 0.0
|
||||
obs = env.reset()
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
env: CartPole-v0
|
||||
model:
|
||||
DiscreteDQN:
|
||||
trainer_param:
|
||||
actions:
|
||||
- 4
|
||||
- 5
|
||||
rl:
|
||||
gamma: 0.99
|
||||
epsilon: 0.05
|
||||
target_update_rate: 0.1
|
||||
maxq_learning: true
|
||||
temperature: 1.0
|
||||
softmax_policy: false
|
||||
q_network_loss: mse
|
||||
double_q_learning: true
|
||||
minibatch_size: 512
|
||||
minibatches_per_step: 1
|
||||
optimizer:
|
||||
optimizer: ADAM
|
||||
learning_rate: 0.01
|
||||
l2_decay: 0
|
||||
evaluation:
|
||||
calc_cpe_in_training: false
|
||||
net_builder:
|
||||
FullyConnected:
|
||||
sizes:
|
||||
- 128
|
||||
- 64
|
||||
activations:
|
||||
- relu
|
||||
- relu
|
||||
replay_memory_size: 20000
|
||||
train_every_ts: 3
|
||||
train_after_ts: 1
|
||||
num_episodes: 50
|
||||
max_steps: 200
|
||||
last_score_bar: 100.0
|
||||
@@ -0,0 +1,38 @@
|
||||
env: MiniGrid-Empty-5x5-v0
|
||||
model:
|
||||
DiscreteDQN:
|
||||
trainer_param:
|
||||
actions:
|
||||
- 101
|
||||
- 102
|
||||
- 103
|
||||
- 104
|
||||
- 105
|
||||
- 106
|
||||
- 107
|
||||
rl:
|
||||
gamma: 0.99
|
||||
epsilon: 0.05
|
||||
target_update_rate: 0.1
|
||||
maxq_learning: true
|
||||
temperature: 0.01
|
||||
softmax_policy: false
|
||||
q_network_loss: mse
|
||||
double_q_learning: true
|
||||
minibatch_size: 512
|
||||
minibatches_per_step: 1
|
||||
optimizer:
|
||||
optimizer: ADAM
|
||||
learning_rate: 0.1
|
||||
evaluation:
|
||||
calc_cpe_in_training: false
|
||||
net_builder:
|
||||
FullyConnected:
|
||||
sizes: []
|
||||
activations: []
|
||||
replay_memory_size: 2000
|
||||
train_every_ts: 3
|
||||
train_after_ts: 1
|
||||
num_episodes: 100
|
||||
max_steps: 2000
|
||||
last_score_bar: 0.9
|
||||
+105
-335
@@ -4,7 +4,6 @@
|
||||
Environments that require short training and evaluation time (<=10min)
|
||||
can be tested in this file.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -14,376 +13,147 @@ from typing import Optional
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from reagent.core.configuration import make_config_class
|
||||
from reagent.gym.agents.agent import Agent
|
||||
from reagent.gym.agents.replay_buffer_add_fn import replay_buffer_add_fn
|
||||
from reagent.gym.agents.replay_buffer_train_fn import replay_buffer_train_fn
|
||||
from reagent.gym.agents.post_step import train_with_replay_buffer_post_step
|
||||
from reagent.gym.envs.env_factory import EnvFactory
|
||||
from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.policies.samplers.continuous_sampler import GaussianSampler
|
||||
from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
|
||||
from reagent.gym.policies.scorers.continuous_scorer import sac_scorer
|
||||
from reagent.gym.policies.scorers.discrete_scorer import (
|
||||
discrete_dqn_scorer,
|
||||
parametric_dqn_scorer,
|
||||
)
|
||||
from reagent.gym.preprocessors.action_preprocessors.action_preprocessor import (
|
||||
continuous_action_preprocessor,
|
||||
discrete_action_preprocessor,
|
||||
)
|
||||
from reagent.gym.preprocessors.policy_preprocessors.policy_preprocessor import (
|
||||
numpy_policy_preprocessor,
|
||||
tiled_numpy_policy_preprocessor,
|
||||
)
|
||||
from reagent.gym.preprocessors.trainer_preprocessors.trainer_preprocessor import (
|
||||
discrete_dqn_trainer_preprocessor,
|
||||
parametric_dqn_trainer_preprocessor,
|
||||
sac_trainer_preprocessor,
|
||||
)
|
||||
from reagent.gym.runners.gymrunner import run_episode
|
||||
from reagent.json_serialize import from_json
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer, ReplayElement
|
||||
from reagent.parameters import NormalizationData
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
from reagent.tensorboardX import SummaryWriterContext
|
||||
from reagent.test.gym.open_ai_gym_environment import OpenAIGymEnvironment
|
||||
from reagent.test.gym.run_gym import OpenAiGymParameters, create_trainer
|
||||
from reagent.test.base.utils import only_continuous_normalizer
|
||||
from reagent.workflow.model_managers.union import ModelManager__Union
|
||||
from reagent.workflow.types import RewardOptions
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
curr_dir = os.path.dirname(__file__)
|
||||
|
||||
|
||||
DISCRETE_DQN_CARTPOLE_JSON = os.path.join(
|
||||
curr_dir, "configs/discrete_dqn_cartpole_v0.json"
|
||||
)
|
||||
DISCRETE_DQN_CARTPOLE_NUM_EPISODES = 150
|
||||
PARAMETRIC_DQN_CARTPOLE_JSON = os.path.join(
|
||||
curr_dir, "configs/parametric_dqn_cartpole_v0.json"
|
||||
)
|
||||
PARAMETRIC_DQN_CARTPOLE_NUM_EPISODES = 150
|
||||
CARTPOLE_SCORE_BAR = 200
|
||||
|
||||
OPEN_GRIDWORLD_JSON = os.path.join(curr_dir, "configs/open_gridworld.json")
|
||||
OPEN_GRIDWORLD_NUM_EPISODES = 100
|
||||
OPEN_GRIDWORLD_SCORE_BAR = 0.9
|
||||
|
||||
SAC_PENDULUM_JSON = os.path.join(curr_dir, "configs/sac_pendulum_v0.json")
|
||||
SAC_PENDULUM_NUM_EPISODES = 50
|
||||
# Though maximal score is 0, we set lower bar to let tests finish in time
|
||||
PENDULUM_SCORE_BAR = -750
|
||||
|
||||
SEED = 0
|
||||
DISCRETE_DQN_CARTPOLE_CONFIG = os.path.join(
|
||||
curr_dir, "configs/discrete_dqn_cartpole_online.yaml"
|
||||
)
|
||||
|
||||
OPEN_GRIDWORLD_CONFIG = os.path.join(curr_dir, "configs/open_gridworld.yaml")
|
||||
|
||||
|
||||
def extract_config(config_path: str) -> OpenAiGymParameters:
|
||||
with open(config_path, "r") as f:
|
||||
json_data = json.loads(f.read())
|
||||
json_data["evaluation"] = {
|
||||
"calc_cpe_in_training": False
|
||||
} # Slow without disabling
|
||||
json_data["use_gpu"] = False
|
||||
|
||||
return from_json(json_data, OpenAiGymParameters)
|
||||
def build_normalizer(env):
|
||||
if isinstance(env.observation_space, gym.spaces.Box):
|
||||
assert (
|
||||
len(env.observation_space.shape) == 1
|
||||
), f"{env.observation_space} not supported."
|
||||
return {
|
||||
"state": NormalizationData(
|
||||
dense_normalization_parameters=only_continuous_normalizer(
|
||||
list(range(env.observation_space.shape[0])),
|
||||
env.observation_space.low,
|
||||
env.observation_space.high,
|
||||
)
|
||||
)
|
||||
}
|
||||
elif isinstance(env.observation_space, gym.spaces.Dict):
|
||||
# assuming env.observation_space is image
|
||||
return None
|
||||
else:
|
||||
raise NotImplementedError(f"{env.observation_space} not supported")
|
||||
|
||||
|
||||
def build_trainer(config):
|
||||
return create_trainer(config, OpenAIGymEnvironment(config.env))
|
||||
def run_test(
|
||||
env: str,
|
||||
model: ModelManager__Union,
|
||||
replay_memory_size: int,
|
||||
train_every_ts: int,
|
||||
train_after_ts: int,
|
||||
num_episodes: int,
|
||||
max_steps: Optional[int],
|
||||
last_score_bar: float,
|
||||
):
|
||||
env = EnvFactory.make(env)
|
||||
normalization = build_normalizer(env)
|
||||
logger.info(f"Normalization is {normalization}")
|
||||
|
||||
manager = model.value
|
||||
trainer = manager.initialize_trainer(
|
||||
use_gpu=False,
|
||||
reward_options=RewardOptions(),
|
||||
normalization_data_map=normalization,
|
||||
)
|
||||
|
||||
policy = manager.create_policy(env)
|
||||
replay_buffer = ReplayBuffer.create_from_env(
|
||||
env=env,
|
||||
replay_memory_size=replay_memory_size,
|
||||
batch_size=trainer.minibatch_size,
|
||||
)
|
||||
|
||||
post_step = train_with_replay_buffer_post_step(
|
||||
replay_buffer=replay_buffer,
|
||||
trainer=trainer,
|
||||
trainer_preprocessor=manager.create_trainer_preprocessor(),
|
||||
training_freq=train_every_ts,
|
||||
batch_size=trainer.minibatch_size,
|
||||
replay_burnin=train_after_ts,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
policy=policy,
|
||||
action_preprocessor=manager.create_action_preprocessor(),
|
||||
post_transition_callback=post_step,
|
||||
)
|
||||
|
||||
def run(env: gym.Env, agent: Agent, num_episodes: int, max_steps: Optional[int] = None):
|
||||
reward_history = []
|
||||
for i in range(num_episodes):
|
||||
print(f"Starting episode {i}")
|
||||
ep_reward = run_episode(env, agent, max_steps)
|
||||
print(f"Finished episode {i} with reward {ep_reward}")
|
||||
logger.info(f"running episode {i}")
|
||||
ep_reward = run_episode(env=env, agent=agent, max_steps=max_steps)
|
||||
reward_history.append(ep_reward)
|
||||
|
||||
assert reward_history[-1] >= last_score_bar, (
|
||||
f"reward after {len(reward_history)} episodes is {reward_history[-1]},"
|
||||
f"less than < {last_score_bar}...\n"
|
||||
f"Full reward history: {reward_history}"
|
||||
)
|
||||
|
||||
return reward_history
|
||||
|
||||
|
||||
def run_discrete_dqn_cartpole(config):
|
||||
trainer = build_trainer(config)
|
||||
num_episodes = DISCRETE_DQN_CARTPOLE_NUM_EPISODES
|
||||
env = EnvFactory.make(config.env)
|
||||
wrapped_env = OpenAIGymEnvironment(config.env)
|
||||
action_shape = np.array(wrapped_env.actions).shape
|
||||
action_type = np.int32
|
||||
replay_buffer = ReplayBuffer(
|
||||
observation_shape=env.reset().shape,
|
||||
stack_size=1,
|
||||
replay_capacity=config.max_replay_memory_size,
|
||||
batch_size=trainer.minibatch_size,
|
||||
observation_dtype=np.float32,
|
||||
action_shape=action_shape,
|
||||
action_dtype=action_type,
|
||||
reward_shape=(),
|
||||
reward_dtype=np.float32,
|
||||
extra_storage_types=[
|
||||
ReplayElement("possible_actions_mask", action_shape, action_type),
|
||||
ReplayElement("log_prob", (), np.float32),
|
||||
],
|
||||
)
|
||||
def run_from_config(path):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(path, "r") as f:
|
||||
config_dict = yaml.load(f.read())
|
||||
|
||||
actions = wrapped_env.actions
|
||||
normalization = wrapped_env.normalization
|
||||
policy = Policy(
|
||||
scorer=discrete_dqn_scorer(trainer.q_network),
|
||||
sampler=SoftmaxActionSampler(temperature=0.01),
|
||||
policy_preprocessor=numpy_policy_preprocessor(),
|
||||
)
|
||||
agent = Agent(
|
||||
policy=policy,
|
||||
action_preprocessor=discrete_action_preprocessor,
|
||||
replay_buffer=replay_buffer,
|
||||
replay_buffer_add_fn=replay_buffer_add_fn,
|
||||
replay_buffer_train_fn=replay_buffer_train_fn(
|
||||
trainer=trainer,
|
||||
trainer_preprocessor=discrete_dqn_trainer_preprocessor(
|
||||
len(actions), normalization
|
||||
),
|
||||
training_freq=config.run_details.train_every_ts,
|
||||
batch_size=trainer.minibatch_size,
|
||||
replay_burnin=config.run_details.train_after_ts,
|
||||
),
|
||||
)
|
||||
@make_config_class(run_test)
|
||||
class ConfigClass:
|
||||
pass
|
||||
|
||||
reward_history = run(
|
||||
env=env,
|
||||
agent=agent,
|
||||
num_episodes=num_episodes,
|
||||
max_steps=config.run_details.max_steps,
|
||||
)
|
||||
return reward_history
|
||||
|
||||
|
||||
def run_parametric_dqn_cartpole(config):
|
||||
trainer = build_trainer(config)
|
||||
num_episodes = PARAMETRIC_DQN_CARTPOLE_NUM_EPISODES
|
||||
env = EnvFactory.make(config.env)
|
||||
wrapped_env = OpenAIGymEnvironment(config.env)
|
||||
action_shape = np.array(wrapped_env.actions).shape
|
||||
action_type = np.float32
|
||||
replay_buffer = ReplayBuffer(
|
||||
observation_shape=env.reset().shape,
|
||||
stack_size=1,
|
||||
replay_capacity=config.max_replay_memory_size,
|
||||
batch_size=trainer.minibatch_size,
|
||||
observation_dtype=np.float32,
|
||||
action_shape=action_shape,
|
||||
action_dtype=action_type,
|
||||
reward_shape=(),
|
||||
reward_dtype=np.float32,
|
||||
extra_storage_types=[
|
||||
ReplayElement("possible_actions_mask", action_shape, action_type),
|
||||
ReplayElement("log_prob", (), np.float32),
|
||||
],
|
||||
)
|
||||
|
||||
actions = wrapped_env.actions
|
||||
normalization = wrapped_env.normalization
|
||||
|
||||
policy = Policy(
|
||||
scorer=parametric_dqn_scorer(len(actions), trainer.q_network),
|
||||
sampler=SoftmaxActionSampler(temperature=0.01),
|
||||
policy_preprocessor=tiled_numpy_policy_preprocessor(len(actions)),
|
||||
)
|
||||
agent = Agent(
|
||||
policy=policy,
|
||||
action_preprocessor=discrete_action_preprocessor,
|
||||
replay_buffer=replay_buffer,
|
||||
replay_buffer_add_fn=replay_buffer_add_fn,
|
||||
replay_buffer_train_fn=replay_buffer_train_fn(
|
||||
trainer=trainer,
|
||||
trainer_preprocessor=parametric_dqn_trainer_preprocessor(
|
||||
len(actions), normalization
|
||||
),
|
||||
training_freq=config.run_details.train_every_ts,
|
||||
batch_size=trainer.minibatch_size,
|
||||
replay_burnin=config.run_details.train_after_ts,
|
||||
),
|
||||
)
|
||||
|
||||
reward_history = run(
|
||||
env=env,
|
||||
agent=agent,
|
||||
num_episodes=num_episodes,
|
||||
max_steps=config.run_details.max_steps,
|
||||
)
|
||||
return reward_history
|
||||
|
||||
|
||||
def run_open_gridworld(config):
|
||||
trainer = build_trainer(config)
|
||||
num_episodes = OPEN_GRIDWORLD_NUM_EPISODES
|
||||
env = EnvFactory.make(config.env)
|
||||
wrapped_env = OpenAIGymEnvironment(config.env)
|
||||
action_shape = np.array(wrapped_env.actions).shape
|
||||
action_type = np.int32
|
||||
replay_buffer = ReplayBuffer(
|
||||
observation_shape=env.reset().shape,
|
||||
stack_size=1,
|
||||
replay_capacity=config.max_replay_memory_size,
|
||||
batch_size=trainer.minibatch_size,
|
||||
observation_dtype=np.float32,
|
||||
action_shape=action_shape,
|
||||
action_dtype=action_type,
|
||||
reward_shape=(),
|
||||
reward_dtype=np.float32,
|
||||
extra_storage_types=[
|
||||
ReplayElement("possible_actions_mask", action_shape, action_type),
|
||||
ReplayElement("log_prob", (), np.float32),
|
||||
],
|
||||
)
|
||||
|
||||
actions = wrapped_env.actions
|
||||
normalization = wrapped_env.normalization
|
||||
policy = Policy(
|
||||
scorer=discrete_dqn_scorer(trainer.q_network),
|
||||
sampler=SoftmaxActionSampler(temperature=0.01),
|
||||
policy_preprocessor=numpy_policy_preprocessor(),
|
||||
)
|
||||
agent = Agent(
|
||||
policy=policy,
|
||||
action_preprocessor=discrete_action_preprocessor,
|
||||
replay_buffer=replay_buffer,
|
||||
replay_buffer_add_fn=replay_buffer_add_fn,
|
||||
replay_buffer_train_fn=replay_buffer_train_fn(
|
||||
trainer=trainer,
|
||||
trainer_preprocessor=discrete_dqn_trainer_preprocessor(
|
||||
len(actions), normalization
|
||||
),
|
||||
training_freq=config.run_details.train_every_ts,
|
||||
batch_size=trainer.minibatch_size,
|
||||
replay_burnin=config.run_details.train_after_ts,
|
||||
),
|
||||
)
|
||||
|
||||
reward_history = run(
|
||||
env=env,
|
||||
agent=agent,
|
||||
num_episodes=num_episodes,
|
||||
max_steps=config.run_details.max_steps,
|
||||
)
|
||||
return reward_history
|
||||
|
||||
|
||||
def run_sac_pendulum(config):
|
||||
trainer = build_trainer(config)
|
||||
num_episodes = SAC_PENDULUM_NUM_EPISODES
|
||||
env = EnvFactory.make(config.env)
|
||||
action_shape = (1,)
|
||||
action_type = np.float32
|
||||
replay_buffer = ReplayBuffer(
|
||||
observation_shape=env.reset().shape,
|
||||
stack_size=1,
|
||||
replay_capacity=config.max_replay_memory_size,
|
||||
batch_size=trainer.minibatch_size,
|
||||
observation_dtype=np.float32,
|
||||
action_shape=action_shape,
|
||||
action_dtype=action_type,
|
||||
reward_shape=(),
|
||||
reward_dtype=np.float32,
|
||||
extra_storage_types=[
|
||||
ReplayElement("possible_actions_mask", action_shape, action_type),
|
||||
ReplayElement("log_prob", (), np.float32),
|
||||
],
|
||||
)
|
||||
|
||||
policy = Policy(
|
||||
scorer=sac_scorer(trainer.actor_network),
|
||||
sampler=GaussianSampler(
|
||||
trainer.actor_network,
|
||||
trainer.min_action_range_tensor_serving,
|
||||
trainer.max_action_range_tensor_serving,
|
||||
trainer.min_action_range_tensor_training,
|
||||
trainer.max_action_range_tensor_training,
|
||||
),
|
||||
policy_preprocessor=numpy_policy_preprocessor(),
|
||||
)
|
||||
agent = Agent(
|
||||
policy=policy,
|
||||
action_preprocessor=continuous_action_preprocessor,
|
||||
replay_buffer=replay_buffer,
|
||||
replay_buffer_add_fn=replay_buffer_add_fn,
|
||||
replay_buffer_train_fn=replay_buffer_train_fn(
|
||||
trainer=trainer,
|
||||
trainer_preprocessor=sac_trainer_preprocessor(),
|
||||
training_freq=config.run_details.train_every_ts,
|
||||
batch_size=trainer.minibatch_size,
|
||||
replay_burnin=config.run_details.train_after_ts,
|
||||
),
|
||||
)
|
||||
|
||||
reward_history = run(
|
||||
env=env,
|
||||
agent=agent,
|
||||
num_episodes=num_episodes,
|
||||
max_steps=config.run_details.max_steps,
|
||||
)
|
||||
return reward_history
|
||||
config = ConfigClass(**config_dict)
|
||||
return run_test(**config.asdict())
|
||||
|
||||
|
||||
class TestGym(unittest.TestCase):
|
||||
def setUp(self):
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
SummaryWriterContext._reset_globals()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
np.random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
random.seed(SEED)
|
||||
|
||||
def test_discrete_dqn_cartpole(self):
|
||||
config = extract_config(DISCRETE_DQN_CARTPOLE_JSON)
|
||||
self.assertTrue(config.model_type == "pytorch_discrete_dqn")
|
||||
reward_history = run_discrete_dqn_cartpole(config)
|
||||
self.assertTrue(
|
||||
reward_history[-1] >= CARTPOLE_SCORE_BAR,
|
||||
"reward after %d episodes is %f < %f...\nFull reward history: %s"
|
||||
% (
|
||||
len(reward_history),
|
||||
reward_history[-1],
|
||||
CARTPOLE_SCORE_BAR,
|
||||
reward_history,
|
||||
),
|
||||
)
|
||||
reward_history = run_from_config(DISCRETE_DQN_CARTPOLE_CONFIG)
|
||||
logger.info(f"Discrete DQN passes, with reward_history={reward_history}.")
|
||||
|
||||
def test_open_gridworld(self):
|
||||
config = extract_config(OPEN_GRIDWORLD_JSON)
|
||||
self.assertTrue(config.model_type == "pytorch_discrete_dqn")
|
||||
reward_history = run_open_gridworld(config)
|
||||
self.assertTrue(
|
||||
reward_history[-1] >= OPEN_GRIDWORLD_SCORE_BAR,
|
||||
"reward after %d episodes is %f < %f...\nFull reward history: %s"
|
||||
% (
|
||||
len(reward_history),
|
||||
reward_history[-1],
|
||||
OPEN_GRIDWORLD_SCORE_BAR,
|
||||
reward_history,
|
||||
),
|
||||
)
|
||||
reward_history = run_from_config(OPEN_GRIDWORLD_CONFIG)
|
||||
logger.info(f"Open GridWorld passes, with reward_history={reward_history}.")
|
||||
|
||||
@unittest.skip("Skipping since training takes more than 10 min.")
|
||||
@unittest.skip("To be implemented...")
|
||||
def test_parametric_dqn_cartpole(self):
|
||||
config = extract_config(PARAMETRIC_DQN_CARTPOLE_JSON)
|
||||
self.assertTrue(config.model_type == "pytorch_parametric_dqn")
|
||||
reward_history = run_parametric_dqn_cartpole(config)
|
||||
self.assertTrue(
|
||||
reward_history[-1] >= CARTPOLE_SCORE_BAR,
|
||||
"reward after %d episodes is %f < %f\nFull reward history: %s"
|
||||
% (
|
||||
len(reward_history),
|
||||
reward_history[-1],
|
||||
CARTPOLE_SCORE_BAR,
|
||||
reward_history,
|
||||
),
|
||||
)
|
||||
raise NotImplementedError("TODO: make model manager for PDQN")
|
||||
|
||||
@unittest.skip("Skipping since training takes more than 10 min.")
|
||||
@unittest.skip("To be implemented...")
|
||||
def test_sac_pendulum(self):
|
||||
config = extract_config(SAC_PENDULUM_JSON)
|
||||
self.assertTrue(config.model_type == "soft_actor_critic")
|
||||
reward_history = run_sac_pendulum(config)
|
||||
self.assertTrue(
|
||||
reward_history[-1] >= PENDULUM_SCORE_BAR,
|
||||
"reward after %d episodes is %f < %f\nFull reward history: %s"(
|
||||
len(reward_history),
|
||||
reward_history[-1],
|
||||
PENDULUM_SCORE_BAR,
|
||||
reward_history,
|
||||
),
|
||||
)
|
||||
raise NotImplementedError("TODO: make model manager for SAC")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+4
-19
@@ -8,7 +8,6 @@ from typing import Any, Callable, Optional
|
||||
import numpy as np
|
||||
import reagent.types as rlt
|
||||
import torch
|
||||
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
|
||||
|
||||
|
||||
class Sampler(ABC):
|
||||
@@ -43,24 +42,10 @@ PolicyPreprocessor = Callable[[Any], Any]
|
||||
ActionPreprocessor = Callable[[rlt.ActorOutput], np.array]
|
||||
|
||||
|
||||
ObservationType = Any
|
||||
RewardType = float
|
||||
TerminalType = bool
|
||||
PossibleActionsMaskType = Optional[torch.Tensor]
|
||||
ReplayBufferAddFn = Callable[
|
||||
[
|
||||
ReplayBuffer,
|
||||
ObservationType,
|
||||
rlt.ActorOutput,
|
||||
RewardType,
|
||||
TerminalType,
|
||||
PossibleActionsMaskType,
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
# Called in post_step of Agent to train on sampled batch from RB
|
||||
ReplayBufferTrainFn = Callable[[ReplayBuffer], None]
|
||||
""" Called after env.step(action)
|
||||
Args: (state, actor_output, reward, terminal, possible_actions_mask)
|
||||
"""
|
||||
PostStep = Callable[[Any, rlt.ActorOutput, float, bool, Optional[torch.Tensor]], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -34,7 +34,9 @@ import os
|
||||
import pickle
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym.spaces import Box, Discrete
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -196,6 +198,53 @@ class ReplayBuffer(object):
|
||||
def size(self) -> int:
|
||||
return self._num_valid_indices
|
||||
|
||||
@classmethod
|
||||
def create_from_env(
|
||||
cls,
|
||||
env: gym.Env,
|
||||
replay_memory_size: int,
|
||||
batch_size: int,
|
||||
stack_size: int = 1,
|
||||
store_possible_actions_mask: bool = True,
|
||||
store_log_prob: bool = True,
|
||||
):
|
||||
assert isinstance(
|
||||
env.observation_space, Box
|
||||
), f"observation space has type {type(env.observation_space)}"
|
||||
if isinstance(env.action_space, Box):
|
||||
actions_type = env.action_space.dtype # type: ignore
|
||||
actions_shape = env.action_space.shape # type: ignore
|
||||
elif isinstance(env.action_space, Discrete):
|
||||
# TODO: don't store one-hot encoded actions in RB.
|
||||
actions_type = env.action_space.dtype # type: ignore
|
||||
actions_shape = (env.action_space.n,) # type: ignore
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"env.action_space {type(env.action_space)} not supported."
|
||||
)
|
||||
|
||||
extra_storage_types = []
|
||||
if store_possible_actions_mask:
|
||||
extra_storage_types.append(
|
||||
ReplayElement("possible_actions_mask", actions_shape, np.int64)
|
||||
)
|
||||
|
||||
if store_log_prob:
|
||||
extra_storage_types.append(ReplayElement("log_prob", (), np.float32))
|
||||
|
||||
return cls(
|
||||
stack_size=stack_size,
|
||||
replay_capacity=replay_memory_size,
|
||||
batch_size=batch_size,
|
||||
observation_shape=env.observation_space.shape, # type: ignore
|
||||
observation_dtype=env.observation_space.dtype, # type: ignore
|
||||
action_shape=actions_shape,
|
||||
action_dtype=actions_type,
|
||||
reward_shape=(),
|
||||
reward_dtype=np.float32,
|
||||
extra_storage_types=extra_storage_types,
|
||||
)
|
||||
|
||||
def set_index_valid_status(self, idx: int, is_valid: bool):
|
||||
old_valid = self._is_index_valid[idx]
|
||||
if not old_valid and is_valid:
|
||||
|
||||
@@ -280,9 +280,7 @@ class TestQueryData(ReagentSQLTestBase):
|
||||
)
|
||||
query_data(
|
||||
input_table_spec=ts,
|
||||
states=[0, 1, 4, 5, 6],
|
||||
actions=["L", "R", "U", "D"],
|
||||
metrics=["reward"],
|
||||
custom_reward_expression=custom_reward_expression,
|
||||
multi_steps=multi_steps,
|
||||
gamma=gamma,
|
||||
|
||||
+238
-154
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from pyspark.sql.functions import col, crc32, udf
|
||||
from pyspark.sql.functions import col, crc32, explode, map_keys, udf
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
BooleanType,
|
||||
@@ -22,7 +22,9 @@ logger = logging.getLogger(__name__)
|
||||
MAX_UINT32 = 4294967295
|
||||
|
||||
|
||||
def calc_custom_reward(sqlCtx, df, custom_reward_expression: str):
|
||||
def calc_custom_reward(df, custom_reward_expression: str):
|
||||
sqlCtx = get_spark_session()
|
||||
# create a temporary table for running sql
|
||||
temp_table_name = "_tmp_calc_reward_df"
|
||||
temp_reward_name = "_tmp_reward_col"
|
||||
df.createOrReplaceTempView(temp_table_name)
|
||||
@@ -33,28 +35,36 @@ def calc_custom_reward(sqlCtx, df, custom_reward_expression: str):
|
||||
return df.drop("reward").withColumnRenamed(temp_reward_name, "reward")
|
||||
|
||||
|
||||
def calc_reward_multi_steps(sqlCtx, df, multi_steps: int, gamma: float):
|
||||
def calc_reward_multi_steps(df, multi_steps: int, gamma: float):
|
||||
# assumes df[reward] is array[float] and 1 <= len(df[reward]) <= multi_steps
|
||||
# computes r_0 + gamma * (r_1 + gamma * (r_2 + ... ))
|
||||
expr = f"AGGREGATE(REVERSE(reward), FLOAT(0), (s, x) -> FLOAT({gamma}) * s + x)"
|
||||
return calc_custom_reward(sqlCtx, df, expr)
|
||||
return calc_custom_reward(df, expr)
|
||||
|
||||
|
||||
def perform_preprocessing(
|
||||
sqlCtx,
|
||||
table_spec: TableSpec,
|
||||
states: List[int],
|
||||
actions: List[str],
|
||||
metrics: List[str],
|
||||
def set_reward_col_as_reward(
|
||||
df,
|
||||
custom_reward_expression: Optional[str] = None,
|
||||
sample_range: Optional[Tuple[float, float]] = None,
|
||||
multi_steps: Optional[int] = None,
|
||||
gamma: Optional[float] = None,
|
||||
):
|
||||
""" Perform preprocessing of given dataframe df.
|
||||
Preprocessing steps include calculating the reward,
|
||||
performing sparse-to-dense for mapped columns like state_features
|
||||
and metrics, and subsampling based on sample_range.
|
||||
If multi_steps is set (with gamma), then we assume multi_steps RL setting.
|
||||
# after this, reward column should be set to be the reward
|
||||
if custom_reward_expression is not None:
|
||||
df = calc_custom_reward(df, custom_reward_expression)
|
||||
elif multi_steps is not None:
|
||||
assert gamma is not None
|
||||
df = calc_reward_multi_steps(df, multi_steps, gamma)
|
||||
return df
|
||||
|
||||
|
||||
def hash_mdp_id_and_subsample(df, sample_range: Optional[Tuple[float, float]] = None):
|
||||
""" Since mdp_id is a string but Pytorch Tensors do not store strings,
|
||||
we hash them with crc32, which is treated as a cryptographic hash
|
||||
(with range [0, MAX_UINT32-1]). We also perform an optional subsampling
|
||||
based on this hash value.
|
||||
NOTE: we're assuming no collisions in this hash! Otherwise, two mdp_ids
|
||||
can be indistinguishable after the hash.
|
||||
TODO: change this to a deterministic subsample.
|
||||
"""
|
||||
if sample_range:
|
||||
assert (
|
||||
@@ -63,138 +73,182 @@ def perform_preprocessing(
|
||||
and sample_range[1] <= 100.0
|
||||
), f"{sample_range} is invalid."
|
||||
|
||||
df = sqlCtx.sql(f"SELECT * FROM {table_spec.table_name}")
|
||||
|
||||
# after this, reward column should be set to be the reward now
|
||||
if custom_reward_expression is not None:
|
||||
df = calc_custom_reward(sqlCtx, df, custom_reward_expression)
|
||||
elif multi_steps is not None:
|
||||
assert gamma is not None
|
||||
df = calc_reward_multi_steps(sqlCtx, df, multi_steps, gamma)
|
||||
# assume single step case reward is already a column
|
||||
|
||||
def get_step(next_col):
|
||||
""" get step count """
|
||||
if multi_steps is not None:
|
||||
return min(len(next_col), multi_steps)
|
||||
else:
|
||||
return 1
|
||||
|
||||
get_step_udf = udf(get_step, LongType())
|
||||
df = df.withColumn("step", get_step_udf("next_state_features"))
|
||||
|
||||
def make_next_udf(return_type):
|
||||
""" return udf to get next item, provided item type """
|
||||
|
||||
def get_next(next_col):
|
||||
""" generic function to get the next item """
|
||||
if multi_steps is not None:
|
||||
step = min(len(next_col), multi_steps)
|
||||
return next_col[step - 1]
|
||||
else:
|
||||
return next_col
|
||||
|
||||
return udf(get_next, return_type)
|
||||
|
||||
df = df.withColumn("time_diff", make_next_udf(LongType())("time_diff"))
|
||||
|
||||
def make_sparse2dense(df, col_name: str, possible_keys: List):
|
||||
""" Given a list of possible keys, convert sparse map to dense array.
|
||||
In our example, both value_type is assumed to be a float.
|
||||
"""
|
||||
output_type = StructType(
|
||||
[
|
||||
StructField("presence", ArrayType(BooleanType()), False),
|
||||
StructField("dense", ArrayType(FloatType()), False),
|
||||
]
|
||||
)
|
||||
|
||||
def sparse2dense(map_col):
|
||||
assert isinstance(
|
||||
map_col, dict
|
||||
), f"{map_col} has type {type(map_col)} and is not a dict."
|
||||
presence = []
|
||||
dense = []
|
||||
for key in possible_keys:
|
||||
val = map_col.get(key, None)
|
||||
if val is not None:
|
||||
presence.append(True)
|
||||
dense.append(float(val))
|
||||
else:
|
||||
presence.append(False)
|
||||
dense.append(0.0)
|
||||
return presence, dense
|
||||
|
||||
sparse2dense_udf = udf(sparse2dense, output_type)
|
||||
df = df.withColumn(col_name, sparse2dense_udf(col_name))
|
||||
df = df.withColumn(f"{col_name}_presence", col(f"{col_name}.presence"))
|
||||
df = df.withColumn(col_name, col(f"{col_name}.dense"))
|
||||
return df
|
||||
|
||||
df = make_sparse2dense(df, "state_features", states)
|
||||
|
||||
next_map_udf = make_next_udf(MapType(LongType(), FloatType()))
|
||||
df = df.withColumn("next_state_features", next_map_udf("next_state_features"))
|
||||
df = make_sparse2dense(df, "next_state_features", states)
|
||||
|
||||
df = df.withColumn("metrics", next_map_udf("metrics"))
|
||||
df = make_sparse2dense(df, "metrics", metrics)
|
||||
|
||||
def where(arr: List[str]):
|
||||
""" locate the index of item in arr, len(arr) if not found. """
|
||||
|
||||
def find(item: str):
|
||||
for i, arr_item in enumerate(arr):
|
||||
if arr_item == item:
|
||||
return i
|
||||
return len(arr)
|
||||
|
||||
return find
|
||||
|
||||
where_udf = udf(where(actions), LongType())
|
||||
df = df.withColumn("action", where_udf("action"))
|
||||
df = df.withColumn(
|
||||
"next_action", where_udf(make_next_udf(LongType())("next_action"))
|
||||
)
|
||||
|
||||
def get_not_terminal(next_action):
|
||||
""" terminal state iff next_action is "" (i.e. onehot len(actions))"""
|
||||
return next_action < len(actions)
|
||||
|
||||
get_not_terminal_udf = udf(get_not_terminal, BooleanType())
|
||||
df = df.withColumn("not_terminal", get_not_terminal_udf("next_action"))
|
||||
|
||||
def onehot(arr: List[str]):
|
||||
""" one-hot encode elements of arr depending on their existence in target """
|
||||
|
||||
def encode(target: List[str]):
|
||||
result = [0] * len(arr)
|
||||
for i, arr_item in enumerate(arr):
|
||||
if arr_item in target:
|
||||
result[i] = 1
|
||||
return result
|
||||
|
||||
return encode
|
||||
|
||||
onehot_udf = udf(onehot(actions), ArrayType(LongType()))
|
||||
df = df.withColumn("possible_actions_mask", onehot_udf("possible_actions"))
|
||||
df = df.withColumn(
|
||||
"possible_next_actions_mask",
|
||||
onehot_udf(make_next_udf(ArrayType(LongType()))("possible_next_actions")),
|
||||
)
|
||||
|
||||
# assuming use_seq_num_diff_as_time_diff = False for now
|
||||
df = df.withColumn("sequence_number", col("sequence_number_ordinal"))
|
||||
|
||||
# crc32 is treated as a cryptographic hash with range [0, MAX_UINT32-1]
|
||||
# Note: we're assuming no collisions!
|
||||
df = df.withColumn("mdp_id", crc32(col("mdp_id")))
|
||||
if sample_range:
|
||||
lower_bound = sample_range[0] / 100.0 * MAX_UINT32
|
||||
upper_bound = sample_range[1] / 100.0 * MAX_UINT32
|
||||
df = df.filter((lower_bound <= col("mdp_id")) & (col("mdp_id") <= upper_bound))
|
||||
return df
|
||||
|
||||
# select all the relevant columns and perform type conversions
|
||||
|
||||
def make_sparse2dense(df, col_name: str, possible_keys: List):
|
||||
""" Given a list of possible keys, convert sparse map to dense array.
|
||||
In our example, both value_type is assumed to be a float.
|
||||
"""
|
||||
output_type = StructType(
|
||||
[
|
||||
StructField("presence", ArrayType(BooleanType()), False),
|
||||
StructField("dense", ArrayType(FloatType()), False),
|
||||
]
|
||||
)
|
||||
|
||||
def sparse2dense(map_col):
|
||||
assert isinstance(
|
||||
map_col, dict
|
||||
), f"{map_col} has type {type(map_col)} and is not a dict."
|
||||
presence = []
|
||||
dense = []
|
||||
for key in possible_keys:
|
||||
val = map_col.get(key, None)
|
||||
if val is not None:
|
||||
presence.append(True)
|
||||
dense.append(float(val))
|
||||
else:
|
||||
presence.append(False)
|
||||
dense.append(0.0)
|
||||
return presence, dense
|
||||
|
||||
sparse2dense_udf = udf(sparse2dense, output_type)
|
||||
df = df.withColumn(col_name, sparse2dense_udf(col_name))
|
||||
df = df.withColumn(f"{col_name}_presence", col(f"{col_name}.presence"))
|
||||
df = df.withColumn(col_name, col(f"{col_name}.dense"))
|
||||
return df
|
||||
|
||||
|
||||
#################################################
|
||||
# Below are some UDFs we use for preprocessing. #
|
||||
#################################################
|
||||
|
||||
|
||||
def make_get_step_udf(multi_steps: Optional[int]):
|
||||
""" Get step count by taking length of next_states_features array. """
|
||||
|
||||
def get_step(col: List):
|
||||
return 1 if multi_steps is None else min(len(col), multi_steps)
|
||||
|
||||
return udf(get_step, LongType())
|
||||
|
||||
|
||||
def make_next_udf(multi_steps: Optional[int], return_type):
|
||||
""" Generic udf to get next (after multi_steps) item, provided item type. """
|
||||
|
||||
def get_next(next_col):
|
||||
return (
|
||||
next_col
|
||||
if multi_steps is None
|
||||
else next_col[min(len(next_col), multi_steps) - 1]
|
||||
)
|
||||
|
||||
return udf(get_next, return_type)
|
||||
|
||||
|
||||
def make_where_udf(arr: List[str]):
|
||||
""" Return index of item in arr, and len(arr) if not found. """
|
||||
|
||||
def find(item: str):
|
||||
for i, arr_item in enumerate(arr):
|
||||
if arr_item == item:
|
||||
return i
|
||||
return len(arr)
|
||||
|
||||
return udf(find, LongType())
|
||||
|
||||
|
||||
def make_not_terminal_udf(actions: List[str]):
|
||||
""" Return true iff next_action is terminal (i.e. idx = len(actions)). """
|
||||
|
||||
def get_not_terminal(next_action):
|
||||
return next_action < len(actions)
|
||||
|
||||
return udf(get_not_terminal, BooleanType())
|
||||
|
||||
|
||||
def make_existence_bitvector_udf(arr: List[str]):
|
||||
""" one-hot encode elements of target depending on their existence in arr. """
|
||||
|
||||
default = [0] * len(arr)
|
||||
|
||||
def encode(target: List[str]):
|
||||
bitvec = default.copy()
|
||||
for i, arr_item in enumerate(arr):
|
||||
if arr_item in target:
|
||||
bitvec[i] = 1
|
||||
return bitvec
|
||||
|
||||
return udf(encode, ArrayType(LongType()))
|
||||
|
||||
|
||||
def perform_preprocessing(
|
||||
df,
|
||||
states: List[int],
|
||||
actions: List[str],
|
||||
metrics: List[str],
|
||||
multi_steps: Optional[int] = None,
|
||||
):
|
||||
""" Perform (1) sparse-to-dense, (2) preprocessing for actions,
|
||||
and (3) other miscellaneous columns.
|
||||
|
||||
(1) For each column of type Map, w/ name X, output two columns.
|
||||
Map values are assumed to be scalar. This process is called sparse-to-dense.
|
||||
X = {"state_features", "next_state_features", "metrics"}.
|
||||
(a) Replace column X with a dense repesentation of the inputted (sparse) map.
|
||||
Dense representation is to concatenate map values into a list.
|
||||
(b) Create new column X_presence, which is a list of same length as (a) and
|
||||
the ith entry is 1 iff the key was present in the original map.
|
||||
|
||||
(2) Inputted actions and possible_actions are strings, which isn't supported
|
||||
for PyTorch Tensors. Here, we represent them with LongType.
|
||||
(a) action and next_action are strings, so simply return their position
|
||||
in the action_space (as given by argument actions).
|
||||
(b) possible_actions and possible_next_actions are list of strs, so
|
||||
return an existence bitvector of length len(actions), where ith
|
||||
index is true iff actions[i] was in the list.
|
||||
|
||||
(3) Miscellaneous columns are step, time_diff, sequence_number, not_terminal
|
||||
"""
|
||||
|
||||
# step refers to n in n-step RL; special case when approaching terminal
|
||||
df = df.withColumn("step", make_get_step_udf(multi_steps)("next_state_features"))
|
||||
|
||||
# take the next time_diff
|
||||
next_long_udf = make_next_udf(multi_steps, LongType())
|
||||
df = df.withColumn("time_diff", next_long_udf("time_diff"))
|
||||
|
||||
# sparse-to-dense of states and metrics
|
||||
next_map_udf = make_next_udf(multi_steps, MapType(LongType(), FloatType()))
|
||||
df = df.withColumn("next_state_features", next_map_udf("next_state_features"))
|
||||
df = df.withColumn("metrics", next_map_udf("metrics"))
|
||||
df = make_sparse2dense(df, "state_features", states)
|
||||
df = make_sparse2dense(df, "next_state_features", states)
|
||||
df = make_sparse2dense(df, "metrics", metrics)
|
||||
|
||||
# turn string actions into indices
|
||||
where_udf = make_where_udf(actions)
|
||||
df = df.withColumn("action", where_udf("action"))
|
||||
df = df.withColumn("next_action", where_udf(next_long_udf("next_action")))
|
||||
|
||||
# turn List[str] possible_actions into existence bitvectors
|
||||
next_long_arr_udf = make_next_udf(multi_steps, ArrayType(LongType()))
|
||||
existence_bitvector_udf = make_existence_bitvector_udf(actions)
|
||||
df = df.withColumn(
|
||||
"possible_actions_mask", existence_bitvector_udf("possible_actions")
|
||||
)
|
||||
df = df.withColumn(
|
||||
"possible_next_actions_mask",
|
||||
existence_bitvector_udf(next_long_arr_udf("possible_next_actions")),
|
||||
)
|
||||
|
||||
# calculate not_terminal
|
||||
not_terminal_udf = make_not_terminal_udf(actions)
|
||||
df = df.withColumn("not_terminal", not_terminal_udf("next_action"))
|
||||
|
||||
# assuming use_seq_num_diff_as_time_diff = False for now
|
||||
df = df.withColumn("sequence_number", col("sequence_number_ordinal"))
|
||||
return df
|
||||
|
||||
|
||||
def select_relevant_columns(df):
|
||||
""" Select all the relevant columns and perform type conversions. """
|
||||
return df.select(
|
||||
col("reward").cast(FloatType()),
|
||||
col("state_features").cast(ArrayType(FloatType())),
|
||||
@@ -216,30 +270,60 @@ def perform_preprocessing(
|
||||
)
|
||||
|
||||
|
||||
def get_distinct_keys(df, col_name, is_col_arr_map=False):
|
||||
""" Return list of distinct keys.
|
||||
Set is_col_arr_map to be true if column is an array of Maps.
|
||||
Otherwise, assume column is a Map.
|
||||
"""
|
||||
if is_col_arr_map:
|
||||
df = df.select(explode(col_name).alias(col_name))
|
||||
df = df.select(explode(map_keys(col_name)))
|
||||
return df.distinct().rdd.flatMap(lambda x: x).collect()
|
||||
|
||||
|
||||
def infer_states_names(df, multi_steps: Optional[int]):
|
||||
""" Infer possible state names from states and next state features. """
|
||||
state_keys = get_distinct_keys(df, "state_features")
|
||||
next_states_is_col_arr_map = not (multi_steps is None)
|
||||
next_state_keys = get_distinct_keys(
|
||||
df, "next_state_features", is_col_arr_map=next_states_is_col_arr_map
|
||||
)
|
||||
return sorted(set(state_keys) | set(next_state_keys))
|
||||
|
||||
|
||||
def infer_metrics_names(df, multi_steps: Optional[int]):
|
||||
""" Infer possible metrics names.
|
||||
Assume in multi-step case, metrics is an array of maps.
|
||||
"""
|
||||
is_col_arr_map = not (multi_steps is None)
|
||||
return sorted(get_distinct_keys(df, "metrics", is_col_arr_map=is_col_arr_map))
|
||||
|
||||
|
||||
def query_data(
|
||||
input_table_spec: TableSpec,
|
||||
states: List[int],
|
||||
actions: List[str],
|
||||
metrics: List[str],
|
||||
custom_reward_expression: Optional[str] = None,
|
||||
sample_range: Optional[Tuple[float, float]] = None,
|
||||
multi_steps: Optional[int] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> Dataset:
|
||||
""" Perform reward calculation, hashing mdp + subsampling and
|
||||
other preprocessing such as sparse2dense.
|
||||
"""
|
||||
sqlCtx = get_spark_session()
|
||||
# performs rewards preprocessing, sparse2dense
|
||||
preprocessed_df = perform_preprocessing(
|
||||
sqlCtx,
|
||||
table_spec=input_table_spec,
|
||||
states=states,
|
||||
actions=actions,
|
||||
metrics=metrics,
|
||||
df = sqlCtx.sql(f"SELECT * FROM {input_table_spec.table_name}")
|
||||
states = infer_states_names(df, multi_steps)
|
||||
metrics = infer_metrics_names(df, multi_steps)
|
||||
df = set_reward_col_as_reward(
|
||||
df,
|
||||
custom_reward_expression=custom_reward_expression,
|
||||
sample_range=sample_range,
|
||||
multi_steps=multi_steps,
|
||||
gamma=gamma,
|
||||
)
|
||||
preprocessed_df.write.mode("overwrite").parquet(
|
||||
input_table_spec.output_dataset.parquet_url
|
||||
df = hash_mdp_id_and_subsample(df, sample_range=sample_range)
|
||||
df = perform_preprocessing(
|
||||
df, states=states, actions=actions, metrics=metrics, multi_steps=multi_steps
|
||||
)
|
||||
df = select_relevant_columns(df)
|
||||
df.write.mode("overwrite").parquet(input_table_spec.output_dataset.parquet_url)
|
||||
return input_table_spec.output_dataset
|
||||
|
||||
@@ -8,7 +8,11 @@ from reagent.core.dataclasses import dataclass, field
|
||||
from reagent.evaluation.evaluator import Evaluator, get_metrics_to_score
|
||||
from reagent.models.base import ModelBase
|
||||
from reagent.parameters import NormalizationData
|
||||
from reagent.preprocessing.batch_preprocessor import BatchPreprocessor
|
||||
from reagent.preprocessing.batch_preprocessor import (
|
||||
BatchPreprocessor,
|
||||
DiscreteDqnBatchPreprocessor,
|
||||
)
|
||||
from reagent.preprocessing.preprocessor import Preprocessor
|
||||
from reagent.workflow.data_fetcher import query_data
|
||||
from reagent.workflow.identify_types_flow import identify_normalization_parameters
|
||||
from reagent.workflow.model_managers.model_manager import ModelManager
|
||||
@@ -31,6 +35,18 @@ from reagent.workflow_utils.page_handler import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.policies.samplers.discrete_sampler import SoftmaxActionSampler
|
||||
from reagent.gym.policies.scorers.discrete_scorer import discrete_dqn_scorer
|
||||
from reagent.gym.preprocessors import (
|
||||
argmax_action_preprocessor,
|
||||
discrete_dqn_trainer_preprocessor,
|
||||
numpy_policy_preprocessor,
|
||||
)
|
||||
except ImportError:
|
||||
logger.info(f"Using {__file__} without reagent.gym.")
|
||||
|
||||
|
||||
class DiscreteNormalizationParameterKeys:
|
||||
STATE = "state"
|
||||
@@ -54,6 +70,26 @@ class DiscreteDQNBase(ModelManager):
|
||||
def normalization_key(cls) -> str:
|
||||
return DiscreteNormalizationParameterKeys.STATE
|
||||
|
||||
def create_policy(self, env) -> Policy:
|
||||
""" Create an online DiscreteDQN Policy from env.
|
||||
Args:
|
||||
env: gym.Env is the environment to run on.
|
||||
"""
|
||||
sampler = SoftmaxActionSampler(temperature=self.rl_parameters.temperature)
|
||||
scorer = discrete_dqn_scorer(self.trainer.q_network)
|
||||
policy_preprocessor = numpy_policy_preprocessor()
|
||||
return Policy(
|
||||
scorer=scorer, sampler=sampler, policy_preprocessor=policy_preprocessor
|
||||
)
|
||||
|
||||
def create_action_preprocessor(self):
|
||||
return argmax_action_preprocessor
|
||||
|
||||
def create_trainer_preprocessor(self):
|
||||
return discrete_dqn_trainer_preprocessor(
|
||||
len(self.action_names), self.state_normalization_parameters
|
||||
)
|
||||
|
||||
@property
|
||||
def metrics_to_score(self) -> List[str]:
|
||||
assert self.reward_options is not None
|
||||
@@ -108,18 +144,13 @@ class DiscreteDQNBase(ModelManager):
|
||||
reward_options: RewardOptions,
|
||||
eval_dataset: bool,
|
||||
) -> Dataset:
|
||||
# sort is set to False because EvaluationPageHandler sort the data anyway
|
||||
return query_data(
|
||||
input_table_spec,
|
||||
self.action_names,
|
||||
self.rl_parameters.use_seq_num_diff_as_time_diff,
|
||||
input_table_spec=input_table_spec,
|
||||
actions=self.action_names,
|
||||
sample_range=sample_range,
|
||||
metric_reward_values=reward_options.metric_reward_values,
|
||||
custom_reward_expression=reward_options.custom_reward_expression,
|
||||
additional_reward_expression=reward_options.additional_reward_expression,
|
||||
multi_steps=self.multi_steps,
|
||||
gamma=self.rl_parameters.gamma,
|
||||
sort=False,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -127,7 +158,12 @@ class DiscreteDQNBase(ModelManager):
|
||||
return self.rl_parameters.multi_steps
|
||||
|
||||
def build_batch_preprocessor(self) -> BatchPreprocessor:
|
||||
raise NotImplementedError
|
||||
return DiscreteDqnBatchPreprocessor(
|
||||
state_preprocessor=Preprocessor(
|
||||
normalization_parameters=self.state_normalization_parameters,
|
||||
use_gpu=self.use_gpu,
|
||||
)
|
||||
)
|
||||
|
||||
def train(
|
||||
self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int
|
||||
@@ -135,7 +171,7 @@ class DiscreteDQNBase(ModelManager):
|
||||
"""
|
||||
Train the model
|
||||
|
||||
Returns partially filled RLTrainningOutput. The field that should not be filled
|
||||
Returns partially filled RLTrainingOutput. The field that should not be filled
|
||||
are:
|
||||
- output_path
|
||||
- warmstart_output_path
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import abc
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import reagent.types as rlt
|
||||
import torch # @manual
|
||||
from reagent.core.registry_meta import RegistryMeta
|
||||
from reagent.parameters import NormalizationData, NormalizationParameters
|
||||
@@ -148,7 +147,7 @@ class ModelManager(metaclass=RegistryMeta):
|
||||
reward_options: RewardOptions,
|
||||
normalization_data_map: Dict[str, NormalizationData],
|
||||
warmstart_path: Optional[str] = None,
|
||||
) -> None:
|
||||
) -> RLTrainer:
|
||||
"""
|
||||
Initialize the trainer. Subclass should not override this. Instead,
|
||||
subclass should implement `_set_normalization_parameters()` and
|
||||
@@ -162,6 +161,7 @@ class ModelManager(metaclass=RegistryMeta):
|
||||
if warmstart_path is not None:
|
||||
trainer_state = torch.load(warmstart_path)
|
||||
self._trainer.load_state_dict(trainer_state)
|
||||
return self._trainer
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_trainer(self) -> RLTrainer:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
click==7.0
|
||||
gym[classic_control,box2d,atari]
|
||||
gym-minigrid
|
||||
numpy==1.17.2
|
||||
pandas==0.25.0
|
||||
pydantic==1.4
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
[tool:pytest]
|
||||
addopts = --verbose -d
|
||||
python_files = reagent/test/*.py reagent/test/**/*.py
|
||||
|
||||
[metadata]
|
||||
name = reagent
|
||||
version = 0.1
|
||||
@@ -28,4 +24,32 @@ install_requires =
|
||||
xgboost==0.90
|
||||
|
||||
[options.extras_require]
|
||||
gym = gym[classic_control,box2d,atari]
|
||||
gym =
|
||||
gym[classic_control,box2d,atari]
|
||||
gym_minigrid
|
||||
|
||||
|
||||
###########
|
||||
# Linting #
|
||||
###########
|
||||
|
||||
[flake8]
|
||||
# E203: black and flake8 disagree on whitespace before ':'
|
||||
# W503: black and flake8 disagree on how to place operators
|
||||
ignore = E203, W503
|
||||
max-line-length = 88
|
||||
exclude =
|
||||
.git,__pycache__,docs
|
||||
|
||||
[coverage:report]
|
||||
omit =
|
||||
serving/*
|
||||
|
||||
|
||||
###########
|
||||
# Testing #
|
||||
###########
|
||||
|
||||
[tool:pytest]
|
||||
addopts = --verbose -d
|
||||
python_files = reagent/test/*.py reagent/test/**/*.py reagent/gym/tests/**/*.py
|
||||
|
||||
@@ -15,5 +15,5 @@ deps =
|
||||
pytest==5.3
|
||||
spark-testing-base==0.10.0
|
||||
commands =
|
||||
pytest --junitxml={envlogdir}/junit-{envname}.xml -n auto --tx 2*popen -m "not serial"
|
||||
pytest --junitxml={envlogdir}/junit-{envname}-serial.xml -n0 --tx 1*popen -m "serial"
|
||||
pytest --junitxml={envlogdir}/junit-{envname}.xml -n4 --tx popen -m "not serial"
|
||||
pytest --junitxml={envlogdir}/junit-{envname}-serial.xml -n0 --tx popen -m "serial"
|
||||
|
||||
Reference in New Issue
Block a user