tianshou.utils

class tianshou.utils.MovAvg(size: int = 100)[source]

Bases: object

Class for moving average. It will automatically exclude the infinity and NaN. Usage:

>>> stat = MovAvg(size=66)
>>> stat.add(torch.tensor(5))
5.0
>>> stat.add(float('inf'))  # which will not add to stat
5.0
>>> stat.add([6, 7, 8])
6.5
>>> stat.get()
6.5
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
add(x: Union[float, list, numpy.ndarray, torch.Tensor]) → float[source]

Add a scalar into MovAvg. You can add torch.Tensor with only one element, a python scalar, or a list of python scalar.

get() → float[source]

Get the average.

mean() → float[source]

Get the average. Same as get().

std() → float[source]

Get the standard deviation.

class tianshou.utils.net.common.Net(layer_num, state_shape, action_shape=0, device='cpu', softmax=False, concat=False)[source]

Bases: torch.nn.modules.module.Module

Simple MLP backbone. For advanced usage (how to customize the network), please refer to Build the Network.

Parameters

concat – whether the input shape is concatenated by state_shape and action_shape. If it is True, action_shape is not the output shape, but affects the input shape.

forward(s, state=None, info={})[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.common.Recurrent(layer_num, state_shape, action_shape, device='cpu')[source]

Bases: torch.nn.modules.module.Module

Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, info={})[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.discrete.Actor(preprocess_net, action_shape)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, info={})[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.discrete.Critic(preprocess_net)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.discrete.DQN(h, w, action_shape, device='cpu')[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(x, state=None, info={})[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.continuous.Actor(preprocess_net, action_shape, max_action, device='cpu')[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, info={})[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.continuous.ActorProb(preprocess_net, action_shape, max_action, device='cpu', unbounded=False)[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, state=None, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.continuous.Critic(preprocess_net, device='cpu')[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, a=None, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.continuous.RecurrentActorProb(layer_num, state_shape, action_shape, max_action, device='cpu')[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tianshou.utils.net.continuous.RecurrentCritic(layer_num, state_shape, action_shape=0, device='cpu')[source]

Bases: torch.nn.modules.module.Module

For advanced usage (how to customize the network), please refer to Build the Network.

forward(s, a=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.