diff --git a/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala b/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala index 74c5bcee..290adb6e 100644 --- a/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala +++ b/preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala @@ -124,23 +124,31 @@ object Timeline { config.actionDiscrete) Timeline.createTrainingTable(sqlContext, config.outputTableName, config.actionDiscrete) + config.outlierEpisodeLengthPercentile.foreach { percentile => + sqlContext.sql(s""" + SELECT mdp_id, COUNT(1) mdp_length + FROM ${config.inputTableName} + WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}' + GROUP BY 1 + """).createOrReplaceTempView("episode_length") + } + val sourceTable = Timeline.mdpLengthThreshold(sqlContext, config) match { case Some(threshold) => s""" - WITH a AS (${Timeline.getEpisodeLengthStatement(config)}), - b AS (SELECT mdp_id FROM a WHERE mdp_length < ${threshold}), + WITH a AS (SELECT mdp_id FROM episode_length WHERE mdp_length < ${threshold}), source_table AS ( SELECT - c.mdp_id, - c.state_features, - c.action, - c.action_probability, - c.reward, - c.sequence_number, - c.possible_actions, - c.metrics - FROM b JOIN ${config.inputTableName} c - WHERE b.mdp_id = c.mdp_id - AND c.ds BETWEEN '${config.startDs}' AND '${config.endDs}' + b.mdp_id, + b.state_features, + b.action, + b.action_probability, + b.reward, + b.sequence_number, + b.possible_actions, + b.metrics + FROM ${config.inputTableName} b JOIN a + WHERE a.mdp_id = b.mdp_id + AND b.ds BETWEEN '${config.startDs}' AND '${config.endDs}' ) """.stripMargin case None => s""" @@ -245,30 +253,21 @@ object Timeline { sqlContext.sql(insertCommand) } - def getEpisodeLengthStatement(config: TimelineConfiguration): String = s""" - SELECT mdp_id, COUNT(1) mdp_length - FROM ${config.inputTableName} - WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}' - GROUP BY 1 - """ - def mdpLengthThreshold(sqlContext: SQLContext, config: TimelineConfiguration): Option[Int] = - config.outlierEpisodeLengthPercentile match { - case Some(percentile) => { - val episodeLengthStatement = Timeline.getEpisodeLengthStatement(config) + config.outlierEpisodeLengthPercentile.flatMap { percentile => + { val df = sqlContext.sql(s""" - WITH a AS (${episodeLengthStatement}), - b AS ( - SELECT CAST(${config.percentileFunction}(mdp_length, ${percentile}) AS INT) pct FROM a + WITH a AS ( + SELECT CAST(${config.percentileFunction}(mdp_length, ${percentile}) AS INT) pct FROM episode_length ), - c AS ( + b AS ( SELECT count(1) as mdp_count, - count(IF(a.mdp_length >= b.pct, 1, NULL)) as outlier_count - FROM a, b + count(IF(episode_length.mdp_length >= a.pct, 1, NULL)) as outlier_count + FROM episode_length, a ) - SELECT b.pct, c.mdp_count, c.outlier_count - FROM b, c + SELECT a.pct, b.mdp_count, b.outlier_count + FROM a, b """) val res = df.first val pct_val = res.getAs[Int]("pct") @@ -284,7 +283,6 @@ object Timeline { } else Some(pct_val) } - case None => None } def validateOrDestroyTrainingTable(sqlContext: SQLContext,