tianshou.data

class tianshou.data.Batch(**kwargs)[source]

Bases: object

Tianshou provides Batch as the internal data structure to pass any kind of data to other methods, for example, a collector gives a Batch to policy for learning. Here is the usage:

>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> len(data.b)
3
>>> data.b[-1]
5

In short, you can define a Batch with any key-value pair. The current implementation of Tianshou typically use 6 keys in Batch:

  • obs the observation of step t ;

  • act the action of step t ;

  • rew the reward of step t ;

  • done the done flag of step t ;

  • obs_next the observation of step t+1 ;

  • info the info of step t (in gym.Env, the env.step() function return 4 arguments, and the last one is info);

Batch has other methods, including __getitem__(), __len__(), append(), and split():

>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])

>>> # here we test __len__
>>> len(data)
3

>>> data.append(data)  # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])

>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, permute=False):
...     print(d.obs, d.rew)
[ 0 11] [6 6]
[22  0] [6 6]
[11 22] [6 6]
__getitem__(index)[source]

Return self[index].

__len__()[source]

Return len(self).

append(batch)[source]

Append a Batch object to current batch.

split(size=None, permute=True)[source]

Split whole data into multiple small batch.

Parameters
  • size (int) – if it is None, it does not split the data batch; otherwise it will divide the data batch with the given size. Default to None.

  • permute (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

class tianshou.data.ReplayBuffer(size)[source]

Bases: object

ReplayBuffer stores data generated from interaction between the policy and environment. It stores basically 6 types of data, as mentioned in Batch, based on numpy.ndarray. Here is the usage:

>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
...     buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0.])

>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
...     buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14.,  5.,  6.,  7.,  8.,  9.])

>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.])

>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True,  True,  True,  True])
__getitem__(index)[source]

Return a data batch: self[index].

__len__()[source]

Return len(self).

add(obs, act, rew, done, obs_next=0, info={}, weight=None)[source]

Add a batch of data into replay buffer.

reset()[source]

Clear all the data in replay buffer.

sample(batch_size)[source]

Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

update(buffer)[source]

Move the data from the given buffer to self.

class tianshou.data.ListReplayBuffer[source]

Bases: tianshou.data.buffer.ReplayBuffer

The function of ListReplayBuffer is almost the same as ReplayBuffer. The only difference is that ListReplayBuffer is based on list.

reset()[source]

Clear all the data in replay buffer.

class tianshou.data.PrioritizedReplayBuffer(size)[source]

Bases: tianshou.data.buffer.ReplayBuffer

docstring for PrioritizedReplayBuffer

add(obs, act, rew, done, obs_next=0, info={}, weight=None)[source]

Add a batch of data into replay buffer.

reset()[source]

Clear all the data in replay buffer.

sample(batch_size)[source]

Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

class tianshou.data.Collector(policy, env, buffer=None, stat_size=100, store_obs_next=True, **kwargs)[source]

Bases: object

The Collector enables the policy to interact with different types of environments conveniently.

Parameters
  • policy – an instance of the BasePolicy class.

  • env – an environment or an instance of the BaseVectorEnv class.

  • buffer – an instance of the ReplayBuffer class, or a list of ReplayBuffer. If set to None, it will automatically assign a small-size ReplayBuffer.

  • stat_size (int) – for the moving average of recording speed, defaults to 100.

  • store_obs_next (bool) – whether to store the obs_next to replay buffer, defaults to True.

Example:

policy = PGPolicy(...)  # or other policies if you wish
env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(size=10000)
# here we set up a collector with a single environment
collector = Collector(policy, env, buffer=replay_buffer)

# the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer)

# collect at least 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
# collect at least 2 steps
collector.collect(n_step=2)
# collect episodes with visual rendering (the render argument is the
#   sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)

# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data)  # btw, vanilla policy gradient only
#   supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
#   clear the buffer
collector.reset_buffer()

For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation.

Note

Please make sure the given environment has a time limitation.

close()[source]

Close the environment(s).

collect(n_step=0, n_episode=0, render=0)[source]

Collect a specified number of step or episode.

Parameters
  • n_step (int) – how many steps you want to collect.

  • n_episode (int or list) – how many episodes you want to collect (in each environment).

  • render (float) – the sleep time between rendering consecutive frames. No rendering if it is 0 (default option).

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns

A dict including the following keys

  • n/ep the collected number of episodes.

  • n/st the collected number of steps.

  • v/st the speed of steps per second.

  • v/ep the speed of episode per second.

  • rew the mean reward over collected episodes.

  • len the mean length over collected episodes.

get_env_num()[source]

Return the number of environments the collector has.

render(**kwargs)[source]

Render all the environment(s).

reset_buffer()[source]

Reset the main data buffer.

reset_env()[source]

Reset all of the environment(s)’ states and reset all of the cache buffers (if need).

sample(batch_size)[source]

Sample a data batch from the internal replay buffer. It will call process_fn() before returning the final batch data.

Parameters

batch_size (int) – 0 means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size.

seed(seed=None)[source]

Reset all the seed(s) of the given environment(s).