diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile index a9e32e06..a1a07d38 100644 --- a/docker/cpu/Dockerfile +++ b/docker/cpu/Dockerfile @@ -65,15 +65,9 @@ RUN wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz -O t make install # Install Java & maven. -# Taken from https://github.com/dockerfile/java/blob/master/oracle-java8/Dockerfile -RUN \ - echo oracle-java8-installer shared/accepted-oracle-license-v1-1 select true | debconf-set-selections && \ - add-apt-repository -y ppa:webupd8team/java && \ - apt-get update && \ - apt-get install -y oracle-java8-installer maven && \ - rm -rf /var/lib/apt/lists/* && \ - rm -rf /var/cache/oracle-jdk8-installer -ENV JAVA_HOME /usr/lib/jvm/java-8-oracle +RUN apt-get install -y openjdk-8-jre && \ + apt-get install -y maven +ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ # Install Spark RUN wget http://www-eu.apache.org/dist/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz && \ @@ -108,12 +102,5 @@ ADD ./requirements.txt requirements.txt RUN ./install_prereqs.sh RUN rm install_prereqs.sh requirements.txt -# Install and run key tests, this should finish quickly -RUN git clone https://github.com/facebookresearch/Horizon.git && \ - cd Horizon && \ - thrift --gen py --out . ml/rl/thrift/core.thrift && \ - pip install -e . && \ - python -m ml.rl.test.workflow.test_oss_workflows - # Define default command. CMD ["bash"] diff --git a/docker/cuda/Dockerfile b/docker/cuda/Dockerfile index d0cec603..c6624f28 100644 --- a/docker/cuda/Dockerfile +++ b/docker/cuda/Dockerfile @@ -65,15 +65,9 @@ RUN wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz -O t make install # Install Java & maven. -# Taken from https://github.com/dockerfile/java/blob/master/oracle-java8/Dockerfile -RUN \ - echo oracle-java8-installer shared/accepted-oracle-license-v1-1 select true | debconf-set-selections && \ - add-apt-repository -y ppa:webupd8team/java && \ - apt-get update && \ - apt-get install -y oracle-java8-installer maven && \ - rm -rf /var/lib/apt/lists/* && \ - rm -rf /var/cache/oracle-jdk8-installer -ENV JAVA_HOME /usr/lib/jvm/java-8-oracle +RUN apt-get install -y openjdk-8-jre && \ + apt-get install -y maven +ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ ENV HOME /home WORKDIR ${HOME}/ @@ -106,12 +100,6 @@ ADD ./requirements.txt requirements.txt RUN ./install_prereqs.sh RUN rm install_prereqs.sh requirements.txt -# Install and run key tests, this should finish quickly -RUN git clone https://github.com/facebookresearch/Horizon.git && \ - cd Horizon && \ - thrift --gen py --out . ml/rl/thrift/core.thrift && \ - pip install -e . && \ - python -m ml.rl.test.workflow.test_oss_workflows # Define default command. CMD ["bash"] diff --git a/docs/installation.md b/docs/installation.md index d1f05320..94a8f37b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -23,10 +23,20 @@ On Linux you can build the image with specific memory allocations from command l docker build -t horizon:dev --memory=8g --memory-swap=8g . ``` -Once the Docker image is built you can start an interactive shell in the container and run the unit tests: +Once the Docker image is built you can start an interactive shell in the container and run the unit tests. To have the ability to edit files locally and have changes be available in the Docker container, mount the local Horizon repo as a volume: ``` -docker run -it horizon:dev +docker run -v //Horizon:/home/Horizon -it horizon:dev cd Horizon +``` +Depending on where your local Horizon copy is, you may need to white list your shared path via Docker -> Preferences... -> File Sharing. + +Run the setup file: +``` +bash scripts/setup.sh +``` + +Now you can run the tests: +``` python setup.py test ``` diff --git a/ml/rl/test/gym/discrete_dqn_cartpole_v0_100_eps.json b/ml/rl/test/gym/discrete_dqn_cartpole_v0_100_eps.json index 69f46885..d90c09d0 100644 --- a/ml/rl/test/gym/discrete_dqn_cartpole_v0_100_eps.json +++ b/ml/rl/test/gym/discrete_dqn_cartpole_v0_100_eps.json @@ -35,7 +35,7 @@ "use_noisy_linear_layers": false }, "run_details": { - "num_episodes": 100, + "num_episodes": 400, "max_steps": 200, "train_every_ts": 1, "train_after_ts": 1, diff --git a/ml/rl/test/gym/open_ai_gym_environment.py b/ml/rl/test/gym/open_ai_gym_environment.py index 3c648b43..4f173d87 100644 --- a/ml/rl/test/gym/open_ai_gym_environment.py +++ b/ml/rl/test/gym/open_ai_gym_environment.py @@ -11,6 +11,7 @@ from ml.rl.test.gym.gym_predictor import ( GymDQNPredictorPytorch, ) from ml.rl.test.utils import default_normalizer +from ml.rl.training.dqn_predictor import DQNPredictor class ModelType(enum.Enum): @@ -25,7 +26,7 @@ class EnvType(enum.Enum): class OpenAIGymEnvironment: - def __init__(self, gymenv, epsilon, softmax_policy, gamma): + def __init__(self, gymenv, epsilon=0, softmax_policy=False, gamma=0.99): """ Creates an OpenAIGymEnvironment object. @@ -129,6 +130,13 @@ class OpenAIGymEnvironment: if test: return predictor.policy(next_state)[0] return predictor.policy(next_state, add_action_noise=True)[0] + elif isinstance(predictor, DQNPredictor): + # Use DQNPredictor directly - useful to test caffe2 predictor + sparse_next_states = predictor.in_order_dense_to_sparse(next_state) + q_values = predictor.predict(sparse_next_states) + action_idx = max(q_values[0], key=q_values[0].get) + action[int(action_idx)] = 1.0 + return action else: raise NotImplementedError("Unknown predictor type") diff --git a/ml/rl/test/gym/run_gym.py b/ml/rl/test/gym/run_gym.py index d5a914c0..c5e83ed2 100644 --- a/ml/rl/test/gym/run_gym.py +++ b/ml/rl/test/gym/run_gym.py @@ -160,8 +160,10 @@ def train_gym_online_rl( if gym_env.action_type == EnvType.DISCRETE_ACTION: action_index = np.argmax(action) next_state, reward, terminal, _ = gym_env.env.step(action_index) + action_to_log = str(action_index) else: next_state, reward, terminal, _ = gym_env.env.step(action) + action_to_log = action.tolist() next_state = gym_env.transform_state(next_state) ep_timesteps += 1 @@ -192,7 +194,7 @@ def train_gym_online_rl( i, ep_timesteps - 1, state.tolist(), - action.tolist(), + action_to_log, reward, next_state.tolist(), next_action.tolist(), diff --git a/ml/rl/test/workflow/eval_cartpole.py b/ml/rl/test/workflow/eval_cartpole.py new file mode 100644 index 00000000..73d8a190 --- /dev/null +++ b/ml/rl/test/workflow/eval_cartpole.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import sys + +from ml.rl.test.gym.open_ai_gym_environment import OpenAIGymEnvironment +from ml.rl.training.dqn_predictor import DQNPredictor + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +ENV = "CartPole-v0" +AVG_OVER_NUM_EPS = 20 + + +def main(model_path): + predictor = DQNPredictor.load(model_path, "minidb", int_features=False) + + env = OpenAIGymEnvironment(gymenv=ENV) + + avg_rewards, avg_discounted_rewards = env.run_ep_n_times( + AVG_OVER_NUM_EPS, predictor, test=True + ) + + logger.info( + "Achieved an average reward score of {} over {} evaluations.".format( + avg_rewards, AVG_OVER_NUM_EPS + ) + ) + + +def parse_args(args): + if len(args) != 3: + raise Exception("Usage: python -m ") + + parser = argparse.ArgumentParser(description="Read command line parameters.") + parser.add_argument("-m", "--model", help="Path to Caffe2 model.") + args = parser.parse_args(args[1:]) + return args.model + + +if __name__ == "__main__": + model_path = parse_args(sys.argv) + main(model_path) diff --git a/ml/rl/training/evaluator.py b/ml/rl/training/evaluator.py index 6b479354..3fa34f99 100644 --- a/ml/rl/training/evaluator.py +++ b/ml/rl/training/evaluator.py @@ -15,6 +15,7 @@ from tensorboardX import SummaryWriter logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class CPE_Estimate(NamedTuple): diff --git a/ml/rl/training/rl_dataset.py b/ml/rl/training/rl_dataset.py index 7fb8e127..705ca03a 100644 --- a/ml/rl/training/rl_dataset.py +++ b/ml/rl/training/rl_dataset.py @@ -29,6 +29,7 @@ class RLDataset: data = self.rows if pre_timeline_format: data = self.pre_timeline_format_rows + with open(self.file_path, "w") as f: json.dump(data, f) @@ -53,7 +54,7 @@ class RLDataset: """ assert isinstance(state, list) - assert isinstance(action, list) + assert isinstance(action, (list, str)) assert isinstance(reward, float) assert isinstance(next_state, list) assert isinstance(next_action, list) @@ -87,12 +88,15 @@ class RLDataset: # This assumes that every state feature is present in every training example. int_state_feature_keys = [int(k) for k in state_features.keys()] idx_bump = max(int_state_feature_keys) + 1 + if isinstance(action, list): + # Parametric or continuous action domain + action = {str(i + idx_bump): v for i, v in enumerate(action)} + if isinstance(possible_actions, list): if len(possible_actions) == 0: pass elif isinstance(possible_actions[0], int): # Discrete action domain - action = str(np.argmax(action)) possible_actions = [ str(idx) for idx, val in enumerate(possible_actions) if val == 1 ] diff --git a/ml/rl/training/rl_predictor_pytorch.py b/ml/rl/training/rl_predictor_pytorch.py index f119fe7b..294b72a3 100644 --- a/ml/rl/training/rl_predictor_pytorch.py +++ b/ml/rl/training/rl_predictor_pytorch.py @@ -58,6 +58,14 @@ class RLPredictor: def predict_net(self): return self._net + def in_order_dense_to_sparse(self, dense): + """Convert dense observation to sparse observation assuming in order + feature ids.""" + sparse = [] + for row in dense: + sparse.append({str(k): v for k, v in enumerate(row)}) + return sparse + def predict(self, float_state_features, int_state_features=None): """ Returns values for each state :param float_state_features A list of feature -> float value dict examples diff --git a/ml/rl/workflow/create_normalization_metadata.py b/ml/rl/workflow/create_normalization_metadata.py index 464752c0..2c59edaa 100644 --- a/ml/rl/workflow/create_normalization_metadata.py +++ b/ml/rl/workflow/create_normalization_metadata.py @@ -22,7 +22,6 @@ from ml.rl.workflow.training_data_reader import JSONDataset logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) - NORMALIZATION_BATCH_READ_SIZE = 50000 @@ -50,7 +49,7 @@ def get_norm_metadata(dataset, norm_params, norm_col): samples_per_feature, samples = defaultdict(int), defaultdict(list) while not done: - if not batch or len(batch[norm_col]) == 0: + if batch is None or len(batch[norm_col]) == 0: logger.info("No more data in training data. Breaking.") break diff --git a/ml/rl/workflow/dqn_workflow.py b/ml/rl/workflow/dqn_workflow.py index ba855953..2111b910 100644 --- a/ml/rl/workflow/dqn_workflow.py +++ b/ml/rl/workflow/dqn_workflow.py @@ -34,7 +34,7 @@ from tensorboardX import SummaryWriter logger = logging.getLogger(__name__) -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger.setLevel(logging.INFO) DEFAULT_NUM_SAMPLES_FOR_CPE = 5000 diff --git a/ml/rl/workflow/sample_configs/continuous_action/timeline.json b/ml/rl/workflow/sample_configs/continuous_action/timeline.json index 141b378b..bd427727 100644 --- a/ml/rl/workflow/sample_configs/continuous_action/timeline.json +++ b/ml/rl/workflow/sample_configs/continuous_action/timeline.json @@ -5,7 +5,8 @@ "addTerminalStateRow": false, "actionDiscrete": false, "inputTableName": "pendulum", - "outputTableName": "pendulum_training_data" + "outputTableName": "pendulum_training_data", + "numOutputShards": 1 }, "query": { "discountFactor": 0.9, diff --git a/ml/rl/workflow/sample_configs/discrete_action/dqn_example.json b/ml/rl/workflow/sample_configs/discrete_action/dqn_example.json index 1a2bf777..429be5ff 100644 --- a/ml/rl/workflow/sample_configs/discrete_action/dqn_example.json +++ b/ml/rl/workflow/sample_configs/discrete_action/dqn_example.json @@ -1,11 +1,11 @@ { - "training_data_path": "~/cartpole_training_data.json", - "state_norm_data_path": "~/state_features_norm.json", - "model_output_path": "~/", + "training_data_path": "cartpole_training_data.json", + "state_norm_data_path": "state_features_norm.json", + "model_output_path": "./", "use_gpu": true, "use_all_avail_gpus": true, "norm_params": { - "output_dir": "~/", + "output_dir": "./", "cols_to_norm": [ "state_features" ], @@ -15,7 +15,7 @@ "0", "1" ], - "epochs": 1, + "epochs": 100, "rl": { "gamma": 0.99, "target_update_rate": 0.2, @@ -41,7 +41,7 @@ "relu", "linear" ], - "minibatch_size": 512, + "minibatch_size": 1024, "learning_rate": 0.001, "optimizer": "ADAM", "lr_decay": 0.999, diff --git a/ml/rl/workflow/sample_configs/discrete_action/timeline.json b/ml/rl/workflow/sample_configs/discrete_action/timeline.json index 69ccc92c..9dde5c4e 100644 --- a/ml/rl/workflow/sample_configs/discrete_action/timeline.json +++ b/ml/rl/workflow/sample_configs/discrete_action/timeline.json @@ -18,4 +18,4 @@ "1" ] } -} \ No newline at end of file +} diff --git a/ml/rl/workflow/sample_configs/parametric_action/timeline.json b/ml/rl/workflow/sample_configs/parametric_action/timeline.json index fbf70386..0250dec1 100644 --- a/ml/rl/workflow/sample_configs/parametric_action/timeline.json +++ b/ml/rl/workflow/sample_configs/parametric_action/timeline.json @@ -5,7 +5,8 @@ "addTerminalStateRow": false, "actionDiscrete": false, "inputTableName": "cartpole_parametric", - "outputTableName": "cartpole_parametric_training_data" + "outputTableName": "cartpole_parametric_training_data", + "numOutputShards": 1 }, "query": { "discountFactor": 0.9, diff --git a/preprocessing/src/main/scala/com/facebook/spark/rl/Preprocessor.scala b/preprocessing/src/main/scala/com/facebook/spark/rl/Preprocessor.scala index a2efa5d8..6e44677e 100644 --- a/preprocessing/src/main/scala/com/facebook/spark/rl/Preprocessor.scala +++ b/preprocessing/src/main/scala/com/facebook/spark/rl/Preprocessor.scala @@ -1,6 +1,7 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. package com.facebook.spark.rl +import org.slf4j.LoggerFactory import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.spark.sql._ @@ -16,6 +17,9 @@ case class QueryConfiguration(discountFactor: Double = 0.9, actions: Array[String] = Array()) object Preprocessor { + + private val log = LoggerFactory.getLogger(this.getClass.getName) + def main(args: Array[String]) { val sparkSession = SparkSession.builder().enableHiveSupport().getOrCreate() sparkSession.sqlContext.udf.register("COMPUTE_EPISODE_VALUE", Udfs.getEpisodeValue[Double] _) @@ -62,10 +66,14 @@ object Preprocessor { Query.getContinuousQuery(queryConfig) } + val sqlCommand = query.concat( + s" FROM ${timelineConfig.outputTableName} where rand() <= ${queryConfig.tableSample}") + + log.info("Executing query: ") + log.info(sqlCommand) + // Query the results - val outputDf = sparkSession.sql( - query.concat( - s" FROM ${timelineConfig.outputTableName} where rand() <= ${queryConfig.tableSample}")) + val outputDf = sparkSession.sql(sqlCommand) outputDf.show() outputDf diff --git a/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala b/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala index 6a271716..a7e55ceb 100644 --- a/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala +++ b/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala @@ -11,7 +11,7 @@ case class TimelineConfiguration(startDs: String, actionDiscrete: Boolean, inputTableName: String, outputTableName: String, - numOutputShards: Int = 1) + numOutputShards: Int) /** * Given table of state, action, mdp_id, sequence_number, reward, possible_next_actions diff --git a/preprocessing/src/test/scala/com/facebook/spark/rl/TimelineTest.scala b/preprocessing/src/test/scala/com/facebook/spark/rl/TimelineTest.scala index 0c18348d..f525fabe 100644 --- a/preprocessing/src/test/scala/com/facebook/spark/rl/TimelineTest.scala +++ b/preprocessing/src/test/scala/com/facebook/spark/rl/TimelineTest.scala @@ -22,7 +22,8 @@ class TimelineTest extends PipelineTester { false, true, "some_rl_input", - "some_rl_timeline") + "some_rl_timeline", + 1) // Create fake input data val rl_input = sparkContext diff --git a/scripts/eval.sh b/scripts/eval.sh index f1f641af..20b94293 100755 --- a/scripts/eval.sh +++ b/scripts/eval.sh @@ -1 +1,3 @@ #!/usr/bin/env bash + +python ml/rl/test/workflow/eval_cartpole.py -m $1 diff --git a/scripts/run_timeline.sh b/scripts/run_timeline.sh index 8eb7ab60..7672923a 100755 --- a/scripts/run_timeline.sh +++ b/scripts/run_timeline.sh @@ -10,10 +10,12 @@ function finish { } trap finish EXIT -# Remove the output data -rm -Rf cartpole_discrete_timeline - # Run timelime on pre-timeline data /usr/local/spark/bin/spark-submit \ --class com.facebook.spark.rl.Preprocessor preprocessing/target/rl-preprocessing-1.1.jar \ "`cat ml/rl/workflow/sample_configs/discrete_action/timeline.json`" + +mv cartpole_discrete_timeline/part* cartpole_training_data.json + +# Remove the output data folder +rm -Rf cartpole_discrete_timeline diff --git a/scripts/setup.sh b/scripts/setup.sh new file mode 100644 index 00000000..d6809243 --- /dev/null +++ b/scripts/setup.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# Generate thrift specs +thrift --gen py --out . ml/rl/thrift/core.thrift + +# Install the current directory into python path +pip install -e . + +# Run workflow tests +python -m ml.rl.test.workflow.test_oss_workflows diff --git a/scripts/train.sh b/scripts/train.sh index f1f641af..ed038de8 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1 +1,3 @@ #!/usr/bin/env bash + +python ml/rl/workflow/dqn_workflow.py -p ml/rl/workflow/sample_configs/discrete_action/dqn_example.json