Get end-to-end training working

Summary: title

Reviewed By: MisterTea

Differential Revision: D10440251

fbshipit-source-id: 5bb237b695dad63e2ad820273e65ac65cbd19533
This commit is contained in:
Edoardo Conti
2018-10-18 13:39:51 -07:00
committed by Facebook Github Bot
parent 329a0c0e8e
commit f33f3d0cce
23 changed files with 138 additions and 58 deletions
+3 -16
View File
@@ -65,15 +65,9 @@ RUN wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz -O t
make install
# Install Java & maven.
# Taken from https://github.com/dockerfile/java/blob/master/oracle-java8/Dockerfile
RUN \
echo oracle-java8-installer shared/accepted-oracle-license-v1-1 select true | debconf-set-selections && \
add-apt-repository -y ppa:webupd8team/java && \
apt-get update && \
apt-get install -y oracle-java8-installer maven && \
rm -rf /var/lib/apt/lists/* && \
rm -rf /var/cache/oracle-jdk8-installer
ENV JAVA_HOME /usr/lib/jvm/java-8-oracle
RUN apt-get install -y openjdk-8-jre && \
apt-get install -y maven
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
# Install Spark
RUN wget http://www-eu.apache.org/dist/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz && \
@@ -108,12 +102,5 @@ ADD ./requirements.txt requirements.txt
RUN ./install_prereqs.sh
RUN rm install_prereqs.sh requirements.txt
# Install and run key tests, this should finish quickly
RUN git clone https://github.com/facebookresearch/Horizon.git && \
cd Horizon && \
thrift --gen py --out . ml/rl/thrift/core.thrift && \
pip install -e . && \
python -m ml.rl.test.workflow.test_oss_workflows
# Define default command.
CMD ["bash"]
+3 -15
View File
@@ -65,15 +65,9 @@ RUN wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz -O t
make install
# Install Java & maven.
# Taken from https://github.com/dockerfile/java/blob/master/oracle-java8/Dockerfile
RUN \
echo oracle-java8-installer shared/accepted-oracle-license-v1-1 select true | debconf-set-selections && \
add-apt-repository -y ppa:webupd8team/java && \
apt-get update && \
apt-get install -y oracle-java8-installer maven && \
rm -rf /var/lib/apt/lists/* && \
rm -rf /var/cache/oracle-jdk8-installer
ENV JAVA_HOME /usr/lib/jvm/java-8-oracle
RUN apt-get install -y openjdk-8-jre && \
apt-get install -y maven
ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
ENV HOME /home
WORKDIR ${HOME}/
@@ -106,12 +100,6 @@ ADD ./requirements.txt requirements.txt
RUN ./install_prereqs.sh
RUN rm install_prereqs.sh requirements.txt
# Install and run key tests, this should finish quickly
RUN git clone https://github.com/facebookresearch/Horizon.git && \
cd Horizon && \
thrift --gen py --out . ml/rl/thrift/core.thrift && \
pip install -e . && \
python -m ml.rl.test.workflow.test_oss_workflows
# Define default command.
CMD ["bash"]
+12 -2
View File
@@ -23,10 +23,20 @@ On Linux you can build the image with specific memory allocations from command l
docker build -t horizon:dev --memory=8g --memory-swap=8g .
```
Once the Docker image is built you can start an interactive shell in the container and run the unit tests:
Once the Docker image is built you can start an interactive shell in the container and run the unit tests. To have the ability to edit files locally and have changes be available in the Docker container, mount the local Horizon repo as a volume:
```
docker run -it horizon:dev
docker run -v /<LOCAL_PATH_TO_HORIZON>/Horizon:/home/Horizon -it horizon:dev
cd Horizon
```
Depending on where your local Horizon copy is, you may need to white list your shared path via Docker -> Preferences... -> File Sharing.
Run the setup file:
```
bash scripts/setup.sh
```
Now you can run the tests:
```
python setup.py test
```
@@ -35,7 +35,7 @@
"use_noisy_linear_layers": false
},
"run_details": {
"num_episodes": 100,
"num_episodes": 400,
"max_steps": 200,
"train_every_ts": 1,
"train_after_ts": 1,
+9 -1
View File
@@ -11,6 +11,7 @@ from ml.rl.test.gym.gym_predictor import (
GymDQNPredictorPytorch,
)
from ml.rl.test.utils import default_normalizer
from ml.rl.training.dqn_predictor import DQNPredictor
class ModelType(enum.Enum):
@@ -25,7 +26,7 @@ class EnvType(enum.Enum):
class OpenAIGymEnvironment:
def __init__(self, gymenv, epsilon, softmax_policy, gamma):
def __init__(self, gymenv, epsilon=0, softmax_policy=False, gamma=0.99):
"""
Creates an OpenAIGymEnvironment object.
@@ -129,6 +130,13 @@ class OpenAIGymEnvironment:
if test:
return predictor.policy(next_state)[0]
return predictor.policy(next_state, add_action_noise=True)[0]
elif isinstance(predictor, DQNPredictor):
# Use DQNPredictor directly - useful to test caffe2 predictor
sparse_next_states = predictor.in_order_dense_to_sparse(next_state)
q_values = predictor.predict(sparse_next_states)
action_idx = max(q_values[0], key=q_values[0].get)
action[int(action_idx)] = 1.0
return action
else:
raise NotImplementedError("Unknown predictor type")
+3 -1
View File
@@ -160,8 +160,10 @@ def train_gym_online_rl(
if gym_env.action_type == EnvType.DISCRETE_ACTION:
action_index = np.argmax(action)
next_state, reward, terminal, _ = gym_env.env.step(action_index)
action_to_log = str(action_index)
else:
next_state, reward, terminal, _ = gym_env.env.step(action)
action_to_log = action.tolist()
next_state = gym_env.transform_state(next_state)
ep_timesteps += 1
@@ -192,7 +194,7 @@ def train_gym_online_rl(
i,
ep_timesteps - 1,
state.tolist(),
action.tolist(),
action_to_log,
reward,
next_state.tolist(),
next_action.tolist(),
+46
View File
@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from ml.rl.test.gym.open_ai_gym_environment import OpenAIGymEnvironment
from ml.rl.training.dqn_predictor import DQNPredictor
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ENV = "CartPole-v0"
AVG_OVER_NUM_EPS = 20
def main(model_path):
predictor = DQNPredictor.load(model_path, "minidb", int_features=False)
env = OpenAIGymEnvironment(gymenv=ENV)
avg_rewards, avg_discounted_rewards = env.run_ep_n_times(
AVG_OVER_NUM_EPS, predictor, test=True
)
logger.info(
"Achieved an average reward score of {} over {} evaluations.".format(
avg_rewards, AVG_OVER_NUM_EPS
)
)
def parse_args(args):
if len(args) != 3:
raise Exception("Usage: python <file.py> -m <parameters_file>")
parser = argparse.ArgumentParser(description="Read command line parameters.")
parser.add_argument("-m", "--model", help="Path to Caffe2 model.")
args = parser.parse_args(args[1:])
return args.model
if __name__ == "__main__":
model_path = parse_args(sys.argv)
main(model_path)
+1
View File
@@ -15,6 +15,7 @@ from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class CPE_Estimate(NamedTuple):
+6 -2
View File
@@ -29,6 +29,7 @@ class RLDataset:
data = self.rows
if pre_timeline_format:
data = self.pre_timeline_format_rows
with open(self.file_path, "w") as f:
json.dump(data, f)
@@ -53,7 +54,7 @@ class RLDataset:
"""
assert isinstance(state, list)
assert isinstance(action, list)
assert isinstance(action, (list, str))
assert isinstance(reward, float)
assert isinstance(next_state, list)
assert isinstance(next_action, list)
@@ -87,12 +88,15 @@ class RLDataset:
# This assumes that every state feature is present in every training example.
int_state_feature_keys = [int(k) for k in state_features.keys()]
idx_bump = max(int_state_feature_keys) + 1
if isinstance(action, list):
# Parametric or continuous action domain
action = {str(i + idx_bump): v for i, v in enumerate(action)}
if isinstance(possible_actions, list):
if len(possible_actions) == 0:
pass
elif isinstance(possible_actions[0], int):
# Discrete action domain
action = str(np.argmax(action))
possible_actions = [
str(idx) for idx, val in enumerate(possible_actions) if val == 1
]
+8
View File
@@ -58,6 +58,14 @@ class RLPredictor:
def predict_net(self):
return self._net
def in_order_dense_to_sparse(self, dense):
"""Convert dense observation to sparse observation assuming in order
feature ids."""
sparse = []
for row in dense:
sparse.append({str(k): v for k, v in enumerate(row)})
return sparse
def predict(self, float_state_features, int_state_features=None):
""" Returns values for each state
:param float_state_features A list of feature -> float value dict examples
@@ -22,7 +22,6 @@ from ml.rl.workflow.training_data_reader import JSONDataset
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
NORMALIZATION_BATCH_READ_SIZE = 50000
@@ -50,7 +49,7 @@ def get_norm_metadata(dataset, norm_params, norm_col):
samples_per_feature, samples = defaultdict(int), defaultdict(list)
while not done:
if not batch or len(batch[norm_col]) == 0:
if batch is None or len(batch[norm_col]) == 0:
logger.info("No more data in training data. Breaking.")
break
+1 -1
View File
@@ -34,7 +34,7 @@ from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger.setLevel(logging.INFO)
DEFAULT_NUM_SAMPLES_FOR_CPE = 5000
@@ -5,7 +5,8 @@
"addTerminalStateRow": false,
"actionDiscrete": false,
"inputTableName": "pendulum",
"outputTableName": "pendulum_training_data"
"outputTableName": "pendulum_training_data",
"numOutputShards": 1
},
"query": {
"discountFactor": 0.9,
@@ -1,11 +1,11 @@
{
"training_data_path": "~/cartpole_training_data.json",
"state_norm_data_path": "~/state_features_norm.json",
"model_output_path": "~/",
"training_data_path": "cartpole_training_data.json",
"state_norm_data_path": "state_features_norm.json",
"model_output_path": "./",
"use_gpu": true,
"use_all_avail_gpus": true,
"norm_params": {
"output_dir": "~/",
"output_dir": "./",
"cols_to_norm": [
"state_features"
],
@@ -15,7 +15,7 @@
"0",
"1"
],
"epochs": 1,
"epochs": 100,
"rl": {
"gamma": 0.99,
"target_update_rate": 0.2,
@@ -41,7 +41,7 @@
"relu",
"linear"
],
"minibatch_size": 512,
"minibatch_size": 1024,
"learning_rate": 0.001,
"optimizer": "ADAM",
"lr_decay": 0.999,
@@ -18,4 +18,4 @@
"1"
]
}
}
}
@@ -5,7 +5,8 @@
"addTerminalStateRow": false,
"actionDiscrete": false,
"inputTableName": "cartpole_parametric",
"outputTableName": "cartpole_parametric_training_data"
"outputTableName": "cartpole_parametric_training_data",
"numOutputShards": 1
},
"query": {
"discountFactor": 0.9,
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
package com.facebook.spark.rl
import org.slf4j.LoggerFactory
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.spark.sql._
@@ -16,6 +17,9 @@ case class QueryConfiguration(discountFactor: Double = 0.9,
actions: Array[String] = Array())
object Preprocessor {
private val log = LoggerFactory.getLogger(this.getClass.getName)
def main(args: Array[String]) {
val sparkSession = SparkSession.builder().enableHiveSupport().getOrCreate()
sparkSession.sqlContext.udf.register("COMPUTE_EPISODE_VALUE", Udfs.getEpisodeValue[Double] _)
@@ -62,10 +66,14 @@ object Preprocessor {
Query.getContinuousQuery(queryConfig)
}
val sqlCommand = query.concat(
s" FROM ${timelineConfig.outputTableName} where rand() <= ${queryConfig.tableSample}")
log.info("Executing query: ")
log.info(sqlCommand)
// Query the results
val outputDf = sparkSession.sql(
query.concat(
s" FROM ${timelineConfig.outputTableName} where rand() <= ${queryConfig.tableSample}"))
val outputDf = sparkSession.sql(sqlCommand)
outputDf.show()
outputDf
@@ -11,7 +11,7 @@ case class TimelineConfiguration(startDs: String,
actionDiscrete: Boolean,
inputTableName: String,
outputTableName: String,
numOutputShards: Int = 1)
numOutputShards: Int)
/**
* Given table of state, action, mdp_id, sequence_number, reward, possible_next_actions
@@ -22,7 +22,8 @@ class TimelineTest extends PipelineTester {
false,
true,
"some_rl_input",
"some_rl_timeline")
"some_rl_timeline",
1)
// Create fake input data
val rl_input = sparkContext
+2
View File
@@ -1 +1,3 @@
#!/usr/bin/env bash
python ml/rl/test/workflow/eval_cartpole.py -m $1
+5 -3
View File
@@ -10,10 +10,12 @@ function finish {
}
trap finish EXIT
# Remove the output data
rm -Rf cartpole_discrete_timeline
# Run timelime on pre-timeline data
/usr/local/spark/bin/spark-submit \
--class com.facebook.spark.rl.Preprocessor preprocessing/target/rl-preprocessing-1.1.jar \
"`cat ml/rl/workflow/sample_configs/discrete_action/timeline.json`"
mv cartpole_discrete_timeline/part* cartpole_training_data.json
# Remove the output data folder
rm -Rf cartpole_discrete_timeline
+10
View File
@@ -0,0 +1,10 @@
#!/usr/bin/env bash
# Generate thrift specs
thrift --gen py --out . ml/rl/thrift/core.thrift
# Install the current directory into python path
pip install -e .
# Run workflow tests
python -m ml.rl.test.workflow.test_oss_workflows
+2
View File
@@ -1 +1,3 @@
#!/usr/bin/env bash
python ml/rl/workflow/dqn_workflow.py -p ml/rl/workflow/sample_configs/discrete_action/dqn_example.json