Cheat Sheet

This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.

By the way, some of these issues can be resolved by using a gym.wrapper. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor Handle Batched Data Stream in Collector.

Build Policy Network

See Build the Network.

Build New Policy

See BasePolicy.

Customize Training Process

See Train a Policy with Customized Codes.

Parallel Sampling

Use VectorEnv or SubprocVectorEnv.

env_fns = [
    lambda: MyTestEnv(size=2),
    lambda: MyTestEnv(size=3),
    lambda: MyTestEnv(size=4),
    lambda: MyTestEnv(size=5),
]
venv = SubprocVectorEnv(env_fns)

where env_fns is a list of callable env hooker. The above code can be written in for-loop as well:

env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)

Handle Batched Data Stream in Collector

This is related to Issue 42.

If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use preproces_fn in Collector. This is a hook which will be called before the data adding into the buffer.

This function receives typically 7 keys, as listed in Batch, and returns the modified part within a dict or a Batch. For example, you can write your hook as:

import numpy as np
from collections import deque
class MyProcessor:
    def __init__(self, size=100):
        self.episode_log = None
        self.main_log = deque(maxlen=size)
        self.main_log.append(0)
        self.baseline = 0
    def preprocess_fn(**kwargs):
        """change reward to zero mean"""
        if 'rew' not in kwargs:
            # means that it is called after env.reset(), it can only process the obs
            return {}  # none of the variables are needed to be updated
        else:
            n = len(kwargs['rew'])  # the number of envs in collector
            if self.episode_log is None:
                self.episode_log = [[] for i in range(n)]
            for i in range(n):
                self.episode_log[i].append(kwargs['rew'][i])
                kwargs['rew'][i] -= self.baseline
            for i in range(n):
                if kwargs['done']:
                    self.main_log.append(np.mean(self.episode_log[i]))
                    self.episode_log[i] = []
                    self.baseline = np.mean(self.main_log)
            return Batch(rew=kwargs['rew'])
            # you can also return with {'rew': kwargs['rew']}

And finally,

test_processor = MyProcessor(size=100)
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)

Some examples are in test/base/test_collector.py.

RNN-style Training

This is related to Issue 19.

First, add an argument stack_num to ReplayBuffer:

buf = ReplayBuffer(size=size, stack_num=stack_num)

Then, change the network to recurrent-style, for example, class Recurrent in code snippet 1, or RecurrentActor and RecurrentCritic in code snippet 2.

The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a gym.wrapper to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:

  • Before: (s, a, s’, r, d) stored in replay buffer, and get stacked s;

  • After applying wrapper: ([s, a], a, [s’, a’], r, d) stored in replay buffer, and get both stacked s and a.

User-defined Environment and Different State Representation

This is related to Issue 38 and Issue 69.

First of all, your self-defined environment must follow the Gym’s API, some of them are listed below:

  • reset() -> state

  • step(action) -> state, reward, done, info

  • seed(s) -> None

  • render(mode) -> None

  • close() -> None

  • observation_space

  • action_space

The state can be a numpy.ndarray or a Python dictionary. Take FetchReach-v1 as an example:

>>> e = gym.make('FetchReach-v1')
>>> e.reset()
{'observation': array([ 1.34183265e+00,  7.49100387e-01,  5.34722720e-01,  1.97805133e-04,
         7.15193042e-05,  7.73933014e-06,  5.51992816e-08, -2.42927453e-06,
         4.73325650e-06, -2.28455228e-06]),
 'achieved_goal': array([1.34183265, 0.74910039, 0.53472272]),
 'desired_goal': array([1.24073906, 0.77753463, 0.63457791])}

It shows that the state is a dictionary which has 3 keys. It will stored in ReplayBuffer as:

>>> from tianshou.data import ReplayBuffer
>>> b = ReplayBuffer(size=3)
>>> b.add(obs=e.reset(), act=0, rew=0, done=0)
>>> print(b)
ReplayBuffer(
    act: array([0, 0, 0]),
    done: array([0, 0, 0]),
    info: Batch(),
    obs: Batch(
             achieved_goal: array([[1.34183265, 0.74910039, 0.53472272],
                                   [0.        , 0.        , 0.        ],
                                   [0.        , 0.        , 0.        ]]),
             desired_goal: array([[1.42154265, 0.62505137, 0.62929863],
                                  [0.        , 0.        , 0.        ],
                                  [0.        , 0.        , 0.        ]]),
             observation: array([[ 1.34183265e+00,  7.49100387e-01,  5.34722720e-01,
                                   1.97805133e-04,  7.15193042e-05,  7.73933014e-06,
                                   5.51992816e-08, -2.42927453e-06,  4.73325650e-06,
                                  -2.28455228e-06],
                                 [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
                                   0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
                                   0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
                                   0.00000000e+00],
                                 [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
                                   0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
                                   0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
                                   0.00000000e+00]]),
         ),
    policy: Batch(),
    rew: array([0, 0, 0]),
)
>>> print(b.obs.achieved_goal)
[[1.34183265 0.74910039 0.53472272]
 [0.         0.         0.        ]
 [0.         0.         0.        ]]

And the data batch sampled from this replay buffer:

>>> batch, indice = b.sample(2)
>>> batch.keys()
['act', 'done', 'info', 'obs', 'obs_next', 'policy', 'rew']
>>> batch.obs[-1]
Batch(
    achieved_goal: array([1.34183265, 0.74910039, 0.53472272]),
    desired_goal: array([1.42154265, 0.62505137, 0.62929863]),
    observation: array([ 1.34183265e+00,  7.49100387e-01,  5.34722720e-01,  1.97805133e-04,
                         7.15193042e-05,  7.73933014e-06,  5.51992816e-08, -2.42927453e-06,
                         4.73325650e-06, -2.28455228e-06]),
)
>>> batch.obs.desired_goal[-1]  # recommended
array([1.42154265, 0.62505137, 0.62929863])
>>> batch.obs[-1].desired_goal  # not recommended
array([1.42154265, 0.62505137, 0.62929863])
>>> batch[-1].obs.desired_goal  # not recommended
array([1.42154265, 0.62505137, 0.62929863])

Thus, in your self-defined network, just change the forward function as:

def forward(self, s, ...):
    # s is a batch
    observation = s.observation
    achieved_goal = s.achieved_goal
    desired_goal = s.desired_goal
    ...

For self-defined class, the replay buffer will store the reference into a numpy.ndarray, e.g.:

>>> import networkx as nx
>>> b = ReplayBuffer(size=3)
>>> b.add(obs=nx.Graph(), act=0, rew=0, done=0)
>>> print(b)
ReplayBuffer(
    act: array([0, 0, 0]),
    done: array([0, 0, 0]),
    info: Batch(),
    obs: array([<networkx.classes.graph.Graph object at 0x7f5c607826a0>, None,
                None], dtype=object),
    policy: Batch(),
    rew: array([0, 0, 0]),
)

But the state stored in the buffer may be a shallow-copy. To make sure each of your state stored in the buffer is distinct, please return the deep-copy version of your state in your env:

def reset():
    return copy.deepcopy(self.graph)
def step(a):
    ...
    return copy.deepcopy(self.graph), reward, done, {}