bcq#


class BCQPolicy(*, actor_perturbation: Module, actor_perturbation_optim: Optimizer, critic: Module, critic_optim: Optimizer, action_space: Space, vae: VAE, vae_optim: Optimizer, critic2: Module | None = None, critic2_optim: Optimizer | None = None, device: str | device = 'cpu', gamma: float = 0.99, tau: float = 0.005, lmbda: float = 0.75, forward_sampled_times: int = 100, num_sampled_action: int = 10, observation_space: Space | None = None, action_scaling: bool = False, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of BCQ algorithm. arXiv:1812.02900.

Parameters:
  • actor_perturbation – the actor perturbation. (s, a -> perturbed a)

  • actor_perturbation_optim – the optimizer for actor network.

  • critic – the first critic network.

  • critic_optim – the optimizer for the first critic network.

  • critic2 – the second critic network.

  • critic2_optim – the optimizer for the second critic network.

  • vae – the VAE network, generating actions similar to those in batch.

  • vae_optim – the optimizer for the VAE network.

  • device – which device to create this model on.

  • gamma – discount factor, in [0, 1].

  • tau – param for soft update of the target network.

  • lmbda – param for Clipped Double Q-learning.

  • forward_sampled_times – the number of sampled actions in forward function. The policy samples many actions and takes the action with the max value.

  • num_sampled_action – the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q.

  • observation_space – Env’s observation space.

  • action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.

  • action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to BasePolicy for more detailed explanation.

forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ActBatchProtocol[source]#

Compute action over the given batch data.

learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TBCQTrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

sync_weight() None[source]#

Soft-update the weight for the target network.

train(mode: bool = True) Self[source]#

Set the module in training mode, except for the target network.

class BCQTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, actor_loss: float, critic1_loss: float, critic2_loss: float, vae_loss: float)[source]#
actor_loss: float#
critic1_loss: float#
critic2_loss: float#
vae_loss: float#