Source code for tianshou.utils.net.continuous

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch
from torch import nn

from tianshou.utils.net.common import (
    MLP,
    BaseActor,
    TActionShape,
    TLinearLayer,
    get_output_dim,
)

SIGMA_MIN = -20
SIGMA_MAX = 2


[docs] class Actor(BaseActor): """Simple actor network. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param max_action: the scale for the final action logits. Default to 1. :param preprocess_net_output_dim: the output dimension of preprocess_net. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. .. seealso:: Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ def __init__( self, preprocess_net: nn.Module, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, ) -> None: super().__init__() self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, self.output_dim, hidden_sizes, device=self.device, ) self.max_action = max_action
[docs] def get_preprocess_net(self) -> nn.Module: return self.preprocess
[docs] def get_output_dim(self) -> int: return self.output_dim
[docs] def forward( self, obs: np.ndarray | torch.Tensor, state: Any = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: """Mapping: obs -> logits -> action.""" if info is None: info = {} logits, hidden = self.preprocess(obs, state) logits = self.max_action * torch.tanh(self.last(logits)) return logits, hidden
[docs] class Critic(nn.Module): """Simple critic network. It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param preprocess_net_output_dim: the output dimension of preprocess_net. :param linear_layer: use this module as linear layer. Default to nn.Linear. :param flatten_input: whether to flatten input data for the last layer. Default to True. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. .. seealso:: Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ def __init__( self, preprocess_net: nn.Module, hidden_sizes: Sequence[int] = (), device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: super().__init__() self.device = device self.preprocess = preprocess_net self.output_dim = 1 input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, 1, hidden_sizes, device=self.device, linear_layer=linear_layer, flatten_input=flatten_input, )
[docs] def forward( self, obs: np.ndarray | torch.Tensor, act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" if info is None: info = {} obs = torch.as_tensor( obs, device=self.device, dtype=torch.float32, ).flatten(1) if act is not None: act = torch.as_tensor( act, device=self.device, dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) logits, hidden = self.preprocess(obs) return self.last(logits)
[docs] class ActorProb(BaseActor): """Simple actor network (output with a Gauss distribution). :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param max_action: the scale for the final action logits. Default to 1. :param unbounded: whether to apply tanh activation on final logits. Default to False. :param conditioned_sigma: True when sigma is calculated from the input, False when sigma is an independent parameter. Default to False. :param preprocess_net_output_dim: the output dimension of preprocess_net. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. .. seealso:: Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ # TODO: force kwargs, adjust downstream code def __init__( self, preprocess_net: nn.Module, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, device: str | int | torch.device = "cpu", unbounded: bool = False, conditioned_sigma: bool = False, preprocess_net_output_dim: int | None = None, ) -> None: super().__init__() if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.preprocess = preprocess_net self.device = device self.output_dim = int(np.prod(action_shape)) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( input_dim, self.output_dim, hidden_sizes, device=self.device, ) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self.max_action = max_action self._unbounded = unbounded
[docs] def get_preprocess_net(self) -> nn.Module: return self.preprocess
[docs] def get_output_dim(self) -> int: return self.output_dim
[docs] def forward( self, obs: np.ndarray | torch.Tensor, state: Any = None, info: dict[str, Any] | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]: """Mapping: obs -> logits -> (mu, sigma).""" if info is None: info = {} logits, hidden = self.preprocess(obs, state) mu = self.mu(logits) if not self._unbounded: mu = self.max_action * torch.tanh(mu) if self._c_sigma: sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() return (mu, sigma), state
[docs] class RecurrentActorProb(nn.Module): """Recurrent version of ActorProb. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ def __init__( self, layer_num: int, state_shape: Sequence[int], action_shape: Sequence[int], hidden_layer_size: int = 128, max_action: float = 1.0, device: str | int | torch.device = "cpu", unbounded: bool = False, conditioned_sigma: bool = False, ) -> None: super().__init__() if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, num_layers=layer_num, batch_first=True, ) output_dim = int(np.prod(action_shape)) self.mu = nn.Linear(hidden_layer_size, output_dim) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = nn.Linear(hidden_layer_size, output_dim) else: self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) self.max_action = max_action self._unbounded = unbounded
[docs] def forward( self, obs: np.ndarray | torch.Tensor, state: dict[str, torch.Tensor] | None = None, info: dict[str, Any] | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} obs = torch.as_tensor( obs, device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(obs.shape) == 2: obs = obs.unsqueeze(-2) self.nn.flatten_parameters() if state is None: obs, (hidden, cell) = self.nn(obs) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] obs, (hidden, cell) = self.nn( obs, ( state["hidden"].transpose(0, 1).contiguous(), state["cell"].transpose(0, 1).contiguous(), ), ) logits = obs[:, -1] mu = self.mu(logits) if not self._unbounded: mu = self.max_action * torch.tanh(mu) if self._c_sigma: sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] return (mu, sigma), { "hidden": hidden.transpose(0, 1).detach(), "cell": cell.transpose(0, 1).detach(), }
[docs] class RecurrentCritic(nn.Module): """Recurrent version of Critic. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ def __init__( self, layer_num: int, state_shape: Sequence[int], action_shape: Sequence[int] = [0], device: str | int | torch.device = "cpu", hidden_layer_size: int = 128, ) -> None: super().__init__() self.state_shape = state_shape self.action_shape = action_shape self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, num_layers=layer_num, batch_first=True, ) self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
[docs] def forward( self, obs: np.ndarray | torch.Tensor, act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} obs = torch.as_tensor( obs, device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. assert len(obs.shape) == 3 self.nn.flatten_parameters() obs, (hidden, cell) = self.nn(obs) obs = obs[:, -1] if act is not None: act = torch.as_tensor( act, device=self.device, dtype=torch.float32, ) obs = torch.cat([obs, act], dim=1) return self.fc2(obs)
[docs] class Perturbation(nn.Module): """Implementation of perturbation network in BCQ algorithm. Given a state and action, it can generate perturbed action. :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. Default to cpu. :param phi: max perturbation parameter for BCQ. Default to 0.05. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. .. seealso:: You can refer to `examples/offline/offline_bcq.py` to see how to use it. """ def __init__( self, preprocess_net: nn.Module, max_action: float, device: str | int | torch.device = "cpu", phi: float = 0.05, ): # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim super().__init__() self.preprocess_net = preprocess_net self.device = device self.max_action = max_action self.phi = phi
[docs] def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: # preprocess_net logits = self.preprocess_net(torch.cat([state, action], -1))[0] noise = self.phi * self.max_action * torch.tanh(logits) # clip to [-max_action, max_action] return (noise + action).clamp(-self.max_action, self.max_action)
[docs] class VAE(nn.Module): """Implementation of VAE. It models the distribution of action. Given a state, it can generate actions similar to those in batch. It is used in BCQ algorithm. :param encoder: the encoder in VAE. Its input_dim must be state_dim + action_dim, and output_dim must be hidden_dim. :param decoder: the decoder in VAE. Its input_dim must be state_dim + latent_dim, and output_dim must be action_dim. :param hidden_dim: the size of the last linear-layer in encoder. :param latent_dim: the size of latent layer. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. Default to "cpu". For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. .. seealso:: You can refer to `examples/offline/offline_bcq.py` to see how to use it. """ def __init__( self, encoder: nn.Module, decoder: nn.Module, hidden_dim: int, latent_dim: int, max_action: float, device: str | torch.device = "cpu", ): super().__init__() self.encoder = encoder self.mean = nn.Linear(hidden_dim, latent_dim) self.log_std = nn.Linear(hidden_dim, latent_dim) self.decoder = decoder self.max_action = max_action self.latent_dim = latent_dim self.device = device
[docs] def forward( self, state: torch.Tensor, action: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # [state, action] -> z , [state, z] -> action latent_z = self.encoder(torch.cat([state, action], -1)) # shape of z: (state.shape[:-1], hidden_dim) mean = self.mean(latent_z) # Clamped for numerical stability log_std = self.log_std(latent_z).clamp(-4, 15) std = torch.exp(log_std) # shape of mean, std: (state.shape[:-1], latent_dim) latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim) return reconstruction, mean, std
[docs] def decode( self, state: torch.Tensor, latent_z: torch.Tensor | None = None, ) -> torch.Tensor: # decode(state) -> action if latent_z is None: # state.shape[0] may be batch_size # latent vector clipped to [-0.5, 0.5] latent_z = ( torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5) ) # decode z with state! return self.max_action * torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))