dqn#


class DQNPolicy(*, model: Module, optim: Optimizer, action_space: Discrete, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of Deep Q Network. arXiv:1312.5602.

Implementation of Double Q-Learning. arXiv:1509.06461.

Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is implemented in the network side, not here).

Parameters:
  • model – a model following the rules in BasePolicy. (s -> logits)

  • optim – a torch.optim for optimizing the model.

  • discount_factor – in [0, 1].

  • estimation_step – the number of steps to look ahead.

  • target_update_freq – the target network update frequency (0 if you do not use the target network).

  • reward_normalization – normalize the returns to Normal(0, 1). TODO: rename to return_normalization?

  • is_double – use double dqn.

  • clip_loss_grad – clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss.

  • observation_space – Env’s observation space.

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to BasePolicy for more detailed explanation.

compute_q_value(logits: Tensor, mask: ndarray | None) Tensor[source]#

Compute the q value based on the network’s raw output and action mask.

exploration_noise(act: ndarray | BatchProtocol, batch: RolloutBatchProtocol) ndarray | BatchProtocol[source]#

Modify the action from policy.forward with exploration noise.

NOTE: currently does not add any noise! Needs to be overridden by subclasses to actually do something.

Parameters:
  • act – a data batch or numpy.ndarray which is the action taken by policy.forward.

  • batch – the input batch for policy.forward, kept for advanced usage.

Returns:

action in the same form of input “act” but with added exploration noise.

forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, model: Literal['model', 'model_old'] = 'model', **kwargs: Any) ModelOutputBatchProtocol[source]#

Compute action over the given batch data.

If you need to mask the action, please add a “mask” into batch.obs, for example, if we have an environment that has “0/1/2” three actions:

batch == Batch(
    obs=Batch(
        obs="original obs, with batch_size=1 for demonstration",
        mask=np.array([[False, True, False]]),
        # action 1 is available
        # action 0 and 2 are unavailable
    ),
    ...
)
Returns:

A Batch which has 3 keys:

  • act the action.

  • logits the network’s raw output.

  • state the hidden state.

See also

Please refer to forward() for more detailed explanation.

learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TDQNTrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) BatchWithReturnsProtocol[source]#

Compute the n-step return for Q-learning targets.

More details can be found at compute_nstep_return().

set_eps(eps: float) None[source]#

Set the eps for epsilon-greedy exploration.

sync_weight() None[source]#

Synchronize the weight for the target network.

train(mode: bool = True) Self[source]#

Set the module in training mode, except for the target network.

class DQNTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: float)[source]#
loss: float#