From 39385e8d83d8a5ff13b8a2190d1c2080a90db2fa Mon Sep 17 00:00:00 2001 From: Jason Gauci Date: Tue, 18 May 2021 09:27:48 -0700 Subject: [PATCH] Tune SAC and CRR Models. Initial support for batch gym training (#470) Summary: Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/470 Reviewed By: czxttkl Differential Revision: D28093192 fbshipit-source-id: 6b260c3e8d49c8b302e40066e2be49a0bfe96688 --- .circleci/config.yml | 2 +- docs/usage.rst | 4 +- reagent/gym/policies/predictor_policies.py | 11 ++- reagent/gym/policies/random_policies.py | 2 +- .../parametric_dqn_cartpole_online.yaml | 4 +- .../continuous_crr_pendulum_online.yaml | 5 +- .../configs/pendulum/sac_pendulum_online.yaml | 2 +- reagent/gym/tests/test_gym.py | 8 +- reagent/gym/tests/test_gym_offline.py | 8 +- reagent/gym/tests/test_world_model.py | 16 +++- reagent/gym/utils.py | 21 ++--- reagent/models/actor.py | 12 ++- reagent/prediction/predictor_wrapper.py | 8 +- .../test/prediction/test_predictor_wrapper.py | 2 +- reagent/training/reagent_lightning_module.py | 8 +- reagent/training/sac_trainer.py | 15 ++-- reagent/workflow/cli.py | 6 +- reagent/workflow/gym_batch_rl.py | 79 +++++++++++++++---- .../sample_configs/sac_pendulum_offline.yaml | 5 +- reagent/workflow/utils.py | 4 + scripts/recurring_training_sac_offline.sh | 23 ++++++ 21 files changed, 177 insertions(+), 68 deletions(-) create mode 100644 scripts/recurring_training_sac_offline.sh diff --git a/.circleci/config.yml b/.circleci/config.yml index 6ce1630a..15bddda3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -100,7 +100,7 @@ commands: name: Run script command: | # gather data and store as pickle - coverage run ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym "$CONFIG" + coverage run ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random "$CONFIG" # run through timeline operator coverage run --append ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG" # train on logged data diff --git a/docs/usage.rst b/docs/usage.rst index bf80181b..f761f679 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -37,7 +37,7 @@ To train a batch RL model, run the following commands: # set the config export CONFIG=reagent/workflow/sample_configs/discrete_dqn_cartpole_offline.yaml # gather some random transitions (can replace with your own) - ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym $CONFIG + ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random $CONFIG # convert data to timeline format ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator $CONFIG # train model based on timeline data @@ -92,7 +92,7 @@ In particular, the following Click command runs 150 episodes of ``CartPole-v0`` .. code-block:: - ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym $CONFIG + ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random $CONFIG The command essentially performs the following pseudo-code: diff --git a/reagent/gym/policies/predictor_policies.py b/reagent/gym/policies/predictor_policies.py index 4e15d46d..e4bfdd45 100644 --- a/reagent/gym/policies/predictor_policies.py +++ b/reagent/gym/policies/predictor_policies.py @@ -19,6 +19,7 @@ from reagent.gym.policies.scorers.discrete_scorer import ( parametric_dqn_serving_scorer, ) from reagent.gym.policies.scorers.slate_q_scorer import slate_q_serving_scorer +from reagent.models.actor import LOG_PROB_MIN, LOG_PROB_MAX if IS_FB_ENVIRONMENT: @@ -116,6 +117,10 @@ class ActorPredictorPolicy(Policy): def act( self, obs: Any, possible_actions_mask: Optional[np.ndarray] = None ) -> rlt.ActorOutput: - action = self.predictor(obs).cpu() - # TODO: return log_probs as well - return rlt.ActorOutput(action=action) + output = self.predictor(obs) + if isinstance(output, tuple): + action, log_prob = output + log_prob = log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX) + return rlt.ActorOutput(action=action.cpu(), log_prob=log_prob.cpu()) + else: + return rlt.ActorOutput(action=output.cpu()) diff --git a/reagent/gym/policies/random_policies.py b/reagent/gym/policies/random_policies.py index d9280c8a..92e7de92 100644 --- a/reagent/gym/policies/random_policies.py +++ b/reagent/gym/policies/random_policies.py @@ -13,7 +13,7 @@ from reagent.gym.policies.policy import Policy from reagent.gym.policies.scorers.discrete_scorer import apply_possible_actions_mask -def make_random_policy_for_env(env: gym.Env): +def make_random_policy_for_env(env: gym.Env) -> Policy: if isinstance(env.action_space, gym.spaces.Discrete): # discrete action space return DiscreteRandomPolicy.create_for_env(env) diff --git a/reagent/gym/tests/configs/cartpole/parametric_dqn_cartpole_online.yaml b/reagent/gym/tests/configs/cartpole/parametric_dqn_cartpole_online.yaml index 811676bc..898d8f2f 100644 --- a/reagent/gym/tests/configs/cartpole/parametric_dqn_cartpole_online.yaml +++ b/reagent/gym/tests/configs/cartpole/parametric_dqn_cartpole_online.yaml @@ -13,7 +13,7 @@ model: minibatches_per_step: 1 optimizer: AdamW: - lr: 0.003 + lr: 0.001 amsgrad: true net_builder: FullyConnected: @@ -28,7 +28,7 @@ model: replay_memory_size: 100000 train_every_ts: 1 train_after_ts: 20000 -num_train_episodes: 80 +num_train_episodes: 90 num_eval_episodes: 20 passing_score_bar: 100.0 use_gpu: false diff --git a/reagent/gym/tests/configs/pendulum/continuous_crr_pendulum_online.yaml b/reagent/gym/tests/configs/pendulum/continuous_crr_pendulum_online.yaml index 58ade0d0..ec5ffd72 100644 --- a/reagent/gym/tests/configs/pendulum/continuous_crr_pendulum_online.yaml +++ b/reagent/gym/tests/configs/pendulum/continuous_crr_pendulum_online.yaml @@ -6,8 +6,9 @@ model: trainer_param: rl: gamma: 0.99 - target_update_rate: 0.01 + target_update_rate: 0.005 softmax_policy: true + entropy_temperature: 0.3 crr_config: exponent_beta: 1.0 exponent_clamp: 20.0 @@ -54,4 +55,4 @@ num_eval_episodes: 20 # Though maximal score is 0, we set lower bar to let tests finish in time passing_score_bar: -500 use_gpu: false -minibatch_size: 1024 +minibatch_size: 256 diff --git a/reagent/gym/tests/configs/pendulum/sac_pendulum_online.yaml b/reagent/gym/tests/configs/pendulum/sac_pendulum_online.yaml index 0d08c31a..8d4be5c1 100644 --- a/reagent/gym/tests/configs/pendulum/sac_pendulum_online.yaml +++ b/reagent/gym/tests/configs/pendulum/sac_pendulum_online.yaml @@ -8,7 +8,7 @@ model: gamma: 0.99 target_update_rate: 0.005 softmax_policy: true - entropy_temperature: 0.1 + entropy_temperature: 0.3 q_network_optimizer: Adam: lr: 0.001 diff --git a/reagent/gym/tests/test_gym.py b/reagent/gym/tests/test_gym.py index 9dc81626..3eb15b54 100644 --- a/reagent/gym/tests/test_gym.py +++ b/reagent/gym/tests/test_gym.py @@ -21,6 +21,7 @@ from reagent.gym.datasets.replay_buffer_dataset import ReplayBufferDataset from reagent.gym.envs import Env__Union from reagent.gym.envs.env_wrapper import EnvWrapper from reagent.gym.policies.policy import Policy +from reagent.gym.policies.random_policies import make_random_policy_for_env from reagent.gym.runners.gymrunner import evaluate_for_n_episodes, run_episode from reagent.gym.types import PostEpisode, PostStep from reagent.gym.utils import build_normalizer, fill_replay_buffer @@ -239,8 +240,13 @@ def run_test_replay_buffer( device = torch.device("cuda") if use_gpu else torch.device("cpu") # first fill the replay buffer using random policy train_after_ts = max(train_after_ts, minibatch_size) + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) fill_replay_buffer( - env=env, replay_buffer=replay_buffer, desired_size=train_after_ts + env=env, + replay_buffer=replay_buffer, + desired_size=train_after_ts, + agent=agent, ) agent = Agent.create_for_env(env, policy=training_policy, device=device) diff --git a/reagent/gym/tests/test_gym_offline.py b/reagent/gym/tests/test_gym_offline.py index 1b164bca..35036e6b 100644 --- a/reagent/gym/tests/test_gym_offline.py +++ b/reagent/gym/tests/test_gym_offline.py @@ -12,6 +12,7 @@ from parameterized import parameterized from reagent.core.tensorboardX import summary_writer_context from reagent.gym.agents.agent import Agent from reagent.gym.envs import Gym +from reagent.gym.policies.random_policies import make_random_policy_for_env from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor from reagent.gym.runners.gymrunner import evaluate_for_n_episodes from reagent.gym.utils import build_normalizer, fill_replay_buffer @@ -110,8 +111,13 @@ def run_test_offline( replay_capacity=replay_memory_size, batch_size=minibatch_size ) # always fill full RB + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) fill_replay_buffer( - env=env, replay_buffer=replay_buffer, desired_size=replay_memory_size + env=env, + replay_buffer=replay_buffer, + desired_size=replay_memory_size, + agent=agent, ) device = torch.device("cuda") if use_gpu else None diff --git a/reagent/gym/tests/test_world_model.py b/reagent/gym/tests/test_world_model.py index 6f766ddf..e4727fec 100644 --- a/reagent/gym/tests/test_world_model.py +++ b/reagent/gym/tests/test_world_model.py @@ -16,6 +16,7 @@ from reagent.evaluation.world_model_evaluator import ( from reagent.gym.agents.agent import Agent from reagent.gym.envs import EnvWrapper, Gym from reagent.gym.envs.pomdp.state_embed_env import StateEmbedEnvironment +from reagent.gym.policies.random_policies import make_random_policy_for_env from reagent.gym.preprocessors import make_replay_buffer_trainer_preprocessor from reagent.gym.runners.gymrunner import evaluate_for_n_episodes from reagent.gym.utils import build_normalizer, fill_replay_buffer @@ -122,7 +123,9 @@ def train_mdnrnn( stack_size=seq_len, return_everything_as_stack=True, ) - fill_replay_buffer(env, train_replay_buffer, num_train_transitions) + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) + fill_replay_buffer(env, train_replay_buffer, num_train_transitions, agent) num_batch_per_epoch = train_replay_buffer.size // batch_size logger.info("Made RBs, starting to train now!") @@ -180,7 +183,9 @@ def train_mdnrnn_and_compute_feature_stats( stack_size=seq_len, return_everything_as_stack=True, ) - fill_replay_buffer(env, test_replay_buffer, num_test_transitions) + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) + fill_replay_buffer(env, test_replay_buffer, num_test_transitions, agent) if saved_mdnrnn_path is None: # train from scratch @@ -248,8 +253,13 @@ def create_embed_rl_dataset( embed_rb = ReplayBuffer( replay_capacity=num_state_embed_transitions, batch_size=batch_size, stack_size=1 ) + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) fill_replay_buffer( - env=embed_env, replay_buffer=embed_rb, desired_size=num_state_embed_transitions + env=embed_env, + replay_buffer=embed_rb, + desired_size=num_state_embed_transitions, + agent=agent, ) batch = embed_rb.sample_transition_batch(batch_size=num_state_embed_transitions) state_min = min(batch.state.min(), batch.next_state.min()).item() diff --git a/reagent/gym/utils.py b/reagent/gym/utils.py index 0432a02e..588aec8c 100644 --- a/reagent/gym/utils.py +++ b/reagent/gym/utils.py @@ -38,8 +38,10 @@ except ImportError: HAS_RECSIM = False -def fill_replay_buffer(env, replay_buffer: ReplayBuffer, desired_size: int): - """Fill replay buffer with random transitions until size reaches desired_size.""" +def fill_replay_buffer( + env, replay_buffer: ReplayBuffer, desired_size: int, agent: Agent +): + """Fill replay buffer with transitions until size reaches desired_size.""" assert ( 0 < desired_size and desired_size <= replay_buffer._replay_capacity ), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}." @@ -48,18 +50,15 @@ def fill_replay_buffer(env, replay_buffer: ReplayBuffer, desired_size: int): f"(more than desired_size = {desired_size})" ) logger.info( - f" Starting to fill replay buffer using random policy to size: {desired_size}." + f" Starting to fill replay buffer using policy to size: {desired_size}." ) - random_policy = make_random_policy_for_env(env) post_step = add_replay_buffer_post_step(replay_buffer, env=env) + agent.post_transition_callback = post_step - agent = Agent.create_for_env( - env, policy=random_policy, post_transition_callback=post_step - ) max_episode_steps = env.max_steps with tqdm( total=desired_size - replay_buffer.size, - desc=f"Filling replay buffer from {replay_buffer.size} to size {desired_size} using random policy", + desc=f"Filling replay buffer from {replay_buffer.size} to size {desired_size}", ) as pbar: mdp_id = 0 while replay_buffer.size < desired_size: @@ -155,7 +154,7 @@ def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]: def create_df_from_replay_buffer( - env: gym.Env, + env, problem_domain: ProblemDomain, desired_size: int, multi_steps: Optional[int], @@ -177,7 +176,9 @@ def create_df_from_replay_buffer( update_horizon=update_horizon, return_as_timeline_format=return_as_timeline_format, ) - fill_replay_buffer(env, replay_buffer, desired_size) + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) + fill_replay_buffer(env, replay_buffer, desired_size, agent) batch = replay_buffer.sample_all_valid_transitions() n = batch.state.shape[0] diff --git a/reagent/models/actor.py b/reagent/models/actor.py index 506fe0c0..f6a02dbc 100644 --- a/reagent/models/actor.py +++ b/reagent/models/actor.py @@ -13,6 +13,9 @@ from reagent.models.fully_connected_network import FullyConnectedNetwork from torch.distributions import Dirichlet from torch.distributions.normal import Normal +LOG_PROB_MIN = -2.0 +LOG_PROB_MAX = 2.0 + class StochasticActor(ModelBase): def __init__(self, scorer, sampler): @@ -86,7 +89,7 @@ class FullyConnectedActor(ModelBase): # TODO: log prob is affected by clamping, how to handle that? log_prob = ( self.noise_dist.log_prob(noise).to(action.device).sum(dim=1).view(-1, 1) - ) + ).clamp(LOG_PROB_MIN, LOG_PROB_MAX) action = (action + noise.to(action.device)).clamp( *CONTINUOUS_TRAINING_ACTION_RANGE ) @@ -136,7 +139,6 @@ class GaussianFullyConnectedActor(ModelBase): # used to calculate log-prob self.const = math.log(math.sqrt(2 * math.pi)) self.eps = 1e-6 - self._log_min_max = (-20.0, 2.0) def input_prototype(self): return rlt.FeatureData(torch.randn(1, self.state_dim)) @@ -174,7 +176,7 @@ class GaussianFullyConnectedActor(ModelBase): loc = self.loc_layer_norm(loc) scale_log = self.scale_layer_norm(scale_log) - scale_log = scale_log.clamp(*self._log_min_max) + scale_log = scale_log.clamp(LOG_PROB_MIN, LOG_PROB_MAX) return loc, scale_log def _squash_raw_action(self, raw_action: torch.Tensor) -> torch.Tensor: @@ -289,9 +291,5 @@ class DirichletFullyConnectedActor(ModelBase): # ONNX can't export Dirichlet() action = torch._sample_dirichlet(concentration) - if not self.training: - # ONNX doesn't like reshape either.. - return rlt.ActorOutput(action=action) - log_prob = Dirichlet(concentration).log_prob(action) return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1)) diff --git a/reagent/prediction/predictor_wrapper.py b/reagent/prediction/predictor_wrapper.py index 9b09caed..fa5c2070 100644 --- a/reagent/prediction/predictor_wrapper.py +++ b/reagent/prediction/predictor_wrapper.py @@ -313,7 +313,6 @@ class ActorWithPreprocessor(ModelBase): state_with_presence[0], state_with_presence[1] ) state_feature_vector = rlt.FeatureData(preprocessed_state) - # TODO: include log_prob in the output model_output = self.model(state_feature_vector) if self.serve_mean_policy: assert ( @@ -326,7 +325,7 @@ class ActorWithPreprocessor(ModelBase): if self.action_postprocessor: # pyre-fixme[29]: `Optional[Postprocessor]` is not a function. action = self.action_postprocessor(action) - return action + return (action, model_output.log_prob) def input_prototype(self): return (self.state_preprocessor.input_prototype(),) @@ -351,9 +350,8 @@ class ActorPredictorWrapper(torch.jit.ScriptModule): @torch.jit.script_method def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor] - ) -> torch.Tensor: - action = self.actor_with_preprocessor(state_with_presence) - return action + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.actor_with_preprocessor(state_with_presence) class RankingActorWithPreprocessor(ModelBase): diff --git a/reagent/test/prediction/test_predictor_wrapper.py b/reagent/test/prediction/test_predictor_wrapper.py index c186c9e3..50a209df 100644 --- a/reagent/test/prediction/test_predictor_wrapper.py +++ b/reagent/test/prediction/test_predictor_wrapper.py @@ -193,7 +193,7 @@ class TestPredictorWrapper(unittest.TestCase): ) wrapper = ActorPredictorWrapper(actor_with_preprocessor) input_prototype = actor_with_preprocessor.input_prototype() - action = wrapper(*input_prototype) + action, _log_prob = wrapper(*input_prototype) self.assertEqual(action.shape, (1, len(action_normalization_parameters))) expected_output = postprocessor( diff --git a/reagent/training/reagent_lightning_module.py b/reagent/training/reagent_lightning_module.py index 352e10e8..c5138a4d 100644 --- a/reagent/training/reagent_lightning_module.py +++ b/reagent/training/reagent_lightning_module.py @@ -54,9 +54,15 @@ class ReAgentLightningModule(pl.LightningModule): def reporter(self): return self._reporter + def set_clean_stop(self, clean_stop: bool): + if clean_stop: + self._cleanly_stopped = torch.ones(1) + else: + self._cleanly_stopped = torch.zeros(1) + def increase_next_stopping_epochs(self, num_epochs: int): self._next_stopping_epoch += num_epochs - self._cleanly_stopped[0] = torch.zeros(1) + self.set_clean_stop(False) return self def train_step_gen(self, training_batch, batch_idx: int): diff --git a/reagent/training/sac_trainer.py b/reagent/training/sac_trainer.py index be0947f0..e1ccd932 100644 --- a/reagent/training/sac_trainer.py +++ b/reagent/training/sac_trainer.py @@ -12,11 +12,11 @@ from reagent.core.configuration import resolve_defaults from reagent.core.dataclasses import dataclass from reagent.core.dataclasses import field from reagent.core.parameters import RLParameters +from reagent.models.actor import LOG_PROB_MIN, LOG_PROB_MAX from reagent.optimizer import Optimizer__Union, SoftUpdate from reagent.training.reagent_lightning_module import ReAgentLightningModule from reagent.training.rl_trainer_pytorch import RLTrainerMixin - logger = logging.getLogger(__name__) @@ -229,8 +229,7 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule): log_prob_a = self.actor_network.get_log_prob( training_batch.next_state, next_state_actor_output.action - ) - log_prob_a = log_prob_a.clamp(-20.0, 20.0) + ).clamp(LOG_PROB_MIN, LOG_PROB_MAX) next_state_value -= self.entropy_temperature * log_prob_a if self.gamma > 0.0: @@ -263,7 +262,7 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule): q2_actor_value = self.q2_network(*state_actor_action) min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) - actor_log_prob = actor_output.log_prob + actor_log_prob = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX) if not self.backprop_through_log_prob: actor_log_prob = actor_log_prob.detach() @@ -309,7 +308,10 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule): alpha_loss = -( ( self.log_alpha - * (actor_output.log_prob + self.target_entropy).detach() + * ( + actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX) + + self.target_entropy + ).detach() ).mean() ) yield alpha_loss @@ -327,8 +329,7 @@ class SACTrainer(RLTrainerMixin, ReAgentLightningModule): log_prob_a = torch.zeros_like(min_q_actor_value) target_value = min_q_actor_value else: - log_prob_a = actor_output.log_prob - log_prob_a = log_prob_a.clamp(-20.0, 20.0) + log_prob_a = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX) target_value = min_q_actor_value - self.entropy_temperature * log_prob_a value_loss = F.mse_loss(state_value, target_value.detach()) diff --git a/reagent/workflow/cli.py b/reagent/workflow/cli.py index 03effd79..dded7368 100755 --- a/reagent/workflow/cli.py +++ b/reagent/workflow/cli.py @@ -4,6 +4,7 @@ import dataclasses import importlib +import json import logging import os import sys @@ -58,7 +59,8 @@ def select_relevant_params(config_dict, ConfigClass): @reagent.command(short_help="Run the workflow with config file") @click.argument("workflow") @click.argument("config_file", type=click.File("r")) -def run(workflow, config_file): +@click.option("--extra-options", default=None) +def run(workflow, config_file, extra_options): func, ConfigClass = _load_func_and_config_class(workflow) @@ -70,6 +72,8 @@ def run(workflow, config_file): yaml = YAML(typ="safe") config_dict = yaml.load(config_file.read()) assert config_dict is not None, "failed to read yaml file" + if extra_options is not None: + config_dict.update(json.loads(extra_options)) config_dict = select_relevant_params(config_dict, ConfigClass) config = ConfigClass(**config_dict) func(**config.asdict()) diff --git a/reagent/workflow/gym_batch_rl.py b/reagent/workflow/gym_batch_rl.py index 3cb24c1c..f8b85e9a 100644 --- a/reagent/workflow/gym_batch_rl.py +++ b/reagent/workflow/gym_batch_rl.py @@ -14,6 +14,7 @@ from reagent.data.spark_utils import call_spark_class, get_spark_session from reagent.gym.agents.agent import Agent from reagent.gym.envs import Gym from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model +from reagent.gym.policies.random_policies import make_random_policy_for_env from reagent.gym.runners.gymrunner import evaluate_for_n_episodes from reagent.gym.utils import fill_replay_buffer from reagent.model_managers.union import ModelManager__Union @@ -34,7 +35,7 @@ def initialize_seed(seed: Optional[int] = None): torch.manual_seed(seed) -def offline_gym( +def offline_gym_random( env_name: str, pkl_path: str, num_train_transitions: int, @@ -42,14 +43,46 @@ def offline_gym( seed: int = 1, ): """ - Generate samples from a DiscreteRandomPolicy on the Gym environment and + Generate samples from a random Policy on the Gym environment and saves results in a pandas df parquet. """ - initialize_seed(seed) env = Gym(env_name=env_name) + random_policy = make_random_policy_for_env(env) + agent = Agent.create_for_env(env, policy=random_policy) + return _offline_gym(env, agent, pkl_path, num_train_transitions, max_steps, seed) + + +def offline_gym_predictor( + env_name: str, + model: ModelManager__Union, + publisher: ModelPublisher__Union, + pkl_path: str, + num_train_transitions: int, + max_steps: Optional[int], + module_name: str = "default_model", + seed: int = 1, +): + """ + Generate samples from a trained Policy on the Gym environment and + saves results in a pandas df parquet. + """ + env = Gym(env_name=env_name) + agent = make_agent_from_model(env, model, publisher, module_name) + return _offline_gym(env, agent, pkl_path, num_train_transitions, max_steps, seed) + + +def _offline_gym( + env: Gym, + agent: Agent, + pkl_path: str, + num_train_transitions: int, + max_steps: Optional[int], + seed: int = 1, +): + initialize_seed(seed) replay_buffer = ReplayBuffer(replay_capacity=num_train_transitions, batch_size=1) - fill_replay_buffer(env, replay_buffer, num_train_transitions) + fill_replay_buffer(env, replay_buffer, num_train_transitions, agent) if isinstance(env.action_space, gym.spaces.Discrete): is_discrete_action = True else: @@ -90,6 +123,27 @@ def timeline_operator(pkl_path: str, input_table_spec: TableSpec): call_spark_class(spark, class_name="Timeline", args=json.dumps(arg)) +def make_agent_from_model( + env: Gym, + model: ModelManager__Union, + publisher: ModelPublisher__Union, + module_name: str, +): + publisher_manager = publisher.value + assert isinstance( + publisher_manager, FileSystemPublisher + ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher" + module_names = model.value.serving_module_names() + assert module_name in module_names, f"{module_name} not in {module_names}" + torchscript_path = publisher_manager.get_latest_published_model( + model.value, module_name + ) + jit_model = torch.jit.load(torchscript_path) + policy = create_predictor_policy_from_model(jit_model) + agent = Agent.create_for_env_with_serving_policy(env, policy) + return agent + + def evaluate_gym( env_name: str, model: ModelManager__Union, @@ -100,26 +154,17 @@ def evaluate_gym( max_steps: Optional[int] = None, ): initialize_seed(1) - publisher_manager = publisher.value - assert isinstance( - publisher_manager, FileSystemPublisher - ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher" env = Gym(env_name=env_name) - module_names = model.value.serving_module_names() - assert module_name in module_names, f"{module_name} not in {module_names}" - torchscript_path = publisher_manager.get_latest_published_model( - model.value, module_name - ) - jit_model = torch.jit.load(torchscript_path) - policy = create_predictor_policy_from_model(jit_model) - agent = Agent.create_for_env_with_serving_policy(env, policy) + agent = make_agent_from_model(env, model, publisher, module_name) + rewards = evaluate_for_n_episodes( n=num_eval_episodes, env=env, agent=agent, max_steps=max_steps ) avg_reward = np.mean(rewards) logger.info( f"Average reward over {num_eval_episodes} is {avg_reward}.\n" - f"List of rewards: {rewards}" + f"List of rewards: {rewards}\n" + f"Passing score bar: {passing_score_bar}" ) assert ( avg_reward >= passing_score_bar diff --git a/reagent/workflow/sample_configs/sac_pendulum_offline.yaml b/reagent/workflow/sample_configs/sac_pendulum_offline.yaml index 67beec9a..86d4979b 100644 --- a/reagent/workflow/sample_configs/sac_pendulum_offline.yaml +++ b/reagent/workflow/sample_configs/sac_pendulum_offline.yaml @@ -11,6 +11,7 @@ model: rl: gamma: 0.9 target_update_rate: 0.5 + softmax_policy: true entropy_temperature: 0.01 q_network_optimizer: Adam: @@ -21,7 +22,6 @@ model: actor_network_optimizer: Adam: lr: 0.001 - alpha_optimizer: null actor_net_builder: GaussianFullyConnected: sizes: @@ -55,7 +55,7 @@ model: calc_cpe_in_training: false num_train_transitions: 40000 # approx. 200 episodes -max_steps: 200 +max_steps: 1000 seed: 42 num_epochs: 80 publisher: @@ -65,3 +65,4 @@ num_eval_episodes: 30 passing_score_bar: -1000 reader_options: minibatch_size: 1024 +warmstart_path: test_warmstart diff --git a/reagent/workflow/utils.py b/reagent/workflow/utils.py index 0f4a19d5..fc9a5958 100644 --- a/reagent/workflow/utils.py +++ b/reagent/workflow/utils.py @@ -142,4 +142,8 @@ def train_eval_lightning( ) trainer.fit(trainer_module, datamodule=datamodule) trainer.test() + if checkpoint_path is not None: + # Overwrite the warmstart path with the new model + trainer_module.set_clean_stop(True) + trainer.save_checkpoint(checkpoint_path) return trainer diff --git a/scripts/recurring_training_sac_offline.sh b/scripts/recurring_training_sac_offline.sh new file mode 100644 index 00000000..443b2649 --- /dev/null +++ b/scripts/recurring_training_sac_offline.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +set -x -e + +rm -f /tmp/file_system_publisher +rm -Rf test_warmstart model_* pl_log* runs + +CONFIG=reagent/workflow/sample_configs/sac_pendulum_offline.yaml + +python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random "$CONFIG" +rm -Rf spark-warehouse derby.log metastore_db +python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG" +python ./reagent/workflow/cli.py run reagent.workflow.training.identify_and_train_network "$CONFIG" +python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.evaluate_gym "$CONFIG" + +for _ in {0..30} +do +python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_predictor "$CONFIG" +rm -Rf spark-warehouse derby.log metastore_db +python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG" +python ./reagent/workflow/cli.py run reagent.workflow.training.identify_and_train_network "$CONFIG" +python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.evaluate_gym "$CONFIG" +done