Source code for tianshou.policy.imitation.cql

from typing import Any, Dict, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import SACPolicy
from tianshou.utils.net.continuous import ActorProb


[docs]class CQLPolicy(SACPolicy): """Implementation of CQL algorithm. arXiv:2006.04779. :param ActorProb actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> a) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic1_optim: the optimizer for the first critic network. :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. :param float cql_alpha_lr: the learning rate of cql_log_alpha. Default to 1e-4. :param float cql_weight: the value of alpha. Default to 1.0. :param float tau: param for soft update of the target network. Default to 0.005. :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient. Default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then alpha is automatically tuned. :param float temperature: the value of temperature. Default to 1.0. :param bool with_lagrange: whether to use Lagrange. Default to True. :param float lagrange_threshold: the value of tau in CQL(Lagrange). Default to 10.0. :param float min_action: The minimum value of each dimension of action. Default to -1.0. :param float max_action: The maximum value of each dimension of action. Default to 1.0. :param int num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp. Default to 10. :param float alpha_min: lower bound for clipping cql_alpha. Default to 0.0. :param float alpha_max: upper bound for clipping cql_alpha. Default to 1e6. :param float clip_grad: clip_grad for updating critic network. Default to 1.0. :param Union[str, torch.device] device: which device to create this model on. Default to "cpu". .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__( self, actor: ActorProb, actor_optim: torch.optim.Optimizer, critic1: torch.nn.Module, critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, cql_alpha_lr: float = 1e-4, cql_weight: float = 1.0, tau: float = 0.005, gamma: float = 0.99, alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, temperature: float = 1.0, with_lagrange: bool = True, lagrange_threshold: float = 10.0, min_action: float = -1.0, max_action: float = 1.0, num_repeat_actions: int = 10, alpha_min: float = 0.0, alpha_max: float = 1e6, clip_grad: float = 1.0, device: Union[str, torch.device] = "cpu", **kwargs: Any ) -> None: super().__init__( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau, gamma, alpha, **kwargs ) # There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy. self.device = device self.temperature = temperature self.with_lagrange = with_lagrange self.lagrange_threshold = lagrange_threshold self.cql_weight = cql_weight self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) self.cql_log_alpha = self.cql_log_alpha.to(device) self.min_action = min_action self.max_action = max_action self.num_repeat_actions = num_repeat_actions self.alpha_min = alpha_min self.alpha_max = alpha_max self.clip_grad = clip_grad
[docs] def train(self, mode: bool = True) -> "CQLPolicy": """Set the module in training mode, except for the target network.""" self.training = mode self.actor.train(mode) self.critic1.train(mode) self.critic2.train(mode) return self
[docs] def sync_weight(self) -> None: """Soft-update the weight for the target network.""" self.soft_update(self.critic1_old, self.critic1, self.tau) self.soft_update(self.critic2_old, self.critic2, self.tau)
[docs] def actor_pred(self, obs: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: batch = Batch(obs=obs, info=None) obs_result = self(batch) return obs_result.act, obs_result.log_prob
[docs] def calc_actor_loss(self, obs: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: act_pred, log_pi = self.actor_pred(obs) q1 = self.critic1(obs, act_pred) q2 = self.critic2(obs, act_pred) min_Q = torch.min(q1, q2) self._alpha: Union[float, torch.Tensor] actor_loss = (self._alpha * log_pi - min_Q).mean() # actor_loss.shape: (), log_pi.shape: (batch_size, 1) return actor_loss, log_pi
[docs] def calc_pi_values(self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: act_pred, log_pi = self.actor_pred(obs_pi) q1 = self.critic1(obs_to_pred, act_pred) q2 = self.critic2(obs_to_pred, act_pred) return q1 - log_pi.detach(), q2 - log_pi.detach()
[docs] def calc_random_values(self, obs: torch.Tensor, act: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: random_value1 = self.critic1(obs, act) random_log_prob1 = np.log(0.5**act.shape[-1]) random_value2 = self.critic2(obs, act) random_log_prob2 = np.log(0.5**act.shape[-1]) return random_value1 - random_log_prob1, random_value2 - random_log_prob2
[docs] def process_fn( self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray ) -> Batch: return batch
[docs] def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: batch: Batch = to_torch( # type: ignore batch, dtype=torch.float, device=self.device, ) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] # compute actor loss and update actor actor_loss, log_pi = self.calc_actor_loss(obs) self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # compute alpha loss if self._is_auto_alpha: log_pi = log_pi + self._target_entropy alpha_loss = -(self._log_alpha * log_pi.detach()).mean() self._alpha_optim.zero_grad() # update log_alpha alpha_loss.backward() self._alpha_optim.step() # update alpha self._alpha = self._log_alpha.detach().exp() # compute target_Q with torch.no_grad(): act_next, new_log_pi = self.actor_pred(obs_next) target_Q1 = self.critic1_old(obs_next, act_next) target_Q2 = self.critic2_old(obs_next, act_next) target_Q = torch.min(target_Q1, target_Q2) - self._alpha * new_log_pi target_Q = \ rew + self._gamma * (1 - batch.done) * target_Q.flatten() # shape: (batch_size) # compute critic loss current_Q1 = self.critic1(obs, act).flatten() current_Q2 = self.critic2(obs, act).flatten() # shape: (batch_size) critic1_loss = F.mse_loss(current_Q1, target_Q) critic2_loss = F.mse_loss(current_Q2, target_Q) # CQL random_actions = torch.FloatTensor( batch_size * self.num_repeat_actions, act.shape[-1] ).uniform_(-self.min_action, self.max_action).to(self.device) tmp_obs = obs.unsqueeze(1) \ .repeat(1, self.num_repeat_actions, 1) \ .view(batch_size * self.num_repeat_actions, obs.shape[-1]) tmp_obs_next = obs_next.unsqueeze(1) \ .repeat(1, self.num_repeat_actions, 1) \ .view(batch_size * self.num_repeat_actions, obs.shape[-1]) # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs) next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs) random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions) for value in [ current_pi_value1, current_pi_value2, next_pi_value1, next_pi_value2, random_value1, random_value2 ]: value.reshape(batch_size, self.num_repeat_actions, 1) # cat q values cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1) cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1) # shape: (batch_size, 3 * num_repeat, 1) cql1_scaled_loss = \ torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() * \ self.cql_weight * self.temperature - current_Q1.mean() * \ self.cql_weight cql2_scaled_loss = \ torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() * \ self.cql_weight * self.temperature - current_Q2.mean() * \ self.cql_weight # shape: (1) if self.with_lagrange: cql_alpha = torch.clamp( self.cql_log_alpha.exp(), self.alpha_min, self.alpha_max, ) cql1_scaled_loss = \ cql_alpha * (cql1_scaled_loss - self.lagrange_threshold) cql2_scaled_loss = \ cql_alpha * (cql2_scaled_loss - self.lagrange_threshold) self.cql_alpha_optim.zero_grad() cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5 cql_alpha_loss.backward(retain_graph=True) self.cql_alpha_optim.step() critic1_loss = critic1_loss + cql1_scaled_loss critic2_loss = critic2_loss + cql2_scaled_loss # update critic self.critic1_optim.zero_grad() critic1_loss.backward(retain_graph=True) # clip grad, prevent the vanishing gradient problem # It doesn't seem necessary clip_grad_norm_(self.critic1.parameters(), self.clip_grad) self.critic1_optim.step() self.critic2_optim.zero_grad() critic2_loss.backward() clip_grad_norm_(self.critic2.parameters(), self.clip_grad) self.critic2_optim.step() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore if self.with_lagrange: result["loss/cql_alpha"] = cql_alpha_loss.item() result["cql_alpha"] = cql_alpha.item() return result