mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Tune SAC and CRR Models. Initial support for batch gym training (#470)
Summary: Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/470 Reviewed By: czxttkl Differential Revision: D28093192 fbshipit-source-id: 6b260c3e8d49c8b302e40066e2be49a0bfe96688
This commit is contained in:
@@ -100,7 +100,7 @@ commands:
|
||||
name: Run script
|
||||
command: |
|
||||
# gather data and store as pickle
|
||||
coverage run ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym "$CONFIG"
|
||||
coverage run ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random "$CONFIG"
|
||||
# run through timeline operator
|
||||
coverage run --append ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG"
|
||||
# train on logged data
|
||||
|
||||
+2
-2
@@ -37,7 +37,7 @@ To train a batch RL model, run the following commands:
|
||||
# set the config
|
||||
export CONFIG=reagent/workflow/sample_configs/discrete_dqn_cartpole_offline.yaml
|
||||
# gather some random transitions (can replace with your own)
|
||||
./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym $CONFIG
|
||||
./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random $CONFIG
|
||||
# convert data to timeline format
|
||||
./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator $CONFIG
|
||||
# train model based on timeline data
|
||||
@@ -92,7 +92,7 @@ In particular, the following Click command runs 150 episodes of ``CartPole-v0``
|
||||
|
||||
.. code-block::
|
||||
|
||||
./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym $CONFIG
|
||||
./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random $CONFIG
|
||||
|
||||
The command essentially performs the following pseudo-code:
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from reagent.gym.policies.scorers.discrete_scorer import (
|
||||
parametric_dqn_serving_scorer,
|
||||
)
|
||||
from reagent.gym.policies.scorers.slate_q_scorer import slate_q_serving_scorer
|
||||
from reagent.models.actor import LOG_PROB_MIN, LOG_PROB_MAX
|
||||
|
||||
|
||||
if IS_FB_ENVIRONMENT:
|
||||
@@ -116,6 +117,10 @@ class ActorPredictorPolicy(Policy):
|
||||
def act(
|
||||
self, obs: Any, possible_actions_mask: Optional[np.ndarray] = None
|
||||
) -> rlt.ActorOutput:
|
||||
action = self.predictor(obs).cpu()
|
||||
# TODO: return log_probs as well
|
||||
return rlt.ActorOutput(action=action)
|
||||
output = self.predictor(obs)
|
||||
if isinstance(output, tuple):
|
||||
action, log_prob = output
|
||||
log_prob = log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
return rlt.ActorOutput(action=action.cpu(), log_prob=log_prob.cpu())
|
||||
else:
|
||||
return rlt.ActorOutput(action=output.cpu())
|
||||
|
||||
@@ -13,7 +13,7 @@ from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.policies.scorers.discrete_scorer import apply_possible_actions_mask
|
||||
|
||||
|
||||
def make_random_policy_for_env(env: gym.Env):
|
||||
def make_random_policy_for_env(env: gym.Env) -> Policy:
|
||||
if isinstance(env.action_space, gym.spaces.Discrete):
|
||||
# discrete action space
|
||||
return DiscreteRandomPolicy.create_for_env(env)
|
||||
|
||||
@@ -13,7 +13,7 @@ model:
|
||||
minibatches_per_step: 1
|
||||
optimizer:
|
||||
AdamW:
|
||||
lr: 0.003
|
||||
lr: 0.001
|
||||
amsgrad: true
|
||||
net_builder:
|
||||
FullyConnected:
|
||||
@@ -28,7 +28,7 @@ model:
|
||||
replay_memory_size: 100000
|
||||
train_every_ts: 1
|
||||
train_after_ts: 20000
|
||||
num_train_episodes: 80
|
||||
num_train_episodes: 90
|
||||
num_eval_episodes: 20
|
||||
passing_score_bar: 100.0
|
||||
use_gpu: false
|
||||
|
||||
@@ -6,8 +6,9 @@ model:
|
||||
trainer_param:
|
||||
rl:
|
||||
gamma: 0.99
|
||||
target_update_rate: 0.01
|
||||
target_update_rate: 0.005
|
||||
softmax_policy: true
|
||||
entropy_temperature: 0.3
|
||||
crr_config:
|
||||
exponent_beta: 1.0
|
||||
exponent_clamp: 20.0
|
||||
@@ -54,4 +55,4 @@ num_eval_episodes: 20
|
||||
# Though maximal score is 0, we set lower bar to let tests finish in time
|
||||
passing_score_bar: -500
|
||||
use_gpu: false
|
||||
minibatch_size: 1024
|
||||
minibatch_size: 256
|
||||
|
||||
@@ -8,7 +8,7 @@ model:
|
||||
gamma: 0.99
|
||||
target_update_rate: 0.005
|
||||
softmax_policy: true
|
||||
entropy_temperature: 0.1
|
||||
entropy_temperature: 0.3
|
||||
q_network_optimizer:
|
||||
Adam:
|
||||
lr: 0.001
|
||||
|
||||
@@ -21,6 +21,7 @@ from reagent.gym.datasets.replay_buffer_dataset import ReplayBufferDataset
|
||||
from reagent.gym.envs import Env__Union
|
||||
from reagent.gym.envs.env_wrapper import EnvWrapper
|
||||
from reagent.gym.policies.policy import Policy
|
||||
from reagent.gym.policies.random_policies import make_random_policy_for_env
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes, run_episode
|
||||
from reagent.gym.types import PostEpisode, PostStep
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
@@ -239,8 +240,13 @@ def run_test_replay_buffer(
|
||||
device = torch.device("cuda") if use_gpu else torch.device("cpu")
|
||||
# first fill the replay buffer using random policy
|
||||
train_after_ts = max(train_after_ts, minibatch_size)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
fill_replay_buffer(
|
||||
env=env, replay_buffer=replay_buffer, desired_size=train_after_ts
|
||||
env=env,
|
||||
replay_buffer=replay_buffer,
|
||||
desired_size=train_after_ts,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
agent = Agent.create_for_env(env, policy=training_policy, device=device)
|
||||
|
||||
@@ -12,6 +12,7 @@ from parameterized import parameterized
|
||||
from reagent.core.tensorboardX import summary_writer_context
|
||||
from reagent.gym.agents.agent import Agent
|
||||
from reagent.gym.envs import Gym
|
||||
from reagent.gym.policies.random_policies import make_random_policy_for_env
|
||||
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
@@ -110,8 +111,13 @@ def run_test_offline(
|
||||
replay_capacity=replay_memory_size, batch_size=minibatch_size
|
||||
)
|
||||
# always fill full RB
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
fill_replay_buffer(
|
||||
env=env, replay_buffer=replay_buffer, desired_size=replay_memory_size
|
||||
env=env,
|
||||
replay_buffer=replay_buffer,
|
||||
desired_size=replay_memory_size,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
device = torch.device("cuda") if use_gpu else None
|
||||
|
||||
@@ -16,6 +16,7 @@ from reagent.evaluation.world_model_evaluator import (
|
||||
from reagent.gym.agents.agent import Agent
|
||||
from reagent.gym.envs import EnvWrapper, Gym
|
||||
from reagent.gym.envs.pomdp.state_embed_env import StateEmbedEnvironment
|
||||
from reagent.gym.policies.random_policies import make_random_policy_for_env
|
||||
from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
|
||||
from reagent.gym.utils import build_normalizer, fill_replay_buffer
|
||||
@@ -122,7 +123,9 @@ def train_mdnrnn(
|
||||
stack_size=seq_len,
|
||||
return_everything_as_stack=True,
|
||||
)
|
||||
fill_replay_buffer(env, train_replay_buffer, num_train_transitions)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
fill_replay_buffer(env, train_replay_buffer, num_train_transitions, agent)
|
||||
num_batch_per_epoch = train_replay_buffer.size // batch_size
|
||||
|
||||
logger.info("Made RBs, starting to train now!")
|
||||
@@ -180,7 +183,9 @@ def train_mdnrnn_and_compute_feature_stats(
|
||||
stack_size=seq_len,
|
||||
return_everything_as_stack=True,
|
||||
)
|
||||
fill_replay_buffer(env, test_replay_buffer, num_test_transitions)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
fill_replay_buffer(env, test_replay_buffer, num_test_transitions, agent)
|
||||
|
||||
if saved_mdnrnn_path is None:
|
||||
# train from scratch
|
||||
@@ -248,8 +253,13 @@ def create_embed_rl_dataset(
|
||||
embed_rb = ReplayBuffer(
|
||||
replay_capacity=num_state_embed_transitions, batch_size=batch_size, stack_size=1
|
||||
)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
fill_replay_buffer(
|
||||
env=embed_env, replay_buffer=embed_rb, desired_size=num_state_embed_transitions
|
||||
env=embed_env,
|
||||
replay_buffer=embed_rb,
|
||||
desired_size=num_state_embed_transitions,
|
||||
agent=agent,
|
||||
)
|
||||
batch = embed_rb.sample_transition_batch(batch_size=num_state_embed_transitions)
|
||||
state_min = min(batch.state.min(), batch.next_state.min()).item()
|
||||
|
||||
+11
-10
@@ -38,8 +38,10 @@ except ImportError:
|
||||
HAS_RECSIM = False
|
||||
|
||||
|
||||
def fill_replay_buffer(env, replay_buffer: ReplayBuffer, desired_size: int):
|
||||
"""Fill replay buffer with random transitions until size reaches desired_size."""
|
||||
def fill_replay_buffer(
|
||||
env, replay_buffer: ReplayBuffer, desired_size: int, agent: Agent
|
||||
):
|
||||
"""Fill replay buffer with transitions until size reaches desired_size."""
|
||||
assert (
|
||||
0 < desired_size and desired_size <= replay_buffer._replay_capacity
|
||||
), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}."
|
||||
@@ -48,18 +50,15 @@ def fill_replay_buffer(env, replay_buffer: ReplayBuffer, desired_size: int):
|
||||
f"(more than desired_size = {desired_size})"
|
||||
)
|
||||
logger.info(
|
||||
f" Starting to fill replay buffer using random policy to size: {desired_size}."
|
||||
f" Starting to fill replay buffer using policy to size: {desired_size}."
|
||||
)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
post_step = add_replay_buffer_post_step(replay_buffer, env=env)
|
||||
agent.post_transition_callback = post_step
|
||||
|
||||
agent = Agent.create_for_env(
|
||||
env, policy=random_policy, post_transition_callback=post_step
|
||||
)
|
||||
max_episode_steps = env.max_steps
|
||||
with tqdm(
|
||||
total=desired_size - replay_buffer.size,
|
||||
desc=f"Filling replay buffer from {replay_buffer.size} to size {desired_size} using random policy",
|
||||
desc=f"Filling replay buffer from {replay_buffer.size} to size {desired_size}",
|
||||
) as pbar:
|
||||
mdp_id = 0
|
||||
while replay_buffer.size < desired_size:
|
||||
@@ -155,7 +154,7 @@ def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]:
|
||||
|
||||
|
||||
def create_df_from_replay_buffer(
|
||||
env: gym.Env,
|
||||
env,
|
||||
problem_domain: ProblemDomain,
|
||||
desired_size: int,
|
||||
multi_steps: Optional[int],
|
||||
@@ -177,7 +176,9 @@ def create_df_from_replay_buffer(
|
||||
update_horizon=update_horizon,
|
||||
return_as_timeline_format=return_as_timeline_format,
|
||||
)
|
||||
fill_replay_buffer(env, replay_buffer, desired_size)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
fill_replay_buffer(env, replay_buffer, desired_size, agent)
|
||||
|
||||
batch = replay_buffer.sample_all_valid_transitions()
|
||||
n = batch.state.shape[0]
|
||||
|
||||
@@ -13,6 +13,9 @@ from reagent.models.fully_connected_network import FullyConnectedNetwork
|
||||
from torch.distributions import Dirichlet
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
LOG_PROB_MIN = -2.0
|
||||
LOG_PROB_MAX = 2.0
|
||||
|
||||
|
||||
class StochasticActor(ModelBase):
|
||||
def __init__(self, scorer, sampler):
|
||||
@@ -86,7 +89,7 @@ class FullyConnectedActor(ModelBase):
|
||||
# TODO: log prob is affected by clamping, how to handle that?
|
||||
log_prob = (
|
||||
self.noise_dist.log_prob(noise).to(action.device).sum(dim=1).view(-1, 1)
|
||||
)
|
||||
).clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
action = (action + noise.to(action.device)).clamp(
|
||||
*CONTINUOUS_TRAINING_ACTION_RANGE
|
||||
)
|
||||
@@ -136,7 +139,6 @@ class GaussianFullyConnectedActor(ModelBase):
|
||||
# used to calculate log-prob
|
||||
self.const = math.log(math.sqrt(2 * math.pi))
|
||||
self.eps = 1e-6
|
||||
self._log_min_max = (-20.0, 2.0)
|
||||
|
||||
def input_prototype(self):
|
||||
return rlt.FeatureData(torch.randn(1, self.state_dim))
|
||||
@@ -174,7 +176,7 @@ class GaussianFullyConnectedActor(ModelBase):
|
||||
loc = self.loc_layer_norm(loc)
|
||||
scale_log = self.scale_layer_norm(scale_log)
|
||||
|
||||
scale_log = scale_log.clamp(*self._log_min_max)
|
||||
scale_log = scale_log.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
return loc, scale_log
|
||||
|
||||
def _squash_raw_action(self, raw_action: torch.Tensor) -> torch.Tensor:
|
||||
@@ -289,9 +291,5 @@ class DirichletFullyConnectedActor(ModelBase):
|
||||
# ONNX can't export Dirichlet()
|
||||
action = torch._sample_dirichlet(concentration)
|
||||
|
||||
if not self.training:
|
||||
# ONNX doesn't like reshape either..
|
||||
return rlt.ActorOutput(action=action)
|
||||
|
||||
log_prob = Dirichlet(concentration).log_prob(action)
|
||||
return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
|
||||
|
||||
@@ -313,7 +313,6 @@ class ActorWithPreprocessor(ModelBase):
|
||||
state_with_presence[0], state_with_presence[1]
|
||||
)
|
||||
state_feature_vector = rlt.FeatureData(preprocessed_state)
|
||||
# TODO: include log_prob in the output
|
||||
model_output = self.model(state_feature_vector)
|
||||
if self.serve_mean_policy:
|
||||
assert (
|
||||
@@ -326,7 +325,7 @@ class ActorWithPreprocessor(ModelBase):
|
||||
if self.action_postprocessor:
|
||||
# pyre-fixme[29]: `Optional[Postprocessor]` is not a function.
|
||||
action = self.action_postprocessor(action)
|
||||
return action
|
||||
return (action, model_output.log_prob)
|
||||
|
||||
def input_prototype(self):
|
||||
return (self.state_preprocessor.input_prototype(),)
|
||||
@@ -351,9 +350,8 @@ class ActorPredictorWrapper(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(
|
||||
self, state_with_presence: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
action = self.actor_with_preprocessor(state_with_presence)
|
||||
return action
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.actor_with_preprocessor(state_with_presence)
|
||||
|
||||
|
||||
class RankingActorWithPreprocessor(ModelBase):
|
||||
|
||||
@@ -193,7 +193,7 @@ class TestPredictorWrapper(unittest.TestCase):
|
||||
)
|
||||
wrapper = ActorPredictorWrapper(actor_with_preprocessor)
|
||||
input_prototype = actor_with_preprocessor.input_prototype()
|
||||
action = wrapper(*input_prototype)
|
||||
action, _log_prob = wrapper(*input_prototype)
|
||||
self.assertEqual(action.shape, (1, len(action_normalization_parameters)))
|
||||
|
||||
expected_output = postprocessor(
|
||||
|
||||
@@ -54,9 +54,15 @@ class ReAgentLightningModule(pl.LightningModule):
|
||||
def reporter(self):
|
||||
return self._reporter
|
||||
|
||||
def set_clean_stop(self, clean_stop: bool):
|
||||
if clean_stop:
|
||||
self._cleanly_stopped = torch.ones(1)
|
||||
else:
|
||||
self._cleanly_stopped = torch.zeros(1)
|
||||
|
||||
def increase_next_stopping_epochs(self, num_epochs: int):
|
||||
self._next_stopping_epoch += num_epochs
|
||||
self._cleanly_stopped[0] = torch.zeros(1)
|
||||
self.set_clean_stop(False)
|
||||
return self
|
||||
|
||||
def train_step_gen(self, training_batch, batch_idx: int):
|
||||
|
||||
@@ -12,11 +12,11 @@ from reagent.core.configuration import resolve_defaults
|
||||
from reagent.core.dataclasses import dataclass
|
||||
from reagent.core.dataclasses import field
|
||||
from reagent.core.parameters import RLParameters
|
||||
from reagent.models.actor import LOG_PROB_MIN, LOG_PROB_MAX
|
||||
from reagent.optimizer import Optimizer__Union, SoftUpdate
|
||||
from reagent.training.reagent_lightning_module import ReAgentLightningModule
|
||||
from reagent.training.rl_trainer_pytorch import RLTrainerMixin
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -229,8 +229,7 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule):
|
||||
|
||||
log_prob_a = self.actor_network.get_log_prob(
|
||||
training_batch.next_state, next_state_actor_output.action
|
||||
)
|
||||
log_prob_a = log_prob_a.clamp(-20.0, 20.0)
|
||||
).clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
next_state_value -= self.entropy_temperature * log_prob_a
|
||||
|
||||
if self.gamma > 0.0:
|
||||
@@ -263,7 +262,7 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule):
|
||||
q2_actor_value = self.q2_network(*state_actor_action)
|
||||
min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)
|
||||
|
||||
actor_log_prob = actor_output.log_prob
|
||||
actor_log_prob = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
|
||||
if not self.backprop_through_log_prob:
|
||||
actor_log_prob = actor_log_prob.detach()
|
||||
@@ -309,7 +308,10 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule):
|
||||
alpha_loss = -(
|
||||
(
|
||||
self.log_alpha
|
||||
* (actor_output.log_prob + self.target_entropy).detach()
|
||||
* (
|
||||
actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
+ self.target_entropy
|
||||
).detach()
|
||||
).mean()
|
||||
)
|
||||
yield alpha_loss
|
||||
@@ -327,8 +329,7 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule):
|
||||
log_prob_a = torch.zeros_like(min_q_actor_value)
|
||||
target_value = min_q_actor_value
|
||||
else:
|
||||
log_prob_a = actor_output.log_prob
|
||||
log_prob_a = log_prob_a.clamp(-20.0, 20.0)
|
||||
log_prob_a = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
|
||||
target_value = min_q_actor_value - self.entropy_temperature * log_prob_a
|
||||
|
||||
value_loss = F.mse_loss(state_value, target_value.detach())
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -58,7 +59,8 @@ def select_relevant_params(config_dict, ConfigClass):
|
||||
@reagent.command(short_help="Run the workflow with config file")
|
||||
@click.argument("workflow")
|
||||
@click.argument("config_file", type=click.File("r"))
|
||||
def run(workflow, config_file):
|
||||
@click.option("--extra-options", default=None)
|
||||
def run(workflow, config_file, extra_options):
|
||||
|
||||
func, ConfigClass = _load_func_and_config_class(workflow)
|
||||
|
||||
@@ -70,6 +72,8 @@ def run(workflow, config_file):
|
||||
yaml = YAML(typ="safe")
|
||||
config_dict = yaml.load(config_file.read())
|
||||
assert config_dict is not None, "failed to read yaml file"
|
||||
if extra_options is not None:
|
||||
config_dict.update(json.loads(extra_options))
|
||||
config_dict = select_relevant_params(config_dict, ConfigClass)
|
||||
config = ConfigClass(**config_dict)
|
||||
func(**config.asdict())
|
||||
|
||||
@@ -14,6 +14,7 @@ from reagent.data.spark_utils import call_spark_class, get_spark_session
|
||||
from reagent.gym.agents.agent import Agent
|
||||
from reagent.gym.envs import Gym
|
||||
from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model
|
||||
from reagent.gym.policies.random_policies import make_random_policy_for_env
|
||||
from reagent.gym.runners.gymrunner import evaluate_for_n_episodes
|
||||
from reagent.gym.utils import fill_replay_buffer
|
||||
from reagent.model_managers.union import ModelManager__Union
|
||||
@@ -34,7 +35,7 @@ def initialize_seed(seed: Optional[int] = None):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def offline_gym(
|
||||
def offline_gym_random(
|
||||
env_name: str,
|
||||
pkl_path: str,
|
||||
num_train_transitions: int,
|
||||
@@ -42,14 +43,46 @@ def offline_gym(
|
||||
seed: int = 1,
|
||||
):
|
||||
"""
|
||||
Generate samples from a DiscreteRandomPolicy on the Gym environment and
|
||||
Generate samples from a random Policy on the Gym environment and
|
||||
saves results in a pandas df parquet.
|
||||
"""
|
||||
initialize_seed(seed)
|
||||
env = Gym(env_name=env_name)
|
||||
random_policy = make_random_policy_for_env(env)
|
||||
agent = Agent.create_for_env(env, policy=random_policy)
|
||||
return _offline_gym(env, agent, pkl_path, num_train_transitions, max_steps, seed)
|
||||
|
||||
|
||||
def offline_gym_predictor(
|
||||
env_name: str,
|
||||
model: ModelManager__Union,
|
||||
publisher: ModelPublisher__Union,
|
||||
pkl_path: str,
|
||||
num_train_transitions: int,
|
||||
max_steps: Optional[int],
|
||||
module_name: str = "default_model",
|
||||
seed: int = 1,
|
||||
):
|
||||
"""
|
||||
Generate samples from a trained Policy on the Gym environment and
|
||||
saves results in a pandas df parquet.
|
||||
"""
|
||||
env = Gym(env_name=env_name)
|
||||
agent = make_agent_from_model(env, model, publisher, module_name)
|
||||
return _offline_gym(env, agent, pkl_path, num_train_transitions, max_steps, seed)
|
||||
|
||||
|
||||
def _offline_gym(
|
||||
env: Gym,
|
||||
agent: Agent,
|
||||
pkl_path: str,
|
||||
num_train_transitions: int,
|
||||
max_steps: Optional[int],
|
||||
seed: int = 1,
|
||||
):
|
||||
initialize_seed(seed)
|
||||
|
||||
replay_buffer = ReplayBuffer(replay_capacity=num_train_transitions, batch_size=1)
|
||||
fill_replay_buffer(env, replay_buffer, num_train_transitions)
|
||||
fill_replay_buffer(env, replay_buffer, num_train_transitions, agent)
|
||||
if isinstance(env.action_space, gym.spaces.Discrete):
|
||||
is_discrete_action = True
|
||||
else:
|
||||
@@ -90,6 +123,27 @@ def timeline_operator(pkl_path: str, input_table_spec: TableSpec):
|
||||
call_spark_class(spark, class_name="Timeline", args=json.dumps(arg))
|
||||
|
||||
|
||||
def make_agent_from_model(
|
||||
env: Gym,
|
||||
model: ModelManager__Union,
|
||||
publisher: ModelPublisher__Union,
|
||||
module_name: str,
|
||||
):
|
||||
publisher_manager = publisher.value
|
||||
assert isinstance(
|
||||
publisher_manager, FileSystemPublisher
|
||||
), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher"
|
||||
module_names = model.value.serving_module_names()
|
||||
assert module_name in module_names, f"{module_name} not in {module_names}"
|
||||
torchscript_path = publisher_manager.get_latest_published_model(
|
||||
model.value, module_name
|
||||
)
|
||||
jit_model = torch.jit.load(torchscript_path)
|
||||
policy = create_predictor_policy_from_model(jit_model)
|
||||
agent = Agent.create_for_env_with_serving_policy(env, policy)
|
||||
return agent
|
||||
|
||||
|
||||
def evaluate_gym(
|
||||
env_name: str,
|
||||
model: ModelManager__Union,
|
||||
@@ -100,26 +154,17 @@ def evaluate_gym(
|
||||
max_steps: Optional[int] = None,
|
||||
):
|
||||
initialize_seed(1)
|
||||
publisher_manager = publisher.value
|
||||
assert isinstance(
|
||||
publisher_manager, FileSystemPublisher
|
||||
), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher"
|
||||
env = Gym(env_name=env_name)
|
||||
module_names = model.value.serving_module_names()
|
||||
assert module_name in module_names, f"{module_name} not in {module_names}"
|
||||
torchscript_path = publisher_manager.get_latest_published_model(
|
||||
model.value, module_name
|
||||
)
|
||||
jit_model = torch.jit.load(torchscript_path)
|
||||
policy = create_predictor_policy_from_model(jit_model)
|
||||
agent = Agent.create_for_env_with_serving_policy(env, policy)
|
||||
agent = make_agent_from_model(env, model, publisher, module_name)
|
||||
|
||||
rewards = evaluate_for_n_episodes(
|
||||
n=num_eval_episodes, env=env, agent=agent, max_steps=max_steps
|
||||
)
|
||||
avg_reward = np.mean(rewards)
|
||||
logger.info(
|
||||
f"Average reward over {num_eval_episodes} is {avg_reward}.\n"
|
||||
f"List of rewards: {rewards}"
|
||||
f"List of rewards: {rewards}\n"
|
||||
f"Passing score bar: {passing_score_bar}"
|
||||
)
|
||||
assert (
|
||||
avg_reward >= passing_score_bar
|
||||
|
||||
@@ -11,6 +11,7 @@ model:
|
||||
rl:
|
||||
gamma: 0.9
|
||||
target_update_rate: 0.5
|
||||
softmax_policy: true
|
||||
entropy_temperature: 0.01
|
||||
q_network_optimizer:
|
||||
Adam:
|
||||
@@ -21,7 +22,6 @@ model:
|
||||
actor_network_optimizer:
|
||||
Adam:
|
||||
lr: 0.001
|
||||
alpha_optimizer: null
|
||||
actor_net_builder:
|
||||
GaussianFullyConnected:
|
||||
sizes:
|
||||
@@ -55,7 +55,7 @@ model:
|
||||
calc_cpe_in_training: false
|
||||
|
||||
num_train_transitions: 40000 # approx. 200 episodes
|
||||
max_steps: 200
|
||||
max_steps: 1000
|
||||
seed: 42
|
||||
num_epochs: 80
|
||||
publisher:
|
||||
@@ -65,3 +65,4 @@ num_eval_episodes: 30
|
||||
passing_score_bar: -1000
|
||||
reader_options:
|
||||
minibatch_size: 1024
|
||||
warmstart_path: test_warmstart
|
||||
|
||||
@@ -142,4 +142,8 @@ def train_eval_lightning(
|
||||
)
|
||||
trainer.fit(trainer_module, datamodule=datamodule)
|
||||
trainer.test()
|
||||
if checkpoint_path is not None:
|
||||
# Overwrite the warmstart path with the new model
|
||||
trainer_module.set_clean_stop(True)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
return trainer
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -x -e
|
||||
|
||||
rm -f /tmp/file_system_publisher
|
||||
rm -Rf test_warmstart model_* pl_log* runs
|
||||
|
||||
CONFIG=reagent/workflow/sample_configs/sac_pendulum_offline.yaml
|
||||
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random "$CONFIG"
|
||||
rm -Rf spark-warehouse derby.log metastore_db
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG"
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.training.identify_and_train_network "$CONFIG"
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.evaluate_gym "$CONFIG"
|
||||
|
||||
for _ in {0..30}
|
||||
do
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_predictor "$CONFIG"
|
||||
rm -Rf spark-warehouse derby.log metastore_db
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG"
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.training.identify_and_train_network "$CONFIG"
|
||||
python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.evaluate_gym "$CONFIG"
|
||||
done
|
||||
Reference in New Issue
Block a user