tianshou.utils¶
-
class
tianshou.utils.
MovAvg
(size: int = 100)[source]¶ Bases:
object
Class for moving average. It will automatically exclude the infinity and NaN. Usage:
>>> stat = MovAvg(size=66) >>> stat.add(torch.tensor(5)) 5.0 >>> stat.add(float('inf')) # which will not add to stat 5.0 >>> stat.add([6, 7, 8]) 6.5 >>> stat.get() 6.5 >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') 6.50±1.12
-
class
tianshou.utils.net.common.
Net
(layer_num, state_shape, action_shape=0, device='cpu', softmax=False, concat=False)[source]¶ Bases:
torch.nn.modules.module.Module
Simple MLP backbone. For advanced usage (how to customize the network), please refer to Build the Network.
- Parameters
concat – whether the input shape is concatenated by state_shape and action_shape. If it is True,
action_shape
is not the output shape, but affects the input shape.
-
forward
(s, state=None, info={})[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
tianshou.utils.net.common.
Recurrent
(layer_num, state_shape, action_shape, device='cpu')[source]¶ Bases:
torch.nn.modules.module.Module
Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, state=None, info={})[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.discrete.
Actor
(preprocess_net, action_shape)[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, state=None, info={})[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.discrete.
Critic
(preprocess_net)[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.discrete.
DQN
(h, w, action_shape, device='cpu')[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(x, state=None, info={})[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.continuous.
Actor
(preprocess_net, action_shape, max_action, device='cpu')[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, state=None, info={})[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.continuous.
ActorProb
(preprocess_net, action_shape, max_action, device='cpu', unbounded=False)[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, state=None, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.continuous.
Critic
(preprocess_net, device='cpu')[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, a=None, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.continuous.
RecurrentActorProb
(layer_num, state_shape, action_shape, max_action, device='cpu')[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tianshou.utils.net.continuous.
RecurrentCritic
(layer_num, state_shape, action_shape=0, device='cpu')[source]¶ Bases:
torch.nn.modules.module.Module
For advanced usage (how to customize the network), please refer to Build the Network.
-
forward
(s, a=None)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-