import torch
import numpy as np
from copy import deepcopy
from typing import Any, Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
[docs]class DQNPolicy(BasePolicy):
"""Implementation of Deep Q Network. arXiv:1312.5602.
Implementation of Double Q-Learning. arXiv:1509.06461.
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.eps = 0.0
assert (
0.0 <= discount_factor <= 1.0
), "discount factor should be in [0, 1]"
self._gamma = discount_factor
assert estimation_step > 0, "estimation_step should be greater than 0"
self._n_step = estimation_step
self._target = target_update_freq > 0
self._freq = target_update_freq
self._iter = 0
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
self._rew_norm = reward_normalization
[docs] def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
[docs] def train(self, mode: bool = True) -> "DQNPolicy":
"""Set the module in training mode, except for the target network."""
self.training = mode
self.model.train(mode)
return self
[docs] def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
with torch.no_grad():
if self._target:
a = self(batch, input="obs_next").act
target_q = self(
batch, model="model_old", input="obs_next"
).logits
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
return target_q
[docs] def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Compute the n-step return for Q-learning targets.
More details can be found at
:meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
"""
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch
[docs] def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
If you need to mask the action, please add a "mask" into batch.obs, for
example, if we have an environment that has "0/1/2" three actions:
::
batch == Batch(
obs=Batch(
obs="original obs, with batch_size=1 for demonstration",
mask=np.array([[False, True, False]]),
# action 1 is available
# action 0 and 2 are unavailable
),
...
)
:param float eps: in [0, 1], for epsilon-greedy exploration method.
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
q, h = model(obs_, state=state, info=batch.info)
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act in training or testing phase
if not self.updating and not np.isclose(self.eps, 0.0):
for i in range(len(q)):
if np.random.rand() < self.eps:
q_ = np.random.rand(*q[i].shape)
if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax()
return Batch(logits=q, act=act, state=h)
[docs] def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns.flatten(), q)
td = r - q
loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1
return {"loss": loss.item()}