redq#


class REDQPolicy(*, actor: Module, actor_optim: Optimizer, critic: Module, critic_optim: Optimizer, action_space: Box, ensemble_size: int = 10, subset_size: int = 2, tau: float = 0.005, gamma: float = 0.99, alpha: float | tuple[float, Tensor, Optimizer] = 0.2, estimation_step: int = 1, actor_delay: int = 20, exploration_noise: BaseNoise | Literal['default'] | None = None, deterministic_eval: bool = True, target_mode: Literal['mean', 'min'] = 'min', action_scaling: bool = True, action_bound_method: Literal['clip'] | None = 'clip', observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of REDQ. arXiv:2101.05982.

Parameters:
  • actor – The actor network following the rules in BasePolicy. (s -> model_output)

  • actor_optim – The optimizer for actor network.

  • critic – The critic network. (s, a -> Q(s, a))

  • critic_optim – The optimizer for critic network.

  • action_space – Env’s action space.

  • ensemble_size – Number of sub-networks in the critic ensemble.

  • subset_size – Number of networks in the subset.

  • tau – Param for soft update of the target network.

  • gamma – Discount factor, in [0, 1].

  • alpha – entropy regularization coefficient. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then alpha is automatically tuned.

  • exploration_noise – The exploration noise, added to the action. Defaults to GaussianNoise(sigma=0.1).

  • estimation_step – The number of steps to look ahead.

  • actor_delay – Number of critic updates before an actor update.

  • observation_space – Env’s observation space.

  • action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.

  • action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to BasePolicy for more detailed explanation.

forward(batch: ObsBatchProtocol, state: dict | Batch | ndarray | None = None, **kwargs: Any) Batch[source]#

Compute action over the given batch data.

Returns:

A Batch which has 2 keys:

  • act the action.

  • state the hidden state.

See also

Please refer to forward() for more detailed explanation.

property is_auto_alpha: bool#
learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TREDQTrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

sync_weight() None[source]#

Soft-update the weight for the target network.

class REDQTrainingStats(alpha: float | None = None, alpha_loss: float | None = None, *, train_time: float = 0.0, smoothed_loss: dict = <factory>, actor_loss: float, critic_loss: float)[source]#

A data structure for storing loss statistics of the REDQ learn step.

alpha: float | None = None#
alpha_loss: float | None = None#