mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Get end-to-end training working
Summary: title Reviewed By: MisterTea Differential Revision: D10440251 fbshipit-source-id: 5bb237b695dad63e2ad820273e65ac65cbd19533
This commit is contained in:
committed by
Facebook Github Bot
parent
329a0c0e8e
commit
f33f3d0cce
+3
-16
@@ -65,15 +65,9 @@ RUN wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz -O t
|
||||
make install
|
||||
|
||||
# Install Java & maven.
|
||||
# Taken from https://github.com/dockerfile/java/blob/master/oracle-java8/Dockerfile
|
||||
RUN \
|
||||
echo oracle-java8-installer shared/accepted-oracle-license-v1-1 select true | debconf-set-selections && \
|
||||
add-apt-repository -y ppa:webupd8team/java && \
|
||||
apt-get update && \
|
||||
apt-get install -y oracle-java8-installer maven && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -rf /var/cache/oracle-jdk8-installer
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-8-oracle
|
||||
RUN apt-get install -y openjdk-8-jre && \
|
||||
apt-get install -y maven
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
|
||||
|
||||
# Install Spark
|
||||
RUN wget http://www-eu.apache.org/dist/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz && \
|
||||
@@ -108,12 +102,5 @@ ADD ./requirements.txt requirements.txt
|
||||
RUN ./install_prereqs.sh
|
||||
RUN rm install_prereqs.sh requirements.txt
|
||||
|
||||
# Install and run key tests, this should finish quickly
|
||||
RUN git clone https://github.com/facebookresearch/Horizon.git && \
|
||||
cd Horizon && \
|
||||
thrift --gen py --out . ml/rl/thrift/core.thrift && \
|
||||
pip install -e . && \
|
||||
python -m ml.rl.test.workflow.test_oss_workflows
|
||||
|
||||
# Define default command.
|
||||
CMD ["bash"]
|
||||
|
||||
+3
-15
@@ -65,15 +65,9 @@ RUN wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz -O t
|
||||
make install
|
||||
|
||||
# Install Java & maven.
|
||||
# Taken from https://github.com/dockerfile/java/blob/master/oracle-java8/Dockerfile
|
||||
RUN \
|
||||
echo oracle-java8-installer shared/accepted-oracle-license-v1-1 select true | debconf-set-selections && \
|
||||
add-apt-repository -y ppa:webupd8team/java && \
|
||||
apt-get update && \
|
||||
apt-get install -y oracle-java8-installer maven && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -rf /var/cache/oracle-jdk8-installer
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-8-oracle
|
||||
RUN apt-get install -y openjdk-8-jre && \
|
||||
apt-get install -y maven
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
|
||||
|
||||
ENV HOME /home
|
||||
WORKDIR ${HOME}/
|
||||
@@ -106,12 +100,6 @@ ADD ./requirements.txt requirements.txt
|
||||
RUN ./install_prereqs.sh
|
||||
RUN rm install_prereqs.sh requirements.txt
|
||||
|
||||
# Install and run key tests, this should finish quickly
|
||||
RUN git clone https://github.com/facebookresearch/Horizon.git && \
|
||||
cd Horizon && \
|
||||
thrift --gen py --out . ml/rl/thrift/core.thrift && \
|
||||
pip install -e . && \
|
||||
python -m ml.rl.test.workflow.test_oss_workflows
|
||||
|
||||
# Define default command.
|
||||
CMD ["bash"]
|
||||
|
||||
+12
-2
@@ -23,10 +23,20 @@ On Linux you can build the image with specific memory allocations from command l
|
||||
docker build -t horizon:dev --memory=8g --memory-swap=8g .
|
||||
```
|
||||
|
||||
Once the Docker image is built you can start an interactive shell in the container and run the unit tests:
|
||||
Once the Docker image is built you can start an interactive shell in the container and run the unit tests. To have the ability to edit files locally and have changes be available in the Docker container, mount the local Horizon repo as a volume:
|
||||
```
|
||||
docker run -it horizon:dev
|
||||
docker run -v /<LOCAL_PATH_TO_HORIZON>/Horizon:/home/Horizon -it horizon:dev
|
||||
cd Horizon
|
||||
```
|
||||
Depending on where your local Horizon copy is, you may need to white list your shared path via Docker -> Preferences... -> File Sharing.
|
||||
|
||||
Run the setup file:
|
||||
```
|
||||
bash scripts/setup.sh
|
||||
```
|
||||
|
||||
Now you can run the tests:
|
||||
```
|
||||
python setup.py test
|
||||
```
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
"use_noisy_linear_layers": false
|
||||
},
|
||||
"run_details": {
|
||||
"num_episodes": 100,
|
||||
"num_episodes": 400,
|
||||
"max_steps": 200,
|
||||
"train_every_ts": 1,
|
||||
"train_after_ts": 1,
|
||||
|
||||
@@ -11,6 +11,7 @@ from ml.rl.test.gym.gym_predictor import (
|
||||
GymDQNPredictorPytorch,
|
||||
)
|
||||
from ml.rl.test.utils import default_normalizer
|
||||
from ml.rl.training.dqn_predictor import DQNPredictor
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
@@ -25,7 +26,7 @@ class EnvType(enum.Enum):
|
||||
|
||||
|
||||
class OpenAIGymEnvironment:
|
||||
def __init__(self, gymenv, epsilon, softmax_policy, gamma):
|
||||
def __init__(self, gymenv, epsilon=0, softmax_policy=False, gamma=0.99):
|
||||
"""
|
||||
Creates an OpenAIGymEnvironment object.
|
||||
|
||||
@@ -129,6 +130,13 @@ class OpenAIGymEnvironment:
|
||||
if test:
|
||||
return predictor.policy(next_state)[0]
|
||||
return predictor.policy(next_state, add_action_noise=True)[0]
|
||||
elif isinstance(predictor, DQNPredictor):
|
||||
# Use DQNPredictor directly - useful to test caffe2 predictor
|
||||
sparse_next_states = predictor.in_order_dense_to_sparse(next_state)
|
||||
q_values = predictor.predict(sparse_next_states)
|
||||
action_idx = max(q_values[0], key=q_values[0].get)
|
||||
action[int(action_idx)] = 1.0
|
||||
return action
|
||||
else:
|
||||
raise NotImplementedError("Unknown predictor type")
|
||||
|
||||
|
||||
@@ -160,8 +160,10 @@ def train_gym_online_rl(
|
||||
if gym_env.action_type == EnvType.DISCRETE_ACTION:
|
||||
action_index = np.argmax(action)
|
||||
next_state, reward, terminal, _ = gym_env.env.step(action_index)
|
||||
action_to_log = str(action_index)
|
||||
else:
|
||||
next_state, reward, terminal, _ = gym_env.env.step(action)
|
||||
action_to_log = action.tolist()
|
||||
next_state = gym_env.transform_state(next_state)
|
||||
|
||||
ep_timesteps += 1
|
||||
@@ -192,7 +194,7 @@ def train_gym_online_rl(
|
||||
i,
|
||||
ep_timesteps - 1,
|
||||
state.tolist(),
|
||||
action.tolist(),
|
||||
action_to_log,
|
||||
reward,
|
||||
next_state.tolist(),
|
||||
next_action.tolist(),
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from ml.rl.test.gym.open_ai_gym_environment import OpenAIGymEnvironment
|
||||
from ml.rl.training.dqn_predictor import DQNPredictor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
ENV = "CartPole-v0"
|
||||
AVG_OVER_NUM_EPS = 20
|
||||
|
||||
|
||||
def main(model_path):
|
||||
predictor = DQNPredictor.load(model_path, "minidb", int_features=False)
|
||||
|
||||
env = OpenAIGymEnvironment(gymenv=ENV)
|
||||
|
||||
avg_rewards, avg_discounted_rewards = env.run_ep_n_times(
|
||||
AVG_OVER_NUM_EPS, predictor, test=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Achieved an average reward score of {} over {} evaluations.".format(
|
||||
avg_rewards, AVG_OVER_NUM_EPS
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def parse_args(args):
|
||||
if len(args) != 3:
|
||||
raise Exception("Usage: python <file.py> -m <parameters_file>")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Read command line parameters.")
|
||||
parser.add_argument("-m", "--model", help="Path to Caffe2 model.")
|
||||
args = parser.parse_args(args[1:])
|
||||
return args.model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = parse_args(sys.argv)
|
||||
main(model_path)
|
||||
@@ -15,6 +15,7 @@ from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class CPE_Estimate(NamedTuple):
|
||||
|
||||
@@ -29,6 +29,7 @@ class RLDataset:
|
||||
data = self.rows
|
||||
if pre_timeline_format:
|
||||
data = self.pre_timeline_format_rows
|
||||
|
||||
with open(self.file_path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
@@ -53,7 +54,7 @@ class RLDataset:
|
||||
"""
|
||||
|
||||
assert isinstance(state, list)
|
||||
assert isinstance(action, list)
|
||||
assert isinstance(action, (list, str))
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(next_state, list)
|
||||
assert isinstance(next_action, list)
|
||||
@@ -87,12 +88,15 @@ class RLDataset:
|
||||
# This assumes that every state feature is present in every training example.
|
||||
int_state_feature_keys = [int(k) for k in state_features.keys()]
|
||||
idx_bump = max(int_state_feature_keys) + 1
|
||||
if isinstance(action, list):
|
||||
# Parametric or continuous action domain
|
||||
action = {str(i + idx_bump): v for i, v in enumerate(action)}
|
||||
|
||||
if isinstance(possible_actions, list):
|
||||
if len(possible_actions) == 0:
|
||||
pass
|
||||
elif isinstance(possible_actions[0], int):
|
||||
# Discrete action domain
|
||||
action = str(np.argmax(action))
|
||||
possible_actions = [
|
||||
str(idx) for idx, val in enumerate(possible_actions) if val == 1
|
||||
]
|
||||
|
||||
@@ -58,6 +58,14 @@ class RLPredictor:
|
||||
def predict_net(self):
|
||||
return self._net
|
||||
|
||||
def in_order_dense_to_sparse(self, dense):
|
||||
"""Convert dense observation to sparse observation assuming in order
|
||||
feature ids."""
|
||||
sparse = []
|
||||
for row in dense:
|
||||
sparse.append({str(k): v for k, v in enumerate(row)})
|
||||
return sparse
|
||||
|
||||
def predict(self, float_state_features, int_state_features=None):
|
||||
""" Returns values for each state
|
||||
:param float_state_features A list of feature -> float value dict examples
|
||||
|
||||
@@ -22,7 +22,6 @@ from ml.rl.workflow.training_data_reader import JSONDataset
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
NORMALIZATION_BATCH_READ_SIZE = 50000
|
||||
|
||||
|
||||
@@ -50,7 +49,7 @@ def get_norm_metadata(dataset, norm_params, norm_col):
|
||||
samples_per_feature, samples = defaultdict(int), defaultdict(list)
|
||||
|
||||
while not done:
|
||||
if not batch or len(batch[norm_col]) == 0:
|
||||
if batch is None or len(batch[norm_col]) == 0:
|
||||
logger.info("No more data in training data. Breaking.")
|
||||
break
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
DEFAULT_NUM_SAMPLES_FOR_CPE = 5000
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
"addTerminalStateRow": false,
|
||||
"actionDiscrete": false,
|
||||
"inputTableName": "pendulum",
|
||||
"outputTableName": "pendulum_training_data"
|
||||
"outputTableName": "pendulum_training_data",
|
||||
"numOutputShards": 1
|
||||
},
|
||||
"query": {
|
||||
"discountFactor": 0.9,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"training_data_path": "~/cartpole_training_data.json",
|
||||
"state_norm_data_path": "~/state_features_norm.json",
|
||||
"model_output_path": "~/",
|
||||
"training_data_path": "cartpole_training_data.json",
|
||||
"state_norm_data_path": "state_features_norm.json",
|
||||
"model_output_path": "./",
|
||||
"use_gpu": true,
|
||||
"use_all_avail_gpus": true,
|
||||
"norm_params": {
|
||||
"output_dir": "~/",
|
||||
"output_dir": "./",
|
||||
"cols_to_norm": [
|
||||
"state_features"
|
||||
],
|
||||
@@ -15,7 +15,7 @@
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"epochs": 1,
|
||||
"epochs": 100,
|
||||
"rl": {
|
||||
"gamma": 0.99,
|
||||
"target_update_rate": 0.2,
|
||||
@@ -41,7 +41,7 @@
|
||||
"relu",
|
||||
"linear"
|
||||
],
|
||||
"minibatch_size": 512,
|
||||
"minibatch_size": 1024,
|
||||
"learning_rate": 0.001,
|
||||
"optimizer": "ADAM",
|
||||
"lr_decay": 0.999,
|
||||
|
||||
@@ -18,4 +18,4 @@
|
||||
"1"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
"addTerminalStateRow": false,
|
||||
"actionDiscrete": false,
|
||||
"inputTableName": "cartpole_parametric",
|
||||
"outputTableName": "cartpole_parametric_training_data"
|
||||
"outputTableName": "cartpole_parametric_training_data",
|
||||
"numOutputShards": 1
|
||||
},
|
||||
"query": {
|
||||
"discountFactor": 0.9,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
package com.facebook.spark.rl
|
||||
|
||||
import org.slf4j.LoggerFactory
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.fasterxml.jackson.module.scala.DefaultScalaModule
|
||||
import org.apache.spark.sql._
|
||||
@@ -16,6 +17,9 @@ case class QueryConfiguration(discountFactor: Double = 0.9,
|
||||
actions: Array[String] = Array())
|
||||
|
||||
object Preprocessor {
|
||||
|
||||
private val log = LoggerFactory.getLogger(this.getClass.getName)
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val sparkSession = SparkSession.builder().enableHiveSupport().getOrCreate()
|
||||
sparkSession.sqlContext.udf.register("COMPUTE_EPISODE_VALUE", Udfs.getEpisodeValue[Double] _)
|
||||
@@ -62,10 +66,14 @@ object Preprocessor {
|
||||
Query.getContinuousQuery(queryConfig)
|
||||
}
|
||||
|
||||
val sqlCommand = query.concat(
|
||||
s" FROM ${timelineConfig.outputTableName} where rand() <= ${queryConfig.tableSample}")
|
||||
|
||||
log.info("Executing query: ")
|
||||
log.info(sqlCommand)
|
||||
|
||||
// Query the results
|
||||
val outputDf = sparkSession.sql(
|
||||
query.concat(
|
||||
s" FROM ${timelineConfig.outputTableName} where rand() <= ${queryConfig.tableSample}"))
|
||||
val outputDf = sparkSession.sql(sqlCommand)
|
||||
|
||||
outputDf.show()
|
||||
outputDf
|
||||
|
||||
@@ -11,7 +11,7 @@ case class TimelineConfiguration(startDs: String,
|
||||
actionDiscrete: Boolean,
|
||||
inputTableName: String,
|
||||
outputTableName: String,
|
||||
numOutputShards: Int = 1)
|
||||
numOutputShards: Int)
|
||||
|
||||
/**
|
||||
* Given table of state, action, mdp_id, sequence_number, reward, possible_next_actions
|
||||
|
||||
@@ -22,7 +22,8 @@ class TimelineTest extends PipelineTester {
|
||||
false,
|
||||
true,
|
||||
"some_rl_input",
|
||||
"some_rl_timeline")
|
||||
"some_rl_timeline",
|
||||
1)
|
||||
|
||||
// Create fake input data
|
||||
val rl_input = sparkContext
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
python ml/rl/test/workflow/eval_cartpole.py -m $1
|
||||
|
||||
@@ -10,10 +10,12 @@ function finish {
|
||||
}
|
||||
trap finish EXIT
|
||||
|
||||
# Remove the output data
|
||||
rm -Rf cartpole_discrete_timeline
|
||||
|
||||
# Run timelime on pre-timeline data
|
||||
/usr/local/spark/bin/spark-submit \
|
||||
--class com.facebook.spark.rl.Preprocessor preprocessing/target/rl-preprocessing-1.1.jar \
|
||||
"`cat ml/rl/workflow/sample_configs/discrete_action/timeline.json`"
|
||||
|
||||
mv cartpole_discrete_timeline/part* cartpole_training_data.json
|
||||
|
||||
# Remove the output data folder
|
||||
rm -Rf cartpole_discrete_timeline
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Generate thrift specs
|
||||
thrift --gen py --out . ml/rl/thrift/core.thrift
|
||||
|
||||
# Install the current directory into python path
|
||||
pip install -e .
|
||||
|
||||
# Run workflow tests
|
||||
python -m ml.rl.test.workflow.test_oss_workflows
|
||||
@@ -1 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
python ml/rl/workflow/dqn_workflow.py -p ml/rl/workflow/sample_configs/discrete_action/dqn_example.json
|
||||
|
||||
Reference in New Issue
Block a user