Add parametric DQN & DDPG query to scala timeline

Reviewed By: MisterTea

Differential Revision: D9838545

fbshipit-source-id: 77de84640e9cc8355bfe49de01640570f3c8b8c5
This commit is contained in:
Edoardo Conti
2018-09-16 21:18:43 -07:00
committed by Facebook Github Bot
parent e50d699f89
commit 469801c31f
4 changed files with 59 additions and 1 deletions
@@ -54,7 +54,11 @@ object Preprocessor {
inputDf.createOrReplaceTempView(timelineConfig.inputTableName)
Timeline.run(sparkSession.sqlContext, timelineConfig)
val query = Query.getDiscreteQuery(queryConfig)
val query = if (timelineConfig.actionDiscrete) {
Query.getDiscreteQuery(queryConfig)
} else {
Query.getContinuousQuery(queryConfig)
}
// Query the results
val outputDf = sparkSession.sql(
@@ -62,4 +62,28 @@ object Query {
""").stripMargin
return query
}
def getContinuousQuery(config: QueryConfiguration): String = {
val rewardTimelineCol = if (config.useNonOrdinalRewardTimeline) {
"reward_timeline"
} else {
"reward_timeline_ordinal"
}
return s"""
SELECT
mdp_id,
sequence_number,
state_features,
action,
action_probability as propensity,
reward,
next_state_features,
next_action,
time_diff,
possible_actions,
possible_next_actions,
COMPUTE_EPISODE_VALUE(${config.discountFactor}, ${rewardTimelineCol}) as episode_value
""".stripMargin
}
}