More oss fixes

Summary: Pull Request resolved: https://github.com/facebookresearch/Horizon/pull/104

Reviewed By: econti

Differential Revision: D14260743

Pulled By: MisterTea

fbshipit-source-id: 3ce1031e5da47a44aa7e6ab98fb5239312a769aa
This commit is contained in:
Jason Gauci
2019-03-01 19:50:29 -08:00
committed by Facebook Github Bot
parent c8192b8f76
commit a42f1dc6a2
12 changed files with 157 additions and 54 deletions
@@ -51,7 +51,23 @@ object Preprocessor {
StructField("metrics", MapType(StringType, DoubleType, true))
))
val inputDf = sparkSession.read.schema(schema).json(timelineConfig.inputTableName)
var inputDf = sparkSession.read.schema(schema).json(timelineConfig.inputTableName)
val mapStringDoubleToLongDouble = udf(
(r: Map[String, Double]) => r.map({ case (key, value) => (key.toLong, value) }))
inputDf = inputDf.withColumn("state_features",
mapStringDoubleToLongDouble(inputDf.col("state_features")))
if (!timelineConfig.actionDiscrete) {
inputDf = inputDf.withColumn("action", mapStringDoubleToLongDouble(inputDf.col("action")))
val mapArrayStringDoubleToArrayLongDouble = udf((r: Array[Map[String, Double]]) =>
r.map((m) => m.map({ case (key, value) => (key.toLong, value) })))
inputDf = inputDf.withColumn(
"possible_actions",
mapArrayStringDoubleToArrayLongDouble(inputDf.col("possible_actions")))
}
inputDf.createOrReplaceTempView(timelineConfig.inputTableName)
Timeline.run(sparkSession.sqlContext, timelineConfig)