ddpg#
Source code: tianshou/policy/modelfree/ddpg.py
- class DDPGPolicy(*, actor: Module, actor_optim: Optimizer, critic: Module, critic_optim: Optimizer, action_space: Space, tau: float = 0.005, gamma: float = 0.99, exploration_noise: BaseNoise | Literal['default'] | None = 'default', estimation_step: int = 1, observation_space: Space | None = None, action_scaling: bool = True, action_bound_method: Literal['clip'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
- Parameters:
actor – The actor network following the rules in
BasePolicy
. (s -> model_output)actor_optim – The optimizer for actor network.
critic – The critic network. (s, a -> Q(s, a))
critic_optim – The optimizer for critic network.
action_space – Env’s action space.
tau – Param for soft update of the target network.
gamma – Discount factor, in [0, 1].
exploration_noise – The exploration noise, added to the action. Defaults to
GaussianNoise(sigma=0.1)
.estimation_step – The number of steps to look ahead.
observation_space – Env’s observation space.
action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.
action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
BasePolicy
for more detailed explanation.- exploration_noise(act: ndarray | BatchProtocol, batch: RolloutBatchProtocol) ndarray | BatchProtocol [source]#
Modify the action from policy.forward with exploration noise.
NOTE: currently does not add any noise! Needs to be overridden by subclasses to actually do something.
- Parameters:
act – a data batch or numpy.ndarray which is the action taken by policy.forward.
batch – the input batch for policy.forward, kept for advanced usage.
- Returns:
action in the same form of input “act” but with added exploration noise.
- forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, model: Literal['actor', 'actor_old'] = 'actor', **kwargs: Any) ActStateBatchProtocol [source]#
Compute action over the given batch data.
- Returns:
A
Batch
which has 2 keys:act
the action.state
the hidden state.
See also
Please refer to
forward()
for more detailed explanation.
- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TDDPGTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.training
andself.updating
. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normal
andtorch.distributions.Categorical
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.
- process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) RolloutBatchProtocol | BatchWithReturnsProtocol [source]#
Pre-process the data from the provided replay buffer.
Meant to be overridden by subclasses. Typical usage is to add new keys to the batch, e.g., to add the value function of the next state. Used in
update()
, which is usually called repeatedly during training.For modifying the replay buffer only once at the beginning (e.g., for offline learning) see
process_buffer()
.