Source code for tianshou.policy.imitation.bcq

import copy
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F

from tianshou.data import Batch, to_torch
from tianshou.policy import BasePolicy
from tianshou.utils.net.continuous import VAE


[docs]class BCQPolicy(BasePolicy): """Implementation of BCQ algorithm. arXiv:1812.02900. :param Perturbation actor: the actor perturbation. (s, a -> perturbed a) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic1_optim: the optimizer for the first critic network. :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. :param VAE vae: the VAE network, generating actions similar to those in batch. (s, a -> generated a) :param torch.optim.Optimizer vae_optim: the optimizer for the VAE network. :param Union[str, torch.device] device: which device to create this model on. Default to "cpu". :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param float tau: param for soft update of the target network. Default to 0.005. :param float lmbda: param for Clipped Double Q-learning. Default to 0.75. :param int forward_sampled_times: the number of sampled actions in forward function. The policy samples many actions and takes the action with the max value. Default to 100. :param int num_sampled_action: the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q. Default to 10. .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__( self, actor: torch.nn.Module, actor_optim: torch.optim.Optimizer, critic1: torch.nn.Module, critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, vae: VAE, vae_optim: torch.optim.Optimizer, device: Union[str, torch.device] = "cpu", gamma: float = 0.99, tau: float = 0.005, lmbda: float = 0.75, forward_sampled_times: int = 100, num_sampled_action: int = 10, **kwargs: Any ) -> None: # actor is Perturbation! super().__init__(**kwargs) self.actor = actor self.actor_target = copy.deepcopy(self.actor) self.actor_optim = actor_optim self.critic1 = critic1 self.critic1_target = copy.deepcopy(self.critic1) self.critic1_optim = critic1_optim self.critic2 = critic2 self.critic2_target = copy.deepcopy(self.critic2) self.critic2_optim = critic2_optim self.vae = vae self.vae_optim = vae_optim self.gamma = gamma self.tau = tau self.lmbda = lmbda self.device = device self.forward_sampled_times = forward_sampled_times self.num_sampled_action = num_sampled_action
[docs] def train(self, mode: bool = True) -> "BCQPolicy": """Set the module in training mode, except for the target network.""" self.training = mode self.actor.train(mode) self.critic1.train(mode) self.critic2.train(mode) return self
[docs] def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. obs_group: torch.Tensor = to_torch( # type: ignore batch.obs, device=self.device ) act_group = [] for obs in obs_group: # now obs is (state_dim) obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) # now obs is (forward_sampled_times, state_dim) # decode(obs) generates action and actor perturbs it act = self.actor(obs, self.vae.decode(obs)) # now action is (forward_sampled_times, action_dim) q1 = self.critic1(obs, act) # q1 is (forward_sampled_times, 1) max_indice = q1.argmax(0) act_group.append(act[max_indice].cpu().data.numpy().flatten()) act_group = np.array(act_group) return Batch(act=act_group)
[docs] def sync_weight(self) -> None: """Soft-update the weight for the target network.""" self.soft_update(self.critic1_target, self.critic1, self.tau) self.soft_update(self.critic2_target, self.critic2, self.tau) self.soft_update(self.actor_target, self.actor, self.tau)
[docs] def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) batch: Batch = to_torch( # type: ignore batch, dtype=torch.float, device=self.device ) obs, act = batch.obs, batch.act batch_size = obs.shape[0] # mean, std: (state.shape[0], latent_dim) recon, mean, std = self.vae(obs, act) recon_loss = F.mse_loss(act, recon) # (....) is D_KL( N(mu, sigma) || N(0,1) ) KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() vae_loss = recon_loss + KL_loss / 2 self.vae_optim.zero_grad() vae_loss.backward() self.vae_optim.step() # critic training: with torch.no_grad(): # repeat num_sampled_action times obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) # now obs_next: (num_sampled_action * batch_size, state_dim) # perturbed action generated by VAE act_next = self.vae.decode(obs_next) # now obs_next: (num_sampled_action * batch_size, action_dim) target_Q1 = self.critic1_target(obs_next, act_next) target_Q2 = self.critic2_target(obs_next, act_next) # Clipped Double Q-learning target_Q = \ self.lmbda * torch.min(target_Q1, target_Q2) + \ (1 - self.lmbda) * torch.max(target_Q1, target_Q2) # now target_Q: (num_sampled_action * batch_size, 1) # the max value of Q target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) # now target_Q: (batch_size, 1) target_Q = \ batch.rew.reshape(-1, 1) + \ (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q current_Q1 = self.critic1(obs, act) current_Q2 = self.critic2(obs, act) critic1_loss = F.mse_loss(current_Q1, target_Q) critic2_loss = F.mse_loss(current_Q2, target_Q) self.critic1_optim.zero_grad() self.critic2_optim.zero_grad() critic1_loss.backward() critic2_loss.backward() self.critic1_optim.step() self.critic2_optim.step() sampled_act = self.vae.decode(obs) perturbed_act = self.actor(obs, sampled_act) # max actor_loss = -self.critic1(obs, perturbed_act).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # update target network self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), "loss/vae": vae_loss.item(), } return result