pg#


class PGPolicy(*, actor: Module, optim: Optimizer, dist_fn: Callable[[...], Distribution], action_space: Space, discount_factor: float = 0.99, reward_normalization: bool = False, deterministic_eval: bool = False, observation_space: Space | None = None, action_scaling: bool = True, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of REINFORCE algorithm.

Parameters:
  • actor – mapping (s->model_output), should follow the rules in BasePolicy.

  • optim – optimizer for actor network.

  • dist_fn – distribution class for computing the action. Maps model_output -> distribution. Typically a Gaussian distribution taking model_output=mean,std as input for continuous action spaces, or a categorical distribution taking model_output=logits for discrete action spaces. Note that as user, you are responsible for ensuring that the distribution is compatible with the action space.

  • action_space – env’s action space.

  • discount_factor – in [0, 1].

  • reward_normalization – if True, will normalize the returns by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! See TODO in process_fn.

  • deterministic_eval – if True, will use deterministic action (the dist’s mode) instead of stochastic one during evaluation. Does not affect training.

  • observation_space – Env’s observation space.

  • action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.

  • action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to BasePolicy for more detailed explanation.

forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) DistBatchProtocol[source]#

Compute action over the given batch data by applying the actor.

Will sample from the dist_fn, if appropriate. Returns a new object representing the processed batch data (contrary to other methods that modify the input batch inplace).

See also

Please refer to forward() for more detailed explanation.

learn(batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, *args: Any, **kwargs: Any) TPGTrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) BatchWithReturnsProtocol[source]#

Compute the discounted returns (Monte Carlo estimates) for each transition.

They are added to the batch under the field returns. Note: this function will modify the input batch!

\[G_t = \sum_{i=t}^T \gamma^{i-t}r_i\]

where \(T\) is the terminal time step, \(\gamma\) is the discount factor, \(\gamma \in [0, 1]\).

Parameters:
  • batch – a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recognized by buffer.unfinished_index().

  • buffer – the corresponding replay buffer.

  • indices (numpy.ndarray) – tell batch’s location in buffer, batch is equal to buffer[indices].

class PGTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: tianshou.data.stats.SequenceSummaryStats)[source]#
loss: SequenceSummaryStats#