Files
ReAgent/reagent/gym/utils.py
T
Jason Gauci 39385e8d83 Tune SAC and CRR Models. Initial support for batch gym training (#470)
Summary: Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/470

Reviewed By: czxttkl

Differential Revision: D28093192

fbshipit-source-id: 6b260c3e8d49c8b302e40066e2be49a0bfe96688
2021-05-18 09:27:48 -07:00

409 lines
15 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import logging
import random
from typing import Dict, List, Optional
import gym
import numpy as np
import pandas as pd
import torch # @manual
import torch.nn.functional as F
from gym import spaces
from reagent.core.parameters import NormalizationData, NormalizationKey, ProblemDomain
from reagent.gym.agents.agent import Agent
from reagent.gym.agents.post_step import add_replay_buffer_post_step
from reagent.gym.envs import EnvWrapper
from reagent.gym.normalizers import (
only_continuous_normalizer,
discrete_action_normalizer,
only_continuous_action_normalizer,
)
from reagent.gym.policies.random_policies import make_random_policy_for_env
from reagent.gym.runners.gymrunner import run_episode
from reagent.replay_memory import ReplayBuffer
from tqdm import tqdm
logger = logging.getLogger(__name__)
SEED = 0
try:
from reagent.gym.envs import RecSim # noqa
HAS_RECSIM = True
except ImportError:
HAS_RECSIM = False
def fill_replay_buffer(
env, replay_buffer: ReplayBuffer, desired_size: int, agent: Agent
):
"""Fill replay buffer with transitions until size reaches desired_size."""
assert (
0 < desired_size and desired_size <= replay_buffer._replay_capacity
), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}."
assert replay_buffer.size < desired_size, (
f"Replay buffer already has {replay_buffer.size} elements. "
f"(more than desired_size = {desired_size})"
)
logger.info(
f" Starting to fill replay buffer using policy to size: {desired_size}."
)
post_step = add_replay_buffer_post_step(replay_buffer, env=env)
agent.post_transition_callback = post_step
max_episode_steps = env.max_steps
with tqdm(
total=desired_size - replay_buffer.size,
desc=f"Filling replay buffer from {replay_buffer.size} to size {desired_size}",
) as pbar:
mdp_id = 0
while replay_buffer.size < desired_size:
last_size = replay_buffer.size
max_steps = desired_size - replay_buffer.size - 1
if max_episode_steps is not None:
max_steps = min(max_episode_steps, max_steps)
run_episode(env=env, agent=agent, mdp_id=mdp_id, max_steps=max_steps)
size_delta = replay_buffer.size - last_size
# The assertion below is commented out because it can't
# support input samples which has seq_len>1. This should be
# treated as a bug, and need to be fixed in the future.
# assert (
# size_delta >= 0
# ), f"size delta is {size_delta} which should be non-negative."
pbar.update(n=size_delta)
mdp_id += 1
if size_delta <= 0:
# replay buffer size isn't increasing... so stop early
break
if replay_buffer.size >= desired_size:
logger.info(f"Successfully filled replay buffer to size: {replay_buffer.size}!")
else:
logger.info(
f"Stopped early and filled replay buffer to size: {replay_buffer.size}."
)
def build_state_normalizer(env: EnvWrapper):
if isinstance(env.observation_space, spaces.Box):
assert (
len(env.observation_space.shape) == 1
), f"{env.observation_space.shape} has dim > 1, and is not supported."
return only_continuous_normalizer(
list(range(env.observation_space.shape[0])),
env.observation_space.low,
env.observation_space.high,
)
elif isinstance(env.observation_space, spaces.Dict):
# assuming env.observation_space is image
return None
else:
raise NotImplementedError(f"{env.observation_space} not supported")
def build_action_normalizer(env: EnvWrapper):
action_space = env.action_space
if isinstance(action_space, spaces.Discrete):
return discrete_action_normalizer(list(range(action_space.n)))
elif isinstance(action_space, spaces.Box):
assert (
len(action_space.shape) == 1
), f"Box action shape {action_space.shape} not supported."
action_dim = action_space.shape[0]
return only_continuous_action_normalizer(
list(range(action_dim)),
min_value=action_space.low,
max_value=action_space.high,
)
else:
raise NotImplementedError(f"{action_space} not supported.")
def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]:
try:
return env.normalization_data
except AttributeError:
# TODO: make this a property of EnvWrapper?
# pyre-fixme[16]: Module `envs` has no attribute `RecSim`.
if HAS_RECSIM and isinstance(env, RecSim):
return {
NormalizationKey.STATE: NormalizationData(
dense_normalization_parameters=only_continuous_normalizer(
list(range(env.observation_space["user"].shape[0]))
)
),
NormalizationKey.ITEM: NormalizationData(
dense_normalization_parameters=only_continuous_normalizer(
list(range(env.observation_space["doc"]["0"].shape[0]))
)
),
}
return {
NormalizationKey.STATE: NormalizationData(
dense_normalization_parameters=build_state_normalizer(env)
),
NormalizationKey.ACTION: NormalizationData(
dense_normalization_parameters=build_action_normalizer(env)
),
}
def create_df_from_replay_buffer(
env,
problem_domain: ProblemDomain,
desired_size: int,
multi_steps: Optional[int],
ds: str,
) -> pd.DataFrame:
# fill the replay buffer
set_seed(env, SEED)
if multi_steps is None:
update_horizon = 1
return_as_timeline_format = False
else:
update_horizon = multi_steps
return_as_timeline_format = True
is_multi_steps = multi_steps is not None
replay_buffer = ReplayBuffer(
replay_capacity=desired_size,
batch_size=1,
update_horizon=update_horizon,
return_as_timeline_format=return_as_timeline_format,
)
random_policy = make_random_policy_for_env(env)
agent = Agent.create_for_env(env, policy=random_policy)
fill_replay_buffer(env, replay_buffer, desired_size, agent)
batch = replay_buffer.sample_all_valid_transitions()
n = batch.state.shape[0]
logger.info(f"Creating df of size {n}.")
def discrete_feat_transform(elem) -> str:
"""query data expects str format"""
return str(elem.item())
def continuous_feat_transform(elem: List[float]) -> Dict[int, float]:
"""query data expects sparse format"""
assert isinstance(elem, torch.Tensor), f"{type(elem)} isn't tensor"
assert len(elem.shape) == 1, f"{elem.shape} isn't 1-dimensional"
return {i: s.item() for i, s in enumerate(elem)}
def make_parametric_feat_transform(one_hot_dim: int):
"""one-hot and then continuous_feat_transform"""
def transform(elem) -> Dict[int, float]:
elem_tensor = torch.tensor(elem.item())
one_hot_feat = F.one_hot(elem_tensor, one_hot_dim).float()
return continuous_feat_transform(one_hot_feat)
return transform
state_features = feature_transform(batch.state, continuous_feat_transform)
next_state_features = feature_transform(
batch.next_state,
continuous_feat_transform,
is_next_with_multi_steps=is_multi_steps,
)
if problem_domain == ProblemDomain.DISCRETE_ACTION:
# discrete action is str
action = feature_transform(batch.action, discrete_feat_transform)
next_action = feature_transform(
batch.next_action,
discrete_feat_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal="",
terminal=batch.terminal,
)
elif problem_domain == ProblemDomain.PARAMETRIC_ACTION:
# continuous action is Dict[int, double]
assert isinstance(env.action_space, gym.spaces.Discrete)
parametric_feat_transform = make_parametric_feat_transform(env.action_space.n)
action = feature_transform(batch.action, parametric_feat_transform)
next_action = feature_transform(
batch.next_action,
parametric_feat_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal={},
terminal=batch.terminal,
)
elif problem_domain == ProblemDomain.CONTINUOUS_ACTION:
action = feature_transform(batch.action, continuous_feat_transform)
next_action = feature_transform(
batch.next_action,
continuous_feat_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal={},
terminal=batch.terminal,
)
elif problem_domain == ProblemDomain.MDN_RNN:
action = feature_transform(batch.action, discrete_feat_transform)
assert multi_steps is not None
next_action = feature_transform(
batch.next_action,
discrete_feat_transform,
is_next_with_multi_steps=True,
replace_when_terminal="",
terminal=batch.terminal,
)
else:
raise NotImplementedError(f"model type: {problem_domain}.")
if multi_steps is None:
time_diff = [1] * n
reward = batch.reward.squeeze(1).tolist()
metrics = [{"reward": r} for r in reward]
else:
time_diff = [[1] * len(ns) for ns in next_state_features]
reward = [reward_list.tolist() for reward_list in batch.reward]
metrics = [
[{"reward": r.item()} for r in reward_list] for reward_list in batch.reward
]
# TODO(T67265031): change this to int
mdp_id = [str(i.item()) for i in batch.mdp_id]
sequence_number = batch.sequence_number.squeeze(1).tolist()
# in the product data, all sequence_number_ordinal start from 1.
# So to be consistent with the product data.
sequence_number_ordinal = (batch.sequence_number.squeeze(1) + 1).tolist()
action_probability = batch.log_prob.exp().squeeze(1).tolist()
df_dict = {
"state_features": state_features,
"next_state_features": next_state_features,
"action": action,
"next_action": next_action,
"reward": reward,
"action_probability": action_probability,
"metrics": metrics,
"time_diff": time_diff,
"mdp_id": mdp_id,
"sequence_number": sequence_number,
"sequence_number_ordinal": sequence_number_ordinal,
"ds": [ds] * n,
}
if problem_domain == ProblemDomain.PARAMETRIC_ACTION:
# Possible actions are List[Dict[int, float]]
assert isinstance(env.action_space, gym.spaces.Discrete)
possible_actions = [{i: 1.0} for i in range(env.action_space.n)]
elif problem_domain == ProblemDomain.DISCRETE_ACTION:
# Possible actions are List[str]
assert isinstance(env.action_space, gym.spaces.Discrete)
possible_actions = [str(i) for i in range(env.action_space.n)]
elif problem_domain == ProblemDomain.MDN_RNN:
# Possible actions are List[str]
assert isinstance(env.action_space, gym.spaces.Discrete)
possible_actions = [str(i) for i in range(env.action_space.n)]
# these are fillers, which should have correct shape
pa_features = range(n)
pna_features = time_diff
if problem_domain in (
ProblemDomain.DISCRETE_ACTION,
ProblemDomain.PARAMETRIC_ACTION,
ProblemDomain.MDN_RNN,
):
def pa_transform(x):
return possible_actions
df_dict["possible_actions"] = feature_transform(pa_features, pa_transform)
df_dict["possible_next_actions"] = feature_transform(
pna_features,
pa_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal=[],
terminal=batch.terminal,
)
df = pd.DataFrame(df_dict)
# validate df
validate_mdp_ids_seq_nums(df)
# shuffling (sample the whole batch)
df = df.reindex(np.random.permutation(df.index))
return df
def set_seed(env: gym.Env, seed: int):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
env.seed(seed)
env.action_space.seed(seed)
def feature_transform(
features,
single_elem_transform,
is_next_with_multi_steps=False,
replace_when_terminal=None,
terminal=None,
):
"""feature_transform is a method on a single row.
We assume features is List[features] (batch of features).
This can also be called for next_features with multi_steps which we assume
to be List[List[features]]. First List is denoting that it's a batch,
second List is denoting that a single row consists of a list of features.
"""
if is_next_with_multi_steps:
if terminal is None:
return [
[single_elem_transform(feat) for feat in multi_steps_features]
for multi_steps_features in features
]
else:
# for next features where we replace them when terminal
assert replace_when_terminal is not None
return [
[single_elem_transform(feat) for feat in multi_steps_features]
if not terminal[idx]
else [single_elem_transform(feat) for feat in multi_steps_features[:-1]]
+ [replace_when_terminal]
for idx, multi_steps_features in enumerate(features)
]
else:
if terminal is None:
return [single_elem_transform(feat) for feat in features]
else:
assert replace_when_terminal is not None
return [
single_elem_transform(feat)
if not terminal[idx]
else replace_when_terminal
for idx, feat in enumerate(features)
]
def validate_mdp_ids_seq_nums(df):
mdp_ids = list(df["mdp_id"])
sequence_numbers = list(df["sequence_number"])
unique_mdp_ids = set(mdp_ids)
prev_mdp_id, prev_seq_num = None, None
mdp_count = 0
for mdp_id, seq_num in zip(mdp_ids, sequence_numbers):
if prev_mdp_id is None or mdp_id != prev_mdp_id:
mdp_count += 1
prev_mdp_id = mdp_id
else:
assert seq_num == prev_seq_num + 1, (
f"For mdp_id {mdp_id}, got {seq_num} <= {prev_seq_num}."
f"Sequence number must be in increasing order.\n"
f"Zip(mdp_id, seq_num): "
f"{list(zip(mdp_ids, sequence_numbers))}"
)
prev_seq_num = seq_num
assert len(unique_mdp_ids) == mdp_count, "MDPs are broken up. {} vs {}".format(
len(unique_mdp_ids), mdp_count
)
return