Source code for tianshou.policy.imitation.td3_bc

from dataclasses import dataclass
from typing import Any, Literal, TypeVar

import gymnasium as gym
import torch
import torch.nn.functional as F

from tianshou.data import to_torch_as
from tianshou.data.types import RolloutBatchProtocol
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import TD3Policy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.td3 import TD3TrainingStats


[docs] @dataclass(kw_only=True) class TD3BCTrainingStats(TD3TrainingStats): pass
TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats)
[docs] class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): """Implementation of TD3+BC. arXiv:2106.06860. :param actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. :param action_space: Env's action space. Should be gym.spaces.Box. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, use the same network as critic (via deepcopy). :param critic2_optim: the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param policy_noise: the noise used in updating policy network. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. :param alpha: the value of alpha, which controls the weight for TD3 learning relative to behavior cloning. :param observation_space: Env's observation space. :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. Only used if the action_space is continuous. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__( self, *, actor: torch.nn.Module, actor_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, action_space: gym.Space, critic2: torch.nn.Module | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, exploration_noise: BaseNoise | None = GaussianNoise(sigma=0.1), policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, # TODO: same name as alpha in SAC and REDQ, which also inherit from DDPGPolicy. Rename? alpha: float = 2.5, estimation_step: int = 1, observation_space: gym.Space | None = None, action_scaling: bool = True, action_bound_method: Literal["clip"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: super().__init__( actor=actor, actor_optim=actor_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, exploration_noise=exploration_noise, policy_noise=policy_noise, noise_clip=noise_clip, update_actor_freq=update_actor_freq, estimation_step=estimation_step, action_scaling=action_scaling, action_bound_method=action_bound_method, observation_space=observation_space, lr_scheduler=lr_scheduler, ) self.alpha = alpha
[docs] def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self.update_actor_freq == 0: act = self(batch, eps=0.0).act q_value = self.critic(batch.obs, act) lmbda = self.alpha / q_value.abs().mean().detach() actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return TD3BCTrainingStats( # type: ignore[return-value] actor_loss=self._last, critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), )