mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Save episode length as temp view
Summary: Caching the computation to improve speed Reviewed By: czxttkl Differential Revision: D14018487 fbshipit-source-id: a9d8590884169c95ec78aa7c417aaf18aeb6f835
This commit is contained in:
committed by
Facebook Github Bot
parent
6700c365e7
commit
c313b8b4f9
@@ -124,23 +124,31 @@ object Timeline {
|
||||
config.actionDiscrete)
|
||||
Timeline.createTrainingTable(sqlContext, config.outputTableName, config.actionDiscrete)
|
||||
|
||||
config.outlierEpisodeLengthPercentile.foreach { percentile =>
|
||||
sqlContext.sql(s"""
|
||||
SELECT mdp_id, COUNT(1) mdp_length
|
||||
FROM ${config.inputTableName}
|
||||
WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
|
||||
GROUP BY 1
|
||||
""").createOrReplaceTempView("episode_length")
|
||||
}
|
||||
|
||||
val sourceTable = Timeline.mdpLengthThreshold(sqlContext, config) match {
|
||||
case Some(threshold) => s"""
|
||||
WITH a AS (${Timeline.getEpisodeLengthStatement(config)}),
|
||||
b AS (SELECT mdp_id FROM a WHERE mdp_length < ${threshold}),
|
||||
WITH a AS (SELECT mdp_id FROM episode_length WHERE mdp_length < ${threshold}),
|
||||
source_table AS (
|
||||
SELECT
|
||||
c.mdp_id,
|
||||
c.state_features,
|
||||
c.action,
|
||||
c.action_probability,
|
||||
c.reward,
|
||||
c.sequence_number,
|
||||
c.possible_actions,
|
||||
c.metrics
|
||||
FROM b JOIN ${config.inputTableName} c
|
||||
WHERE b.mdp_id = c.mdp_id
|
||||
AND c.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
|
||||
b.mdp_id,
|
||||
b.state_features,
|
||||
b.action,
|
||||
b.action_probability,
|
||||
b.reward,
|
||||
b.sequence_number,
|
||||
b.possible_actions,
|
||||
b.metrics
|
||||
FROM ${config.inputTableName} b JOIN a
|
||||
WHERE a.mdp_id = b.mdp_id
|
||||
AND b.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
|
||||
)
|
||||
""".stripMargin
|
||||
case None => s"""
|
||||
@@ -245,30 +253,21 @@ object Timeline {
|
||||
sqlContext.sql(insertCommand)
|
||||
}
|
||||
|
||||
def getEpisodeLengthStatement(config: TimelineConfiguration): String = s"""
|
||||
SELECT mdp_id, COUNT(1) mdp_length
|
||||
FROM ${config.inputTableName}
|
||||
WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
|
||||
GROUP BY 1
|
||||
"""
|
||||
|
||||
def mdpLengthThreshold(sqlContext: SQLContext, config: TimelineConfiguration): Option[Int] =
|
||||
config.outlierEpisodeLengthPercentile match {
|
||||
case Some(percentile) => {
|
||||
val episodeLengthStatement = Timeline.getEpisodeLengthStatement(config)
|
||||
config.outlierEpisodeLengthPercentile.flatMap { percentile =>
|
||||
{
|
||||
val df = sqlContext.sql(s"""
|
||||
WITH a AS (${episodeLengthStatement}),
|
||||
b AS (
|
||||
SELECT CAST(${config.percentileFunction}(mdp_length, ${percentile}) AS INT) pct FROM a
|
||||
WITH a AS (
|
||||
SELECT CAST(${config.percentileFunction}(mdp_length, ${percentile}) AS INT) pct FROM episode_length
|
||||
),
|
||||
c AS (
|
||||
b AS (
|
||||
SELECT
|
||||
count(1) as mdp_count,
|
||||
count(IF(a.mdp_length >= b.pct, 1, NULL)) as outlier_count
|
||||
FROM a, b
|
||||
count(IF(episode_length.mdp_length >= a.pct, 1, NULL)) as outlier_count
|
||||
FROM episode_length, a
|
||||
)
|
||||
SELECT b.pct, c.mdp_count, c.outlier_count
|
||||
FROM b, c
|
||||
SELECT a.pct, b.mdp_count, b.outlier_count
|
||||
FROM a, b
|
||||
""")
|
||||
val res = df.first
|
||||
val pct_val = res.getAs[Int]("pct")
|
||||
@@ -284,7 +283,6 @@ object Timeline {
|
||||
} else
|
||||
Some(pct_val)
|
||||
}
|
||||
case None => None
|
||||
}
|
||||
|
||||
def validateOrDestroyTrainingTable(sqlContext: SQLContext,
|
||||
|
||||
Reference in New Issue
Block a user