Deep Q Network

Deep reinforcement learning has achieved significant successes in various applications. Deep Q Network (DQN) [MKS+15] is the pioneer one. In this tutorial, we will show how to train a DQN agent on CartPole with Tianshou step by step. The full script is at test/discrete/

Contrary to existing Deep RL libraries such as RLlib, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level.

Make an Environment

First of all, you have to make an environment for your agent to interact with. For environment interfaces, we follow the convention of OpenAI Gym. In your Python code, simply import Tianshou and make the environment:

import gym
import tianshou as ts

env = gym.make('CartPole-v0')

CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG [LHP+16], for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action.

Setup Multi-environment Wrapper

If you want to use the original gym.Env:

train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')

Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv, and RayVectorEnv. It can be used as follows: (more explanation can be found at Parallel Sampling)

train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])

Here, we set up 8 environments in train_envs and 100 environments in test_envs.

For the demonstration, here we use the second code-block.


If you use your own environment, please make sure the seed method is set up properly, e.g.,

def seed(self, seed):

Otherwise, the outputs of these envs may be the same with each other.

Build the Network

Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of course, the inputs and outputs must comply with Tianshou’s API. Here is an example:

import torch, numpy as np
from torch import nn

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        self.model = nn.Sequential(
            nn.Linear(, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state

state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

It is also possible to use pre-defined MLP networks in common, discrete, and continuous. The rules of self-defined networks are:

  1. Input: observation obs (may be a numpy.ndarray, torch.Tensor, dict, or self-defined class), hidden state state (for RNN usage), and other information info provided by the environment.

  2. Output: some logits, the next hidden state state. The logits could be a tuple instead of a torch.Tensor, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO [SWD+17], the return of the network might be (mu, sigma), state for Gaussian policy.


The logits here indicates the raw output of the network. In supervised learning, the raw output of prediction/classification model is called logits, and here we extend this definition to any raw output of the neural network.

Setup Policy

We use the defined net and optim above, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with a target network:

policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)

Setup Collector

The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently. In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.

train_collector =, train_envs,
test_collector =, test_envs)

Train Policy with a Trainer

Tianshou provides onpolicy_trainer(), offpolicy_trainer(), and offline_trainer(). The trainer will automatically stop training when the policy reach the stop condition stop_fn on test collector. Since DQN is an off-policy algorithm, we use the offpolicy_trainer() as follows:

result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector,
    max_epoch=10, step_per_epoch=1000, collect_per_step=10,
    episode_per_test=100, batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
print(f'Finished training! Use {result["duration"]}')

The meaning of each parameter is as follows (full description can be found at offpolicy_trainer()):

  • max_epoch: The maximum of epochs for training. The training process might be finished before reaching the max_epoch;

  • step_per_epoch: The number of step for updating policy network in one epoch;

  • collect_per_step: The number of frames the collector would collect before the network update. For example, the code above means “collect 10 frames and do one policy network update”;

  • episode_per_test: The number of episodes for one policy evaluation.

  • batch_size: The batch size of sample data, which is going to feed in the policy network.

  • train_fn: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means “reset the epsilon to 0.1 in DQN before training”.

  • test_fn: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means “reset the epsilon to 0.05 in DQN before testing”.

  • stop_fn: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.

  • writer: See below.

The trainer supports TensorBoard for logging. It can be used as:

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('log/dqn')

Pass the writer into the trainer, and the training result will be recorded into the TensorBoard.

The returned result is a dictionary as follows:

    'train_step': 9246,
    'train_episode': 504.0,
    'train_time/collector': '0.65s',
    'train_time/model': '1.97s',
    'train_speed': '3518.79 step/s',
    'test_step': 49112,
    'test_episode': 400.0,
    'test_time': '1.38s',
    'test_speed': '35600.52 step/s',
    'best_reward': 199.03,
    'duration': '4.01s'

It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03.

Save/Load Policy

Since the policy inherits the class torch.nn.Module, saving and loading the policy are exactly the same as a torch module:, 'dqn.pth')

Watch the Agent’s Performance

Collector supports rendering. Here is the example of watching the agent’s performance in 35 FPS:

collector =, env)
collector.collect(n_episode=1, render=1 / 35)

Train a Policy with Customized Codes

“I don’t want to use your provided trainer. I want to customize it!”

Tianshou supports user-defined training code. Here is the code snippet:

# pre-collect at least 5000 frames with random action before training

for i in range(int(1e6)):  # total step
    collect_result = train_collector.collect(n_step=10)

    # once if the collected episodes' mean returns reach the threshold,
    # or every 1000 steps, we test it on test_collector
    if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0:
        result = test_collector.collect(n_episode=100)
        if result['rew'] >= env.spec.reward_threshold:
            print(f'Finished training! Test mean returns: {result["rew"]}')
            # back to training eps

    # train policy with a sampled batch data from buffer
    losses = policy.update(64, train_collector.buffer)

For further usage, you can refer to the Cheat Sheet.



Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves, Martin A. Riedmiller, Andreas Fidjeland, Georg Ostrovski, Stig Petersen, Charles Beattie, Amir Sadik, Ioannis Antonoglou, Helen King, Dharshan Kumaran, Daan Wierstra, Shane Legg, and Demis Hassabis. Human-level control through deep reinforcement learning. Nature, 518(7540):529–533, 2015. URL:, doi:10.1038/nature14236.


Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, and Daan Wierstra. Continuous control with deep reinforcement learning. In 4th International Conference on Learning Representations, ICLR 2016, San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings. 2016. URL:


John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, and Oleg Klimov. Proximal policy optimization algorithms. CoRR, 2017. URL:, arXiv:1707.06347.