tianshou.policy

class tianshou.policy.BasePolicy(**kwargs)[source]

Bases: abc.ABC, torch.nn.modules.module.Module

Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit BasePolicy.

A policy class typically has four parts:

  • __init__(): initialize the policy, including coping the target network and so on;

  • __call__(): compute action with given observation;

  • process_fn(): pre-process data from the replay buffer (this function can interact with replay buffer);

  • learn(): update policy with a given batch of data.

Most of the policy needs a neural network to predict the action and an optimizer to optimize the policy. The rules of self-defined networks are:

  1. Input: observation obs (may be a numpy.ndarray or torch.Tensor), hidden state state (for RNN usage), and other information info provided by the environment.

  2. Output: some logits and the next hidden state state. The logits could be a tuple instead of a torch.Tensor. It depends on how the policy process the network output. For example, in PPO, the return of the network might be (mu, sigma), state for Gaussian policy.

Since BasePolicy inherits torch.nn.Module, you can operate BasePolicy almost the same as torch.nn.Module, for instance, load and save the model:

torch.save(policy.state_dict(), 'policy.pth')
policy.load_state_dict(torch.load('policy.pth'))
abstract __call__(batch, state=None, **kwargs)[source]

Compute action over the given batch data.

Returns

A Batch which MUST have the following keys:

  • act an numpy.ndarray or a torch.Tensor, the action over given batch data.

  • state a dict, an numpy.ndarray or a torch.Tensor, the internal state of the policy, None as default.

Other keys are user-defined. It depends on the algorithm. For example,

