tianshou.utils

class tianshou.utils.MovAvg(size: int = 100)[source]

Bases: object

Class for moving average. It will automatically exclude the infinity and NaN. Usage:

>>> stat = MovAvg(size=66)
>>> stat.add(torch.tensor(5))
5.0
>>> stat.add(float('inf'))  # which will not add to stat
5.0
>>> stat.add([6, 7, 8])
6.5
>>> stat.get()
6.5
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
add(x: Union[float, list, numpy.ndarray, torch.Tensor]) → float[source]

Add a scalar into MovAvg. You can add torch.Tensor with only one element, a python scalar, or a list of python scalar.

get() → float[source]

Get the average.

mean() → float[source]

Get the average. Same as get().

std() → float[source]

Get the standard deviation.

class tianshou.utils.net.common.Net(layer_num, state_shape, action_shape=0, device='cpu', softmax=False, concat=False, hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

Simple MLP backbone. For advanced usage (how to customize the network), please refer to Build the Network.

Parameters

concat – whether the input shape is concatenated by state_shape and action_shape. If it is True, action_shape is not the output shape, but affects the input shape.

forward(s, state=None, info={})[source]

s -> flatten -> logits

class tianshou.utils.net.common.Recurrent(layer_num, state_shape, action_shape, device='cpu', hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, info={})[source]

In the evaluation mode, s should be with shape [bsz, dim]; in the training mode, s should be with shape [bsz, len, dim]. See the code and comment for more detail.

class tianshou.utils.net.discrete.Actor(preprocess_net, action_shape, hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, info={})[source]

s -> Q(s, *)

class tianshou.utils.net.discrete.Critic(preprocess_net, hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, **kwargs)[source]

s -> V(s)

class tianshou.utils.net.discrete.DQN(h, w, action_shape, device='cpu')[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(x, state=None, info={})[source]

x -> Q(x, *)

class tianshou.utils.net.continuous.Actor(preprocess_net, action_shape, max_action, device='cpu', hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, info={})[source]

s -> logits -> action

class tianshou.utils.net.continuous.ActorProb(preprocess_net, action_shape, max_action, device='cpu', unbounded=False, hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, **kwargs)[source]

s -> logits -> (mu, sigma)

class tianshou.utils.net.continuous.Critic(preprocess_net, device='cpu', hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, a=None, **kwargs)[source]

(s, a) -> logits -> Q(s, a)

class tianshou.utils.net.continuous.RecurrentActorProb(layer_num, state_shape, action_shape, max_action, device='cpu', hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, **kwargs)[source]

Almost the same as Recurrent.

class tianshou.utils.net.continuous.RecurrentCritic(layer_num, state_shape, action_shape=0, device='cpu', hidden_layer_size=128)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, a=None)[source]

Almost the same as Recurrent.