Save episode length as temp view

Summary: Caching the computation to improve speed

Reviewed By: czxttkl

Differential Revision: D14018487

fbshipit-source-id: a9d8590884169c95ec78aa7c417aaf18aeb6f835
This commit is contained in:
Kittipat Virochsiri
2019-02-09 17:18:50 -08:00
committed by Facebook Github Bot
parent 6700c365e7
commit c313b8b4f9
@@ -124,23 +124,31 @@ object Timeline {
config.actionDiscrete)
Timeline.createTrainingTable(sqlContext, config.outputTableName, config.actionDiscrete)
config.outlierEpisodeLengthPercentile.foreach { percentile =>
sqlContext.sql(s"""
SELECT mdp_id, COUNT(1) mdp_length
FROM ${config.inputTableName}
WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
GROUP BY 1
""").createOrReplaceTempView("episode_length")
}
val sourceTable = Timeline.mdpLengthThreshold(sqlContext, config) match {
case Some(threshold) => s"""
WITH a AS (${Timeline.getEpisodeLengthStatement(config)}),
b AS (SELECT mdp_id FROM a WHERE mdp_length < ${threshold}),
WITH a AS (SELECT mdp_id FROM episode_length WHERE mdp_length < ${threshold}),
source_table AS (
SELECT
c.mdp_id,
c.state_features,
c.action,
c.action_probability,
c.reward,
c.sequence_number,
c.possible_actions,
c.metrics
FROM b JOIN ${config.inputTableName} c
WHERE b.mdp_id = c.mdp_id
AND c.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
b.mdp_id,
b.state_features,
b.action,
b.action_probability,
b.reward,
b.sequence_number,
b.possible_actions,
b.metrics
FROM ${config.inputTableName} b JOIN a
WHERE a.mdp_id = b.mdp_id
AND b.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
)
""".stripMargin
case None => s"""
@@ -245,30 +253,21 @@ object Timeline {
sqlContext.sql(insertCommand)
}
def getEpisodeLengthStatement(config: TimelineConfiguration): String = s"""
SELECT mdp_id, COUNT(1) mdp_length
FROM ${config.inputTableName}
WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
GROUP BY 1
"""
def mdpLengthThreshold(sqlContext: SQLContext, config: TimelineConfiguration): Option[Int] =
config.outlierEpisodeLengthPercentile match {
case Some(percentile) => {
val episodeLengthStatement = Timeline.getEpisodeLengthStatement(config)
config.outlierEpisodeLengthPercentile.flatMap { percentile =>
{
val df = sqlContext.sql(s"""
WITH a AS (${episodeLengthStatement}),
b AS (
SELECT CAST(${config.percentileFunction}(mdp_length, ${percentile}) AS INT) pct FROM a
WITH a AS (
SELECT CAST(${config.percentileFunction}(mdp_length, ${percentile}) AS INT) pct FROM episode_length
),
c AS (
b AS (
SELECT
count(1) as mdp_count,
count(IF(a.mdp_length >= b.pct, 1, NULL)) as outlier_count
FROM a, b
count(IF(episode_length.mdp_length >= a.pct, 1, NULL)) as outlier_count
FROM episode_length, a
)
SELECT b.pct, c.mdp_count, c.outlier_count
FROM b, c
SELECT a.pct, b.mdp_count, b.outlier_count
FROM a, b
""")
val res = df.first
val pct_val = res.getAs[Int]("pct")
@@ -284,7 +283,6 @@ object Timeline {
} else
Some(pct_val)
}
case None => None
}
def validateOrDestroyTrainingTable(sqlContext: SQLContext,