mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
More oss fixes
Summary: Pull Request resolved: https://github.com/facebookresearch/Horizon/pull/104 Reviewed By: econti Differential Revision: D14260743 Pulled By: MisterTea fbshipit-source-id: 3ce1031e5da47a44aa7e6ab98fb5239312a769aa
This commit is contained in:
committed by
Facebook Github Bot
parent
c8192b8f76
commit
a42f1dc6a2
@@ -51,7 +51,23 @@ object Preprocessor {
|
||||
StructField("metrics", MapType(StringType, DoubleType, true))
|
||||
))
|
||||
|
||||
val inputDf = sparkSession.read.schema(schema).json(timelineConfig.inputTableName)
|
||||
var inputDf = sparkSession.read.schema(schema).json(timelineConfig.inputTableName)
|
||||
|
||||
val mapStringDoubleToLongDouble = udf(
|
||||
(r: Map[String, Double]) => r.map({ case (key, value) => (key.toLong, value) }))
|
||||
|
||||
inputDf = inputDf.withColumn("state_features",
|
||||
mapStringDoubleToLongDouble(inputDf.col("state_features")))
|
||||
if (!timelineConfig.actionDiscrete) {
|
||||
inputDf = inputDf.withColumn("action", mapStringDoubleToLongDouble(inputDf.col("action")))
|
||||
|
||||
val mapArrayStringDoubleToArrayLongDouble = udf((r: Array[Map[String, Double]]) =>
|
||||
r.map((m) => m.map({ case (key, value) => (key.toLong, value) })))
|
||||
inputDf = inputDf.withColumn(
|
||||
"possible_actions",
|
||||
mapArrayStringDoubleToArrayLongDouble(inputDf.col("possible_actions")))
|
||||
}
|
||||
|
||||
inputDf.createOrReplaceTempView(timelineConfig.inputTableName)
|
||||
|
||||
Timeline.run(sparkSession.sqlContext, timelineConfig)
|
||||
|
||||
Reference in New Issue
Block a user