mirror of
https://github.com/facebookresearch/ReAgent.git
synced 2026-05-17 12:40:39 +00:00
Support customized reward types in Spark Timeline Operator
Summary: As titled Reviewed By: kittipatv Differential Revision: D19356811 fbshipit-source-id: 820315a0808401b20eeaf26ed012b5bc93ff6fde
This commit is contained in:
committed by
Facebook Github Bot
parent
798a32ecb1
commit
9a75e3dcfd
@@ -20,6 +20,21 @@ object Constants {
|
||||
"metrics"
|
||||
);
|
||||
|
||||
val RANKING_DATA_COLUMN_NAMES = Array(
|
||||
"ds",
|
||||
"mdp_id",
|
||||
"sequence_number",
|
||||
"slate_reward",
|
||||
"item_reward",
|
||||
"action",
|
||||
"action_probability",
|
||||
"state_features",
|
||||
"state_sequence_features",
|
||||
"next_action",
|
||||
"next_state_features",
|
||||
"next_state_sequence_features"
|
||||
);
|
||||
|
||||
val SPARSE_DATA_COLUMN_NAMES = Array(
|
||||
"state_id_list_features",
|
||||
"state_id_score_list_features",
|
||||
@@ -33,4 +48,17 @@ object Constants {
|
||||
"next_action_id_list_features",
|
||||
"next_action_id_score_list_features"
|
||||
);
|
||||
|
||||
val DEFAULT_REWARD_COLUMNS = List[String](
|
||||
"reward",
|
||||
"metrics"
|
||||
);
|
||||
|
||||
val DEFAULT_REWARD_TYPES = Map(
|
||||
"reward" -> "double",
|
||||
"metrics" -> "map<string,double>"
|
||||
);
|
||||
|
||||
val DEFAULT_EXTRA_FEATURE_COLUMNS = List[String]()
|
||||
|
||||
}
|
||||
|
||||
@@ -8,68 +8,67 @@ object Helper {
|
||||
|
||||
private val log = LoggerFactory.getLogger(this.getClass.getName)
|
||||
|
||||
def outputTableIsValid(sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDiscrete: Boolean,
|
||||
extraFeatureColumnTypes: Map[String, String] = Map()): Boolean = {
|
||||
val totalColumns = Constants.TRAINING_DATA_COLUMN_NAMES.size + 2 * extraFeatureColumnTypes.size
|
||||
def next_step_col_name(col_name: String): String =
|
||||
if (col_name == "possible_actions") "possible_next_actions" else "next_" + col_name
|
||||
|
||||
def next_step_col_type(col_type: String, next_step_col_is_arr: Boolean): String =
|
||||
if (next_step_col_is_arr) s"array<${col_type}>" else col_type
|
||||
|
||||
def outputTableIsValid(
|
||||
sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDataType: String = "string",
|
||||
rewardTypes: Map[String, String] = Constants.DEFAULT_REWARD_TYPES,
|
||||
timelineJoinTypes: Map[String, String] = Map("possible_actions" -> "array<string>"),
|
||||
next_step_col_is_arr: Boolean = false
|
||||
): Boolean = {
|
||||
// check column types
|
||||
var actionType = "string";
|
||||
var possibleActionType = "array<string>";
|
||||
if (!actionDiscrete) {
|
||||
actionType = "map<bigint,double>"
|
||||
possibleActionType = "array<map<bigint,double>>"
|
||||
}
|
||||
val dt = sqlContext.sparkSession.catalog
|
||||
.listColumns(tableName)
|
||||
.collect
|
||||
.map(column => column.name -> column.dataType)
|
||||
.toMap
|
||||
|
||||
val nextActionDataType = this.next_step_col_type(actionDataType, next_step_col_is_arr)
|
||||
(
|
||||
dt.size == totalColumns &&
|
||||
actionType == dt.getOrElse("action", "") &&
|
||||
possibleActionType == dt.getOrElse("possible_actions", "") &&
|
||||
extraFeatureColumnTypes.filter {
|
||||
case (k, v) => (v == dt.getOrElse(k, "") && v == dt.getOrElse(s"next_${k}", ""))
|
||||
}.size == extraFeatureColumnTypes.size
|
||||
actionDataType == dt.getOrElse("action", "") &&
|
||||
nextActionDataType == dt.getOrElse("next_action", "") &&
|
||||
rewardTypes.filter { case (k, v) => (v == dt.getOrElse(k, "")) }.size == rewardTypes.size &&
|
||||
timelineJoinTypes.filter {
|
||||
case (k, v) =>
|
||||
(v == dt.getOrElse(k, "") &&
|
||||
this.next_step_col_type(v, next_step_col_is_arr) == dt.getOrElse(
|
||||
this.next_step_col_name(k),
|
||||
""
|
||||
))
|
||||
}.size == timelineJoinTypes.size
|
||||
)
|
||||
}
|
||||
|
||||
def getDataTypes(sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
columnNames: List[String]): Map[String, String] = {
|
||||
// null check is required because jackson doesn't care about default values
|
||||
val notNullColumnNames = Option(columnNames).getOrElse(List[String]())
|
||||
def getDataTypes(
|
||||
sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
columnNames: List[String]
|
||||
): Map[String, String] = {
|
||||
val dt = sqlContext.sparkSession.catalog
|
||||
.listColumns(tableName)
|
||||
.collect
|
||||
.filter(column => notNullColumnNames.contains(column.name))
|
||||
.filter(column => columnNames.contains(column.name))
|
||||
.map(column => column.name -> column.dataType)
|
||||
.toMap
|
||||
assert(dt.size == notNullColumnNames.size)
|
||||
assert(dt.size == columnNames.size)
|
||||
dt
|
||||
}
|
||||
|
||||
def validateOrDestroyTrainingTable(sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDiscrete: Boolean,
|
||||
extraFeatureTypes: Map[String, String] = Map()): Unit =
|
||||
def destroyTrainingTable(
|
||||
sqlContext: SQLContext,
|
||||
tableName: String
|
||||
): Unit =
|
||||
try {
|
||||
// Validate the schema and destroy the output table if it doesn't match
|
||||
var validTable = Helper.outputTableIsValid(
|
||||
sqlContext,
|
||||
tableName,
|
||||
actionDiscrete,
|
||||
extraFeatureTypes
|
||||
)
|
||||
if (!validTable) {
|
||||
val dropTableCommand = s"""
|
||||
DROP TABLE ${tableName}
|
||||
"""
|
||||
sqlContext.sql(dropTableCommand);
|
||||
}
|
||||
val dropTableCommand = s"""
|
||||
DROP TABLE ${tableName}
|
||||
"""
|
||||
sqlContext.sql(dropTableCommand);
|
||||
} catch {
|
||||
case e: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => {}
|
||||
case e: Throwable => log.error(e.toString())
|
||||
|
||||
@@ -5,15 +5,17 @@ import org.slf4j.LoggerFactory
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions.udf
|
||||
|
||||
case class MultiStepTimelineConfiguration(startDs: String,
|
||||
endDs: String,
|
||||
addTerminalStateRow: Boolean,
|
||||
actionDiscrete: Boolean,
|
||||
inputTableName: String,
|
||||
outputTableName: String,
|
||||
evalTableName: String,
|
||||
numOutputShards: Int,
|
||||
steps: Int)
|
||||
case class MultiStepTimelineConfiguration(
|
||||
startDs: String,
|
||||
endDs: String,
|
||||
addTerminalStateRow: Boolean,
|
||||
actionDiscrete: Boolean,
|
||||
inputTableName: String,
|
||||
outputTableName: String,
|
||||
evalTableName: String,
|
||||
numOutputShards: Int,
|
||||
steps: Int
|
||||
)
|
||||
|
||||
/**
|
||||
* Given table of state, action, mdp_id, sequence_number, reward, possible_next_actions
|
||||
@@ -121,15 +123,21 @@ object MultiStepTimeline {
|
||||
if (config.addTerminalStateRow) {
|
||||
terminalJoin = "LEFT OUTER";
|
||||
}
|
||||
|
||||
val actionDataType =
|
||||
Helper.getDataTypes(sqlContext, config.inputTableName, List("action"))("action")
|
||||
log.info("action column data type:" + s"${actionDataType}")
|
||||
assert(Set("string", "map<bigint,double>").contains(actionDataType))
|
||||
val actionDiscrete = actionDataType == "string"
|
||||
|
||||
var sortActionMethod = "UDF_SORT_ID";
|
||||
var sortPossibleActionMethod = "UDF_SORT_ARRAY_ID";
|
||||
if (!config.actionDiscrete) {
|
||||
if (!actionDiscrete) {
|
||||
sortActionMethod = "UDF_SORT_MAP";
|
||||
sortPossibleActionMethod = "UDF_SORT_ARRAY_MAP";
|
||||
}
|
||||
|
||||
Helper.validateOrDestroyTrainingTable(sqlContext, config.outputTableName, config.actionDiscrete)
|
||||
MultiStepTimeline.createTrainingTable(sqlContext, config.outputTableName, config.actionDiscrete)
|
||||
MultiStepTimeline.createTrainingTable(sqlContext, config.outputTableName, actionDiscrete)
|
||||
MultiStepTimeline.registerUDFs(sqlContext)
|
||||
|
||||
val sqlCommand = s"""
|
||||
@@ -275,9 +283,11 @@ object MultiStepTimeline {
|
||||
sqlContext.sql(sqlCommand)
|
||||
}
|
||||
|
||||
def createTrainingTable(sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDiscrete: Boolean): Unit = {
|
||||
def createTrainingTable(
|
||||
sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDiscrete: Boolean
|
||||
): Unit = {
|
||||
var actionType = "STRING";
|
||||
var possibleActionType = "ARRAY<STRING>";
|
||||
if (!actionDiscrete) {
|
||||
|
||||
@@ -11,18 +11,22 @@ import org.apache.spark.sql.types._
|
||||
|
||||
case class ExtraFeatureColumn(columnName: String, columnType: String)
|
||||
|
||||
case class TimelineConfiguration(startDs: String,
|
||||
endDs: String,
|
||||
addTerminalStateRow: Boolean,
|
||||
actionDiscrete: Boolean,
|
||||
inputTableName: String,
|
||||
outputTableName: String,
|
||||
evalTableName: String,
|
||||
numOutputShards: Int,
|
||||
outlierEpisodeLengthPercentile: Option[Double] = None,
|
||||
percentileFunction: String = "percentile_approx",
|
||||
extraFeatureColumns: List[String] = List(),
|
||||
timeWindowLimit: Option[Long] = None)
|
||||
case class TimelineConfiguration(
|
||||
startDs: String,
|
||||
endDs: String,
|
||||
addTerminalStateRow: Boolean,
|
||||
actionDiscrete: Boolean,
|
||||
inputTableName: String,
|
||||
outputTableName: String,
|
||||
evalTableName: String,
|
||||
numOutputShards: Int,
|
||||
includePossibleActions: Boolean = true,
|
||||
outlierEpisodeLengthPercentile: Option[Double] = None,
|
||||
percentileFunction: String = "percentile_approx",
|
||||
rewardColumns: List[String] = Constants.DEFAULT_REWARD_COLUMNS,
|
||||
extraFeatureColumns: List[String] = Constants.DEFAULT_EXTRA_FEATURE_COLUMNS,
|
||||
timeWindowLimit: Option[Long] = None
|
||||
)
|
||||
|
||||
/**
|
||||
* Given table of state, action, mdp_id, sequence_number, reward, possible_next_actions
|
||||
@@ -124,19 +128,31 @@ object Timeline {
|
||||
filterTerminal = "";
|
||||
}
|
||||
|
||||
val extraFeatureColumnDataTypes =
|
||||
Helper.getDataTypes(sqlContext, config.inputTableName, config.extraFeatureColumns)
|
||||
val actionDataType =
|
||||
Helper.getDataTypes(sqlContext, config.inputTableName, List("action"))("action")
|
||||
log.info("action column data type:" + s"${actionDataType}")
|
||||
|
||||
Helper.validateOrDestroyTrainingTable(sqlContext,
|
||||
config.outputTableName,
|
||||
config.actionDiscrete,
|
||||
extraFeatureColumnDataTypes)
|
||||
var timelineJoinColumns = config.extraFeatureColumns
|
||||
if (config.includePossibleActions) {
|
||||
timelineJoinColumns = "possible_actions" :: timelineJoinColumns
|
||||
}
|
||||
|
||||
val rewardColumnDataTypes =
|
||||
Helper.getDataTypes(sqlContext, config.inputTableName, config.rewardColumns)
|
||||
log.info("reward columns:" + s"${config.rewardColumns}")
|
||||
log.info("reward column types:" + s"${rewardColumnDataTypes}")
|
||||
|
||||
val timelineJoinColumnDataTypes =
|
||||
Helper.getDataTypes(sqlContext, config.inputTableName, timelineJoinColumns)
|
||||
log.info("timeline join column columns:" + s"${timelineJoinColumns}")
|
||||
log.info("timeline join column types:" + s"${timelineJoinColumnDataTypes}")
|
||||
|
||||
Timeline.createTrainingTable(
|
||||
sqlContext,
|
||||
config.outputTableName,
|
||||
config.actionDiscrete,
|
||||
extraFeatureColumnDataTypes
|
||||
actionDataType,
|
||||
rewardColumnDataTypes,
|
||||
timelineJoinColumnDataTypes
|
||||
)
|
||||
|
||||
config.outlierEpisodeLengthPercentile.foreach { percentile =>
|
||||
@@ -165,7 +181,10 @@ object Timeline {
|
||||
}
|
||||
.getOrElse("WHERE")
|
||||
|
||||
val extraSourceColumns = extraFeatureColumnDataTypes.foldLeft("") {
|
||||
val rewardSourceColumns = rewardColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) => s"${acc}, a.${k}"
|
||||
}
|
||||
val timelineSourceColumns = timelineJoinColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) => s"${acc}, a.${k}"
|
||||
}
|
||||
|
||||
@@ -192,13 +211,11 @@ object Timeline {
|
||||
SELECT
|
||||
a.mdp_id,
|
||||
a.state_features,
|
||||
a.action,
|
||||
a.action_probability,
|
||||
a.reward,
|
||||
a.sequence_number,
|
||||
a.possible_actions,
|
||||
a.metrics
|
||||
${extraSourceColumns}
|
||||
a.action
|
||||
${rewardSourceColumns},
|
||||
a.sequence_number
|
||||
${timelineSourceColumns}
|
||||
FROM ${config.inputTableName} a
|
||||
${joinClause}
|
||||
a.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
|
||||
@@ -212,7 +229,10 @@ object Timeline {
|
||||
}
|
||||
.getOrElse("source_table")
|
||||
|
||||
val extraFeatureQuery = extraFeatureColumnDataTypes.foldLeft("") {
|
||||
val rewardColumnsQuery = rewardColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) => s"${acc}, ${k}"
|
||||
}
|
||||
val timelineJoinColumnsQuery = timelineJoinColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) =>
|
||||
s"""
|
||||
${acc},
|
||||
@@ -223,7 +243,7 @@ object Timeline {
|
||||
ORDER BY
|
||||
mdp_id,
|
||||
sequence_number
|
||||
) AS next_${k}
|
||||
) AS ${Helper.next_step_col_name(k)}
|
||||
"""
|
||||
}
|
||||
|
||||
@@ -233,15 +253,6 @@ object Timeline {
|
||||
mdp_id,
|
||||
state_features,
|
||||
action,
|
||||
action_probability,
|
||||
reward,
|
||||
LEAD(state_features) OVER (
|
||||
PARTITION BY
|
||||
mdp_id
|
||||
ORDER BY
|
||||
mdp_id,
|
||||
sequence_number
|
||||
) AS next_state_features,
|
||||
LEAD(action) OVER (
|
||||
PARTITION BY
|
||||
mdp_id
|
||||
@@ -249,6 +260,15 @@ object Timeline {
|
||||
mdp_id,
|
||||
sequence_number
|
||||
) AS next_action,
|
||||
action_probability
|
||||
${rewardColumnsQuery},
|
||||
LEAD(state_features) OVER (
|
||||
PARTITION BY
|
||||
mdp_id
|
||||
ORDER BY
|
||||
mdp_id,
|
||||
sequence_number
|
||||
) AS next_state_features,
|
||||
sequence_number,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY
|
||||
@@ -270,16 +290,8 @@ object Timeline {
|
||||
ORDER BY
|
||||
mdp_id,
|
||||
sequence_number
|
||||
) AS time_since_first,
|
||||
possible_actions,
|
||||
LEAD(possible_actions) OVER (
|
||||
PARTITION BY
|
||||
mdp_id
|
||||
ORDER BY
|
||||
mdp_id,
|
||||
sequence_number
|
||||
) AS possible_next_actions,
|
||||
metrics${extraFeatureQuery}
|
||||
) AS time_since_first
|
||||
${timelineJoinColumnsQuery}
|
||||
FROM ${sourceTableName}
|
||||
${filterTerminal}
|
||||
CLUSTER BY HASH(mdp_id, sequence_number)
|
||||
@@ -290,18 +302,26 @@ object Timeline {
|
||||
log.info("Done with query")
|
||||
|
||||
// Handle nulls in output present when terminal states are present
|
||||
val nextAction = df("next_action")
|
||||
val possibleNextActions = df("possible_next_actions")
|
||||
|
||||
df = df.withColumn("next_state_features", coalesce(df("next_state_features"), Udfs.emptyMap()))
|
||||
if (config.actionDiscrete) {
|
||||
val handle_cols = timelineJoinColumnDataTypes.++(
|
||||
Map(
|
||||
"action" -> actionDataType,
|
||||
"state_features" -> "map<bigint,double>"
|
||||
)
|
||||
)
|
||||
for ((col_name, col_type) <- handle_cols) {
|
||||
val next_col_name = Helper.next_step_col_name(col_name)
|
||||
val empty_placeholder = col_type match {
|
||||
case "string" => Udfs.emptyStr()
|
||||
case "array<string>" => Udfs.emptyArrOfStr()
|
||||
case "map<bigint,double>" => Udfs.emptyMap()
|
||||
case "array<map<bigint,double>>" => Udfs.emptyArrOfMap()
|
||||
case "array<bigint>" => Udfs.emptyArrOfLong()
|
||||
case "map<bigint,array<bigint>>" => Udfs.emptyMapOfIds()
|
||||
case "map<bigint,map<bigint,double>>" => Udfs.emptyMapOfMap()
|
||||
case "map<bigint,array<map<bigint,double>>>" => Udfs.emptyMapOfArrOfMap()
|
||||
}
|
||||
df = df
|
||||
.withColumn("next_action", coalesce(nextAction, Udfs.emptyStr()))
|
||||
.withColumn("possible_next_actions", coalesce(possibleNextActions, Udfs.emptyArrOfStr()))
|
||||
} else {
|
||||
df = df
|
||||
.withColumn("next_action", coalesce(nextAction, Udfs.emptyMap()))
|
||||
.withColumn("possible_next_actions", coalesce(possibleNextActions, Udfs.emptyArrOfMap()))
|
||||
.withColumn(next_col_name, coalesce(df(next_col_name), empty_placeholder))
|
||||
}
|
||||
|
||||
val finalTableName = "finalTable"
|
||||
@@ -344,26 +364,26 @@ object Timeline {
|
||||
val expected_outlier_percent = 1.0 - percentile
|
||||
if (abs(outlier_percent - expected_outlier_percent) / expected_outlier_percent > 0.1) {
|
||||
log.warn(
|
||||
s"Outlier percent mismatch; expected: ${expected_outlier_percent}; got ${outlier_percent}")
|
||||
s"Outlier percent mismatch; expected: ${expected_outlier_percent}; got ${outlier_percent}"
|
||||
)
|
||||
None
|
||||
} else
|
||||
Some(pct_val)
|
||||
}
|
||||
}
|
||||
|
||||
def createTrainingTable(sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDiscrete: Boolean,
|
||||
extraFeatureColumnDataTypes: Map[String, String] = Map()): Unit = {
|
||||
var actionType = "STRING";
|
||||
var possibleActionType = "ARRAY<STRING>";
|
||||
if (!actionDiscrete) {
|
||||
actionType = "MAP<BIGINT, DOUBLE>"
|
||||
possibleActionType = "ARRAY<MAP<BIGINT,DOUBLE>>"
|
||||
def createTrainingTable(
|
||||
sqlContext: SQLContext,
|
||||
tableName: String,
|
||||
actionDataType: String,
|
||||
rewardColumnDataTypes: Map[String, String] = Map("reward" -> "double"),
|
||||
timelineJoinColumnDataTypes: Map[String, String] = Map()
|
||||
): Unit = {
|
||||
val rewardColumns = rewardColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) => s"${acc}, ${k} ${v}"
|
||||
}
|
||||
|
||||
val extraFeatureColumns = extraFeatureColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) => s"${acc}, ${k} ${v}, next_${k} ${v}"
|
||||
val timelineJoinColumns = timelineJoinColumnDataTypes.foldLeft("") {
|
||||
case (acc, (k, v)) => s"${acc}, ${k} ${v}, ${Helper.next_step_col_name(k)} ${v}"
|
||||
}
|
||||
|
||||
val sqlCommand = s"""
|
||||
@@ -371,19 +391,17 @@ CREATE TABLE IF NOT EXISTS ${tableName} (
|
||||
mdp_id STRING,
|
||||
state_features MAP < BIGINT,
|
||||
DOUBLE >,
|
||||
action ${actionType},
|
||||
action_probability DOUBLE,
|
||||
reward DOUBLE,
|
||||
action ${actionDataType},
|
||||
next_action ${actionDataType},
|
||||
action_probability DOUBLE
|
||||
${rewardColumns},
|
||||
next_state_features MAP < BIGINT,
|
||||
DOUBLE >,
|
||||
next_action ${actionType},
|
||||
sequence_number BIGINT,
|
||||
sequence_number_ordinal BIGINT,
|
||||
time_diff BIGINT,
|
||||
time_since_first BIGINT,
|
||||
possible_actions ${possibleActionType},
|
||||
possible_next_actions ${possibleActionType},
|
||||
metrics MAP< STRING, DOUBLE>${extraFeatureColumns}
|
||||
time_since_first BIGINT
|
||||
${timelineJoinColumns}
|
||||
) PARTITIONED BY (ds STRING) TBLPROPERTIES ('RETENTION'='30')
|
||||
""".stripMargin
|
||||
sqlContext.sql(sqlCommand);
|
||||
|
||||
@@ -18,7 +18,11 @@ object Udfs {
|
||||
x.dropRight(1)
|
||||
|
||||
val emptyMap = udf(() => Map.empty[Long, Double])
|
||||
val emptyMapOfIds = udf(() => Map.empty[Long, Seq[Long]])
|
||||
val emptyMapOfMap = udf(() => Map.empty[Long, Map[Long, Double]])
|
||||
val emptyMapOfArrOfMap = udf(() => Map.empty[Long, Seq[Map[Long, Double]]])
|
||||
val emptyStr = udf(() => "")
|
||||
val emptyArrOfLong = udf(() => Array.empty[Long])
|
||||
val emptyArrOfStr = udf(() => Array.empty[String])
|
||||
val emptyArrOfDbl = udf(() => Array.empty[Double])
|
||||
val emptyArrOfMap = udf(() => Array.empty[Map[Long, Double]])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user