mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Add parametric DQN & DDPG query to scala timeline
Reviewed By: MisterTea Differential Revision: D9838545 fbshipit-source-id: 77de84640e9cc8355bfe49de01640570f3c8b8c5
This commit is contained in:
committed by
Facebook Github Bot
parent
e50d699f89
commit
469801c31f
@@ -54,7 +54,11 @@ object Preprocessor {
|
||||
inputDf.createOrReplaceTempView(timelineConfig.inputTableName)
|
||||
|
||||
Timeline.run(sparkSession.sqlContext, timelineConfig)
|
||||
val query = Query.getDiscreteQuery(queryConfig)
|
||||
val query = if (timelineConfig.actionDiscrete) {
|
||||
Query.getDiscreteQuery(queryConfig)
|
||||
} else {
|
||||
Query.getContinuousQuery(queryConfig)
|
||||
}
|
||||
|
||||
// Query the results
|
||||
val outputDf = sparkSession.sql(
|
||||
|
||||
@@ -62,4 +62,28 @@ object Query {
|
||||
""").stripMargin
|
||||
return query
|
||||
}
|
||||
|
||||
def getContinuousQuery(config: QueryConfiguration): String = {
|
||||
val rewardTimelineCol = if (config.useNonOrdinalRewardTimeline) {
|
||||
"reward_timeline"
|
||||
} else {
|
||||
"reward_timeline_ordinal"
|
||||
}
|
||||
|
||||
return s"""
|
||||
SELECT
|
||||
mdp_id,
|
||||
sequence_number,
|
||||
state_features,
|
||||
action,
|
||||
action_probability as propensity,
|
||||
reward,
|
||||
next_state_features,
|
||||
next_action,
|
||||
time_diff,
|
||||
possible_actions,
|
||||
possible_next_actions,
|
||||
COMPUTE_EPISODE_VALUE(${config.discountFactor}, ${rewardTimelineCol}) as episode_value
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user