types#


class ActBatchProtocol(*args, **kwargs)[source]#

Simplest batch, just containing the action. Useful e.g., for random policy.

act: Tensor | ndarray#
class ActStateBatchProtocol(*args, **kwargs)[source]#

Contains action and state (which can be None), useful for policies that can support RNNs.

state: dict | BatchProtocol | ndarray | None#
class BatchWithAdvantagesProtocol(*args, **kwargs)[source]#

Contains estimated advantages and values.

Returns are usually computed from GAE of advantages by adding the value.

adv: Tensor#
v_s: Tensor#
class BatchWithReturnsProtocol(*args, **kwargs)[source]#

With added returns, usually computed with GAE.

returns: Tensor | ndarray#
class DistBatchProtocol(*args, **kwargs)[source]#

Contains dist instances for actions (created by dist_fn).

Usually categorical or normal.

dist: Distribution#
class DistLogProbBatchProtocol(*args, **kwargs)[source]#

Contains dist objects that can be sampled from and log_prob of taken action.

log_prob: Tensor#
class FQFBatchProtocol(*args, **kwargs)[source]#

Model outputs, fractions and quantiles_tau - specific to the FQF model.

fractions: Tensor#
quantiles_tau: Tensor#
class ImitationBatchProtocol(*args, **kwargs)[source]#

Similar to other batches, but contains imitation_logits and q_value fields.

imitation_logits: Tensor#
q_value: Tensor#
state: dict | Batch | ndarray | None#
class LogpOldProtocol(*args, **kwargs)[source]#

Contains logp_old, often needed for importance weights, in particular in PPO.

Builds on batches that contain advantages and values.

logp_old: Tensor#
class ModelOutputBatchProtocol(*args, **kwargs)[source]#

In addition to state and action, contains model output: (logits).

logits: Tensor#
class ObsBatchProtocol(*args, **kwargs)[source]#

Observations of an environment that a policy can turn into actions.

Typically used inside a policy’s forward

info: Tensor | ndarray#
obs: Tensor | ndarray | BatchProtocol#
class PrioBatchProtocol(*args, **kwargs)[source]#

Contains weights that can be used for prioritized replay.

weight: ndarray | Tensor#
class QuantileRegressionBatchProtocol(*args, **kwargs)[source]#

Contains taus for algorithms using quantile regression.

See e.g. https://arxiv.org/abs/1806.06923

taus: Tensor#
class RecurrentStateBatch(*args, **kwargs)[source]#

Used by RNNs in policies, contains hidden and cell fields.

cell: Tensor#
hidden: Tensor#
class RolloutBatchProtocol(*args, **kwargs)[source]#

Typically, the outcome of sampling from a replay buffer.

act: Tensor | ndarray#
obs_next: Tensor | ndarray | BatchProtocol#
rew: ndarray#
terminated: Tensor | ndarray#
truncated: Tensor | ndarray#