# some code
return Batch(logits=..., act=..., state=None, dist=...)
abstract learn(batch, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

process_fn(batch, buffer, indice)[source]

Pre-process the data from the provided replay buffer. Check out Policy for more information.

class tianshou.policy.DQNPolicy(model, optim, discount_factor=0.99, estimation_step=1, target_update_freq=0, **kwargs)[source]

Bases: tianshou.policy.base.BasePolicy

Implementation of Deep Q Network. arXiv:1312.5602

Parameters
  • model (torch.nn.Module) – a model following the rules in BasePolicy. (s -> logits)

  • optim (torch.optim.Optimizer) – a torch.optim for optimizing the model.

  • discount_factor (float) – in [0, 1].

  • estimation_step (int) – greater than 1, the number of steps to look ahead.

  • target_update_freq (int) – the target network update frequency (0 if you do not use the target network).

__call__(batch, state=None, model='model', input='obs', eps=None, **kwargs)[source]

Compute action over the given batch data.

Parameters

eps (float) – in [0, 1], for epsilon-greedy exploration method.

Returns

A Batch which has 3 keys:

  • act the action.

  • logits the network’s raw output.

  • state the hidden state.

More information can be found at __call__().

eval()[source]

Set the module in evaluation mode, except for the target network.

learn(batch, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

process_fn(batch, buffer, indice)[source]

Compute the n-step return for Q-learning targets:

G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
(Q_{new}(s_{t + n}, a)))

, where \gamma is the discount factor, \gamma \in [0, 1], d_t is the done flag of step t. If there is no target network, the Q_{old} is equal to Q_{new}.

set_eps(eps)[source]

Set the eps for epsilon-greedy exploration.

sync_weight()[source]

Synchronize the weight for the target network.

train()[source]

Set the module in training mode, except for the target network.

class tianshou.policy.PGPolicy(model, optim, dist_fn=<class 'torch.distributions.categorical.Categorical'>, discount_factor=0.99, **kwargs)[source]

Bases: tianshou.policy.base.BasePolicy

Implementation of Vanilla Policy Gradient.

Parameters
  • model (torch.nn.Module) – a model following the rules in BasePolicy. (s -> logits)

  • optim (torch.optim.Optimizer) – a torch.optim for optimizing the model.

  • dist_fn (torch.distributions.Distribution) – for computing the action.

  • discount_factor (float) – in [0, 1].

__call__(batch, state=None, **kwargs)[source]

Compute action over the given batch data.

Returns

A Batch which has 4 keys:

  • act the action.

  • logits the network’s raw output.

  • dist the action distribution.

  • state the hidden state.

More information can be found at __call__().

learn(batch, batch_size=None, repeat=1, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

process_fn(batch, buffer, indice)[source]

Compute the discounted returns for each frame:

G_t = \sum_{i=t}^T \gamma^{i-t}r_i

, where T is the terminal time step, \gamma is the discount factor, \gamma \in [0, 1].

class tianshou.policy.A2CPolicy(actor, critic, optim, dist_fn=<class 'torch.distributions.categorical.Categorical'>, discount_factor=0.99, vf_coef=0.5, ent_coef=0.01, max_grad_norm=None, **kwargs)[source]

Bases: tianshou.policy.pg.PGPolicy

Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783

Parameters
  • actor (torch.nn.Module) – the actor network following the rules in BasePolicy. (s -> logits)

  • critic (torch.nn.Module) – the critic network. (s -> V(s))

  • optim (torch.optim.Optimizer) – the optimizer for actor and critic network.

  • dist_fn (torch.distributions.Distribution) – for computing the action, defaults to torch.distributions.Categorical.

  • discount_factor (float) – in [0, 1], defaults to 0.99.

  • vf_coef (float) – weight for value loss, defaults to 0.5.

  • ent_coef (float) – weight for entropy loss, defaults to 0.01.

  • max_grad_norm (float) – clipping gradients in back propagation, defaults to None.

__call__(batch, state=None, **kwargs)[source]

Compute action over the given batch data.

Returns

A Batch which has 4 keys:

  • act the action.

  • logits the network’s raw output.

  • dist the action distribution.

  • state the hidden state.

More information can be found at __call__().

learn(batch, batch_size=None, repeat=1, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

class tianshou.policy.DDPGPolicy(actor, actor_optim, critic, critic_optim, tau=0.005, gamma=0.99, exploration_noise=0.1, action_range=None, reward_normalization=False, ignore_done=False, **kwargs)[source]

Bases: tianshou.policy.base.BasePolicy

Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971

Parameters
  • actor (torch.nn.Module) – the actor network following the rules in BasePolicy. (s -> logits)

  • actor_optim (torch.optim.Optimizer) – the optimizer for actor network.

  • critic (torch.nn.Module) – the critic network. (s, a -> Q(s, a))

  • critic_optim (torch.optim.Optimizer) – the optimizer for critic network.

  • tau (float) – param for soft update of the target network, defaults to 0.005.

  • gamma (float) – discount factor, in [0, 1], defaults to 0.99.

  • exploration_noise (float) – the noise intensity, add to the action, defaults to 0.1.

  • action_range ([float, float]) – the action range (minimum, maximum).

  • reward_normalization (bool) – normalize the reward to Normal(0, 1), defaults to False.

  • ignore_done (bool) – ignore the done flag while training the policy, defaults to False.

__call__(batch, state=None, model='actor', input='obs', eps=None, **kwargs)[source]

Compute action over the given batch data.

Parameters

eps (float) – in [0, 1], for exploration use.

Returns

A Batch which has 2 keys:

  • act the action.

  • state the hidden state.

More information can be found at __call__().

eval()[source]

Set the module in evaluation mode, except for the target network.

learn(batch, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

process_fn(batch, buffer, indice)[source]

Pre-process the data from the provided replay buffer. Check out Policy for more information.

set_eps(eps)[source]

Set the eps for exploration.

sync_weight()[source]

Soft-update the weight for the target network.

train()[source]

Set the module in training mode, except for the target network.

class tianshou.policy.PPOPolicy(actor, critic, optim, dist_fn, discount_factor=0.99, max_grad_norm=0.5, eps_clip=0.2, vf_coef=0.5, ent_coef=0.0, action_range=None, **kwargs)[source]

Bases: tianshou.policy.pg.PGPolicy

Implementation of Proximal Policy Optimization. arXiv:1707.06347

Parameters
  • actor (torch.nn.Module) – the actor network following the rules in BasePolicy. (s -> logits)

  • critic (torch.nn.Module) – the critic network. (s -> V(s))

  • optim (torch.optim.Optimizer) – the optimizer for actor and critic network.

  • dist_fn (torch.distributions.Distribution) – for computing the action.

  • discount_factor (float) – in [0, 1], defaults to 0.99.

  • max_grad_norm (float) – clipping gradients in back propagation, defaults to None.

  • eps_clip (float) – \epsilon in L_{CLIP} in the original paper, defaults to 0.2.

  • vf_coef (float) – weight for value loss, defaults to 0.5.

  • ent_coef (float) – weight for entropy loss, defaults to 0.01.

  • action_range ([float, float]) – the action range (minimum, maximum).

__call__(batch, state=None, model='actor', **kwargs)[source]

Compute action over the given batch data.

Returns

A Batch which has 4 keys:

  • act the action.

  • logits the network’s raw output.

  • dist the action distribution.

  • state the hidden state.

More information can be found at __call__().

eval()[source]

Set the module in evaluation mode, except for the target network.

learn(batch, batch_size=None, repeat=1, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

sync_weight()[source]

Synchronize the weight for the target network.

train()[source]

Set the module in training mode, except for the target network.

class tianshou.policy.TD3Policy(actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=0.005, gamma=0.99, exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2, noise_clip=0.5, action_range=None, reward_normalization=False, ignore_done=False, **kwargs)[source]

Bases: tianshou.policy.ddpg.DDPGPolicy

Implementation of Twin Delayed Deep Deterministic Policy Gradient, arXiv:1802.09477

Parameters
  • actor (torch.nn.Module) – the actor network following the rules in BasePolicy. (s -> logits)

  • actor_optim (torch.optim.Optimizer) – the optimizer for actor network.

  • critic1 (torch.nn.Module) – the first critic network. (s, a -> Q(s, a))

  • critic1_optim (torch.optim.Optimizer) – the optimizer for the first critic network.

  • critic2 (torch.nn.Module) – the second critic network. (s, a -> Q(s, a))

  • critic2_optim (torch.optim.Optimizer) – the optimizer for the second critic network.

  • tau (float) – param for soft update of the target network, defaults to 0.005.

  • gamma (float) – discount factor, in [0, 1], defaults to 0.99.

  • exploration_noise (float) – the noise intensity, add to the action, defaults to 0.1.

  • policy_noise (float) – the noise used in updating policy network, default to 0.2.

  • update_actor_freq (int) – the update frequency of actor network, default to 2.

  • noise_clip (float) – the clipping range used in updating policy network, default to 0.5.

  • action_range ([float, float]) – the action range (minimum, maximum).

  • reward_normalization (bool) – normalize the reward to Normal(0, 1), defaults to False.

  • ignore_done (bool) – ignore the done flag while training the policy, defaults to False.

eval()[source]

Set the module in evaluation mode, except for the target network.

learn(batch, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

sync_weight()[source]

Soft-update the weight for the target network.

train()[source]

Set the module in training mode, except for the target network.

class tianshou.policy.SACPolicy(actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=0.005, gamma=0.99, alpha=0.2, action_range=None, reward_normalization=False, ignore_done=False, **kwargs)[source]

Bases: tianshou.policy.ddpg.DDPGPolicy

Implementation of Soft Actor-Critic. arXiv:1812.05905

Parameters
  • actor (torch.nn.Module) – the actor network following the rules in BasePolicy. (s -> logits)

  • actor_optim (torch.optim.Optimizer) – the optimizer for actor network.

  • critic1 (torch.nn.Module) – the first critic network. (s, a -> Q(s, a))

  • critic1_optim (torch.optim.Optimizer) – the optimizer for the first critic network.

  • critic2 (torch.nn.Module) – the second critic network. (s, a -> Q(s, a))

  • critic2_optim (torch.optim.Optimizer) – the optimizer for the second critic network.

  • tau (float) – param for soft update of the target network, defaults to 0.005.

  • gamma (float) – discount factor, in [0, 1], defaults to 0.99.

  • exploration_noise (float) – the noise intensity, add to the action, defaults to 0.1.

  • alpha (float) – entropy regularization coefficient, default to 0.2.

  • action_range ([float, float]) – the action range (minimum, maximum).

  • reward_normalization (bool) – normalize the reward to Normal(0, 1), defaults to False.

  • ignore_done (bool) – ignore the done flag while training the policy, defaults to False.

__call__(batch, state=None, input='obs', **kwargs)[source]

Compute action over the given batch data.

Parameters

eps (float) – in [0, 1], for exploration use.

Returns

A Batch which has 2 keys:

  • act the action.

  • state the hidden state.

More information can be found at __call__().

eval()[source]

Set the module in evaluation mode, except for the target network.

learn(batch, **kwargs)[source]

Update policy with a given batch of data.

Returns

A dict which includes loss and its corresponding label.

sync_weight()[source]

Soft-update the weight for the target network.

train()[source]

Set the module in training mode, except for the target network.