Files
ReAgent/reagent/preprocessing/batch_preprocessor.py
Zhengxing Chen 0d294b11e5 Back out recent refactor
Summary:
Need more tests before landing the refactor diffs: D22702504 (https://github.com/facebookresearch/ReAgent/commit/1b470c489d19c33beab88b8ea2e79843d4d31f28), D23123762 (https://github.com/facebookresearch/ReAgent/commit/76829287265bc39f879f3bc1d946a1374c5e1141), D23124179 (https://github.com/facebookresearch/ReAgent/commit/b28f84aa013be00194508f52498160592cb37e9d), D23219012 (https://github.com/facebookresearch/ReAgent/commit/e404c5772ea4118105c2eb136ca96ad5ca8e01db)

Back out to a version based on D23155753.

Check our team diff history: https://fburl.com/diffs/ppsgazgj

Reviewed By: kittipatv

Differential Revision: D23270626

fbshipit-source-id: 14653066bb3924a987a54650a51241895b321c8e
2020-08-21 15:59:42 -07:00

156 lines
6.4 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from reagent import types as rlt
from reagent.preprocessing.preprocessor import Preprocessor
class BatchPreprocessor(nn.Module):
pass
def batch_to_device(batch: Dict[str, torch.Tensor], device: torch.device):
out = {}
for k in batch:
out[k] = batch[k].to(device)
return out
class DiscreteDqnBatchPreprocessor(BatchPreprocessor):
def __init__(
self, num_actions: int, state_preprocessor: Preprocessor, use_gpu: bool
):
super().__init__()
self.num_actions = num_actions
self.state_preprocessor = state_preprocessor
self.device = torch.device("cuda") if use_gpu else torch.device("cpu")
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.DiscreteDqnInput:
batch = batch_to_device(batch, self.device)
preprocessed_state = self.state_preprocessor(
batch["state_features"], batch["state_features_presence"]
)
preprocessed_next_state = self.state_preprocessor(
batch["next_state_features"], batch["next_state_features_presence"]
)
# not terminal iff at least one possible for next action
not_terminal = batch["possible_next_actions_mask"].max(dim=1)[0].float()
action = F.one_hot(batch["action"].to(torch.int64), self.num_actions)
# next action can potentially have value self.num_action if not available
next_action = F.one_hot(
batch["next_action"].to(torch.int64), self.num_actions + 1
)[:, : self.num_actions]
return rlt.DiscreteDqnInput(
state=rlt.FeatureData(preprocessed_state),
next_state=rlt.FeatureData(preprocessed_next_state),
action=action,
next_action=next_action,
reward=batch["reward"].unsqueeze(1),
time_diff=batch["time_diff"].unsqueeze(1),
step=batch["step"].unsqueeze(1),
not_terminal=not_terminal.unsqueeze(1),
possible_actions_mask=batch["possible_actions_mask"],
possible_next_actions_mask=batch["possible_next_actions_mask"],
extras=rlt.ExtraData(
mdp_id=batch["mdp_id"].unsqueeze(1),
sequence_number=batch["sequence_number"].unsqueeze(1),
action_probability=batch["action_probability"].unsqueeze(1),
),
)
class ParametricDqnBatchPreprocessor(BatchPreprocessor):
def __init__(
self,
state_preprocessor: Preprocessor,
action_preprocessor: Preprocessor,
use_gpu: bool,
):
super().__init__()
self.state_preprocessor = state_preprocessor
self.action_preprocessor = action_preprocessor
self.device = torch.device("cuda") if use_gpu else torch.device("cpu")
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.ParametricDqnInput:
batch = batch_to_device(batch, self.device)
# first preprocess state and action
preprocessed_state = self.state_preprocessor(
batch["state_features"], batch["state_features_presence"]
)
preprocessed_next_state = self.state_preprocessor(
batch["next_state_features"], batch["next_state_features_presence"]
)
preprocessed_action = self.action_preprocessor(
batch["action"], batch["action_presence"]
)
preprocessed_next_action = self.action_preprocessor(
batch["next_action"], batch["next_action_presence"]
)
return rlt.ParametricDqnInput(
state=rlt.FeatureData(preprocessed_state),
next_state=rlt.FeatureData(preprocessed_next_state),
action=rlt.FeatureData(preprocessed_action),
next_action=rlt.FeatureData(preprocessed_next_action),
reward=batch["reward"].unsqueeze(1),
time_diff=batch["time_diff"].unsqueeze(1),
step=batch["step"].unsqueeze(1),
not_terminal=batch["not_terminal"].unsqueeze(1),
possible_actions=batch["possible_actions"],
possible_actions_mask=batch["possible_actions_mask"],
possible_next_actions=batch["possible_next_actions"],
possible_next_actions_mask=batch["possible_next_actions_mask"],
extras=rlt.ExtraData(
mdp_id=batch["mdp_id"].unsqueeze(1),
sequence_number=batch["sequence_number"].unsqueeze(1),
action_probability=batch["action_probability"].unsqueeze(1),
),
)
class PolicyNetworkBatchPreprocessor(BatchPreprocessor):
def __init__(
self,
state_preprocessor: Preprocessor,
action_preprocessor: Preprocessor,
use_gpu: bool,
):
super().__init__()
self.state_preprocessor = state_preprocessor
self.action_preprocessor = action_preprocessor
self.device = torch.device("cuda") if use_gpu else torch.device("cpu")
def forward(self, batch: Dict[str, torch.Tensor]) -> rlt.PolicyNetworkInput:
batch = batch_to_device(batch, self.device)
preprocessed_state = self.state_preprocessor(
batch["state_features"], batch["state_features_presence"]
)
preprocessed_next_state = self.state_preprocessor(
batch["next_state_features"], batch["next_state_features_presence"]
)
preprocessed_action = self.action_preprocessor(
batch["action"], batch["action_presence"]
)
preprocessed_next_action = self.action_preprocessor(
batch["next_action"], batch["next_action_presence"]
)
return rlt.PolicyNetworkInput(
state=rlt.FeatureData(preprocessed_state),
next_state=rlt.FeatureData(preprocessed_next_state),
action=rlt.FeatureData(preprocessed_action),
next_action=rlt.FeatureData(preprocessed_next_action),
reward=batch["reward"].unsqueeze(1),
time_diff=batch["time_diff"].unsqueeze(1),
step=batch["step"].unsqueeze(1),
not_terminal=batch["not_terminal"].unsqueeze(1),
extras=rlt.ExtraData(
mdp_id=batch["mdp_id"].unsqueeze(1),
sequence_number=batch["sequence_number"].unsqueeze(1),
action_probability=batch["action_probability"].unsqueeze(1),
),
)