Basic concepts in Tianshou¶
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
Here is a more detailed description, where Env
is the environment and Model
is the neural network:
Data Batch¶
Tianshou provides Batch
as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a Batch
to policy for learning.
Here is the usage:
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
)
In short, you can define a Batch
with any key-value pair.
For Numpy arrays, only data types with np.object
, bool, and number
are supported. For strings or other data types, however, they can be
held in np.object
arrays.
The current implementation of Tianshou typically use 7 reserved keys in
Batch
:
obs
the observation of step ;act
the action of step ;rew
the reward of step ;done
the done flag of step ;obs_next
the observation of step ;info
the info of step (ingym.Env
, theenv.step()
function returns 4 arguments, and the last one isinfo
);policy
the data computed by policy in step ;
Batch
object can be initialized by a wide variety
of arguments, ranging from the key/value pairs or dictionary, to list and
Numpy arrays of dict
or Batch instances where each element is
considered as an individual sample and get stacked together:
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
a: Batch(
b: array([0.0, 'info'], dtype=object),
),
)
Batch
has the same API as a native Python
dict
. In this regard, one can access stored data using string
key, or iterate over stored data:
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
Batch
also partially reproduces the Numpy API for
arrays. It also supports the advanced slicing method, such as batch[:, i],
if the index is valid. You can access or iterate over the individual
samples, if any:
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> print(data[0])
Batch(
a: array([0., 2.])
b: array([ 5, -5]),
)
>>> for sample in data:
>>> print(sample.a)
[0., 2.]
>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
Similarly, one can also perform simple algebra on it, and stack, split or concatenate multiple instances:
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> print(np.mean(data))
Batch(
b: 0.0,
a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: array([5]),
a: array([[0., 2.]]),
), Batch(
b: array([-5]),
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
Note that stacking of inconsistent data is also supported. In which case,
None
is added in list or np.ndarray
of objects, 0 otherwise.
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
Method empty_
sets elements to 0 or None
for np.object
.
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([None, 'st']),
d: array([0., 0.]),
),
)
shape()
and __len__()
methods are also provided to respectively get the shape and the length of
a Batch
instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception.
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
Convenience helpers are available to convert in-place the stored data into Numpy arrays or Torch tensors.
Finally, note that Batch
is serializable and
therefore Pickle compatible. This is especially important for distributed
sampling.
-
tianshou.data.Batch.
shape
Return self.shape.
Data Buffer¶
ReplayBuffer
stores data generated from
interaction between the policy and environment. It stores basically 7 types
of data, as mentioned in Batch
, based on
numpy.ndarray
. Here is the usage:
>>> import numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
ReplayBuffer
also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
observation (save memory in atari tasks), and multi-modal observation (see
issue#38):
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
... obs_next={'id': i + 1})
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
info: Batch(),
obs: Batch(
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
),
policy: Batch(),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
- param int size
the size of replay buffer.
- param int stack_num
the frame-stack sampling argument, should be greater than 1, defaults to 0 (no stacking).
- param bool ignore_obs_next
whether to store obs_next, defaults to
False
.- param bool sample_avail
the parameter indicating sampling only available index when using frame-stack sampling method, defaults to
False
. This feature is not supported in Prioritized Replay Buffer currently.
Tianshou provides other type of data buffer such as ListReplayBuffer
(based on list), PrioritizedReplayBuffer
(based on Segment Tree and numpy.ndarray
). Check out ReplayBuffer
for more detail.
Policy¶
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit BasePolicy
.
A policy class typically has four parts:
__init__()
: initialize the policy, including coping the target network and so on;forward()
: compute action with given observation;process_fn()
: pre-process data from the replay buffer (this function can interact with replay buffer);learn()
: update policy with a given batch of data.
Take 2-step return DQN as an example. The 2-step return DQN compute each frame’s return as:
where is the discount factor, . Here is the pseudocode showing the training process without Tianshou framework:
# pseudocode, cannot work
s = env.reset()
buffer = Buffer(size=10000)
agent = DQN()
for i in range(int(1e6)):
a = agent.compute_action(s)
s_, r, d, _ = env.step(a)
buffer.store(s, a, s_, r, d)
s = s_
if i % 1000 == 0:
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...)
# update DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret)
Thus, we need a time-related interface for calculating the 2-step return. process_fn()
finishes this work by providing the replay buffer, the sample index, and the sample batch data. Since we store all the data in the order of time, you can simply compute the 2-step return as:
class DQN_2step(BasePolicy):
"""some code"""
def process_fn(self, batch, buffer, indice):
buffer_len = len(buffer)
batch_2 = buffer[(indice + 2) % buffer_len]
# this will return a batch data where batch_2.obs is s_t+2
# we can also get s_t+2 through:
# batch_2_obs = buffer.obs[(indice + 2) % buffer_len]
# in short, buffer.obs[i] is equal to buffer[i].obs, but the former is more effecient.
Q = self(batch_2, eps=0) # shape: [batchsize, action_shape]
maxQ = Q.max(dim=-1)
batch.returns = batch.rew \
+ self._gamma * buffer.rew[(indice + 1) % buffer_len] \
+ self._gamma ** 2 * maxQ
return batch
This code does not consider the done flag, so it may not work very well. It shows two ways to get from the replay buffer easily in process_fn()
.
For other method, you can check out tianshou.policy. We give the usage of policy class a high-level explanation in A High-level Explanation.
Collector¶
The Collector
enables the policy to interact with different types of environments conveniently.
In short, Collector
has two main methods:
collect()
: let the policy perform (at least) a specified number of stepn_step
or episoden_episode
and store the data in the replay buffer;sample()
: sample a data batch from replay buffer; it will callprocess_fn()
before returning the final batch data.
Why do we mention at least here? For a single environment, the collector will finish exactly n_step
or n_episode
. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
The solution is to add some cache buffers inside the collector. Once collecting a full episode of trajectory, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
The general explanation is listed in A High-level Explanation. Other usages of collector are listed in Collector
documentation.
Trainer¶
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: Train a Policy with Customized Codes.
Tianshou has two types of trainer: onpolicy_trainer()
and offpolicy_trainer()
, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out tianshou.trainer for the usage.
There will be more types of trainers, for instance, multi-agent trainer.
A High-level Explanation¶
We give a high-level explanation through the pseudocode used in section Policy:
# pseudocode, cannot work # methods in tianshou
s = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # policy(batch, ...)
s_, r, d, _ = env.step(a) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
if i % 1000 == 0: # done in trainer
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
# update DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...)
Conclusion¶
So far, we go through the overall framework of Tianshou. Really simple, isn’t it?