Source code for tianshou.data.types

from typing import Protocol

import numpy as np
import torch

from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol, arr_type


[docs] class ObsBatchProtocol(BatchProtocol, Protocol): """Observations of an environment that a policy can turn into actions. Typically used inside a policy's forward """ obs: arr_type | BatchProtocol info: arr_type
[docs] class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" obs_next: arr_type | BatchProtocol act: arr_type rew: np.ndarray terminated: arr_type truncated: arr_type
[docs] class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" returns: arr_type
[docs] class PrioBatchProtocol(RolloutBatchProtocol, Protocol): """Contains weights that can be used for prioritized replay.""" weight: np.ndarray | torch.Tensor
[docs] class RecurrentStateBatch(BatchProtocol, Protocol): """Used by RNNs in policies, contains `hidden` and `cell` fields.""" hidden: torch.Tensor cell: torch.Tensor
[docs] class ActBatchProtocol(BatchProtocol, Protocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" act: arr_type
[docs] class ActStateBatchProtocol(ActBatchProtocol, Protocol): """Contains action and state (which can be None), useful for policies that can support RNNs.""" state: dict | BatchProtocol | np.ndarray | None
[docs] class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): """In addition to state and action, contains model output: (logits).""" logits: torch.Tensor
[docs] class FQFBatchProtocol(ModelOutputBatchProtocol, Protocol): """Model outputs, fractions and quantiles_tau - specific to the FQF model.""" fractions: torch.Tensor quantiles_tau: torch.Tensor
[docs] class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol, Protocol): """Contains estimated advantages and values. Returns are usually computed from GAE of advantages by adding the value. """ adv: torch.Tensor v_s: torch.Tensor
[docs] class DistBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains dist instances for actions (created by dist_fn). Usually categorical or normal. """ dist: torch.distributions.Distribution
[docs] class DistLogProbBatchProtocol(DistBatchProtocol, Protocol): """Contains dist objects that can be sampled from and log_prob of taken action.""" log_prob: torch.Tensor
[docs] class LogpOldProtocol(BatchWithAdvantagesProtocol, Protocol): """Contains logp_old, often needed for importance weights, in particular in PPO. Builds on batches that contain advantages and values. """ logp_old: torch.Tensor
[docs] class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains taus for algorithms using quantile regression. See e.g. https://arxiv.org/abs/1806.06923 """ taus: torch.Tensor
[docs] class ImitationBatchProtocol(ActBatchProtocol, Protocol): """Similar to other batches, but contains imitation_logits and q_value fields.""" state: dict | Batch | np.ndarray | None q_value: torch.Tensor imitation_logits: torch.Tensor