tianshou.data

class tianshou.data.Batch(**kwargs)[source]

Bases: object

Tianshou provides Batch as the internal data structure to pass any kind of data to other methods, for example, a collector gives a Batch to policy for learning. Here is the usage:

>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> data.b
[5, 5]
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
    a: 4,
    b: array([3, 4, 5]),
    c: '2312312',
)

In short, you can define a Batch with any key-value pair. The current implementation of Tianshou typically use 7 reserved keys in Batch:

  • obs the observation of step t ;

  • act the action of step t ;

  • rew the reward of step t ;

  • done the done flag of step t ;

  • obs_next the observation of step t+1 ;

  • info the info of step t (in gym.Env, the env.step() function return 4 arguments, and the last one is info);

  • policy the data computed by policy in step t;

Batch has other methods, including __getitem__(), __len__(), append(), and split():

>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])

>>> # here we test __len__
>>> len(data)
3

>>> data.append(data)  # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])

>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, shuffle=False):
...     print(d.obs, d.rew)
[ 0 11] [6 6]
[22  0] [6 6]
[11 22] [6 6]
__getitem__(index: Union[str, slice]) → Union[tianshou.data.batch.Batch, dict][source]

Return self[index].

__len__() → int[source]

Return len(self).

append(batch: tianshou.data.batch.Batch) → None[source]

Append a Batch object to current batch.

get(k: str, d: Optional[Any] = None) → Union[tianshou.data.batch.Batch, Any][source]

Return self[k] if k in self else d. d defaults to None.

keys() → List[str][source]

Return self.keys().

split(size: Optional[int] = None, shuffle: bool = True) → Iterator[tianshou.data.batch.Batch][source]

Split whole data into multiple small batch.

Parameters
  • size (int) – if it is None, it does not split the data batch; otherwise it will divide the data batch with the given size. Default to None.

  • shuffle (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

to_numpy() → None[source]

Change all torch.Tensor to numpy.ndarray. This is an inplace operation.

to_torch(dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → None[source]

Change all numpy.ndarray to torch.Tensor. This is an inplace operation.

values() → List[Any][source]

Return self.values().

tianshou.data.to_numpy(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray]) → Union[dict, tianshou.data.batch.Batch, numpy.ndarray][source]

Return an object without torch.Tensor.

tianshou.data.to_torch(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray], dtype: Optional[torch.dtype] = None, device: Union[str, int] = 'cpu') → Union[dict, tianshou.data.batch.Batch, torch.Tensor][source]

Return an object without np.ndarray.

class tianshou.data.ReplayBuffer(size: int, stack_num: Optional[int] = 0, ignore_obs_next: bool = False, **kwargs)[source]

Bases: object

ReplayBuffer stores data generated from interaction between the policy and environment. It stores basically 7 types of data, as mentioned in Batch, based on numpy.ndarray. Here is the usage:

>>> import numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
...     buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0.])

>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
...     buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14.,  5.,  6.,  7.,  8.,  9.])

>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.])

>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True,  True,  True,  True])

ReplayBuffer also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38, need version >= 0.2.3):

>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
...     done = i % 5 == 0
...     buf.add(obs={'id': i}, act=i, rew=i, done=done,
...             obs_next={'id': i + 1})
>>> print(buf)  # you can see obs_next is not saved in buf
ReplayBuffer(
    act: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
    done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
    info: Batch(),
    obs: Batch(
             id: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
         ),
    policy: Batch(),
    rew: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7.  7.  8.  9.]
 [ 7.  8.  9. 10.]
 [11. 11. 11. 11.]
 [11. 11. 11. 12.]
 [11. 11. 12. 13.]
 [11. 12. 13. 14.]
 [12. 13. 14. 15.]
 [ 7.  7.  7.  7.]
 [ 7.  7.  7.  8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7.  8.  9. 10.]
 [ 7.  8.  9. 10.]
 [11. 11. 11. 12.]
 [11. 11. 12. 13.]
 [11. 12. 13. 14.]
 [12. 13. 14. 15.]
 [12. 13. 14. 15.]
 [ 7.  7.  7.  8.]
 [ 7.  7.  8.  9.]]
__getitem__(index: Union[slice, numpy.ndarray]) → tianshou.data.batch.Batch[source]

Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, …].

__len__() → int[source]

Return len(self).

add(obs: Union[dict, numpy.ndarray], act: Union[numpy.ndarray, float], rew: float, done: bool, obs_next: Union[dict, numpy.ndarray, None] = None, info: dict = {}, policy: Union[dict, tianshou.data.batch.Batch, None] = {}, **kwargs) → None[source]

Add a batch of data into replay buffer.

get(indice: Union[slice, numpy.ndarray], key: str, stack_num: Optional[int] = None) → Union[tianshou.data.batch.Batch, numpy.ndarray][source]

Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure.

reset() → None[source]

Clear all the data in replay buffer.

sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

update(buffer: tianshou.data.buffer.ReplayBuffer) → None[source]

Move the data from the given buffer to self.

class tianshou.data.ListReplayBuffer(**kwargs)[source]

Bases: tianshou.data.buffer.ReplayBuffer

The function of ListReplayBuffer is almost the same as ReplayBuffer. The only difference is that ListReplayBuffer is based on list.

See also

Please refer to ReplayBuffer for more detailed explanation.

reset() → None[source]

Clear all the data in replay buffer.

class tianshou.data.PrioritizedReplayBuffer(size: int, alpha: float, beta: float, mode: str = 'weight', **kwargs)[source]

Bases: tianshou.data.buffer.ReplayBuffer

Prioritized replay buffer implementation.

Parameters
  • alpha (float) – the prioritization exponent.

  • beta (float) – the importance sample soft coefficient.

  • mode (str) – defaults to weight.

See also

Please refer to ReplayBuffer for more detailed explanation.

__getitem__(index: Union[slice, numpy.ndarray]) → tianshou.data.batch.Batch[source]

Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, …].

add(obs: Union[dict, numpy.ndarray], act: Union[numpy.ndarray, float], rew: float, done: bool, obs_next: Union[dict, numpy.ndarray, None] = None, info: dict = {}, policy: Union[dict, tianshou.data.batch.Batch, None] = {}, weight: float = 1.0, **kwargs) → None[source]

Add a batch of data into replay buffer.

reset() → None[source]

Clear all the data in replay buffer.

sample(batch_size: int, importance_sample: bool = True) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with priority probability. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

update_weight(indice: Union[slice, numpy.ndarray], new_weight: numpy.ndarray) → None[source]

Update priority weight by indice in this buffer.

Parameters
  • indice (np.ndarray) – indice you want to update weight

  • new_weight (np.ndarray) – new priority weight you wangt to update

class tianshou.data.Collector(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.vecenv.BaseVectorEnv], buffer: Union[tianshou.data.buffer.ReplayBuffer, List[tianshou.data.buffer.ReplayBuffer], None] = None, preprocess_fn: Callable[[Any], Union[dict, tianshou.data.batch.Batch]] = None, stat_size: Optional[int] = 100, **kwargs)[source]

Bases: object

The Collector enables the policy to interact with different types of environments conveniently.

Parameters
  • policy – an instance of the BasePolicy class.

  • env – a gym.Env environment or an instance of the BaseVectorEnv class.

  • buffer – an instance of the ReplayBuffer class, or a list of ReplayBuffer. If set to None, it will automatically assign a small-size ReplayBuffer.

  • preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42, defaults to None.

  • stat_size (int) – for the moving average of recording speed, defaults to 100.

The preprocess_fn is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in Batch. It will receive with only obs when the collector resets the environment. It returns either a dict or a Batch with the modified keys and values. Examples are in “test/base/test_collector.py”.

Example:

policy = PGPolicy(...)  # or other policies if you wish
env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(size=10000)
# here we set up a collector with a single environment
collector = Collector(policy, env, buffer=replay_buffer)

# the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer)

# collect at least 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
# collect at least 2 steps
collector.collect(n_step=2)
# collect episodes with visual rendering (the render argument is the
#   sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)

# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data)  # btw, vanilla policy gradient only
#   supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
#   clear the buffer
collector.reset_buffer()

For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation.

Note

Please make sure the given environment has a time limitation.

close() → None[source]

Close the environment(s).

collect(n_step: int = 0, n_episode: Union[int, List[int]] = 0, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None) → Dict[str, float][source]

Collect a specified number of step or episode.

Parameters
  • n_step (int) – how many steps you want to collect.

  • n_episode (int or list) – how many episodes you want to collect (in each environment).

  • render (float) – the sleep time between rendering consecutive frames, defaults to None (no rendering).

  • log_fn (function) – a function which receives env info, typically for tensorboard logging.

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns

A dict including the following keys

  • n/ep the collected number of episodes.

  • n/st the collected number of steps.

  • v/st the speed of steps per second.

  • v/ep the speed of episode per second.

  • rew the mean reward over collected episodes.

  • len the mean length over collected episodes.

get_env_num() → int[source]

Return the number of environments the collector have.

render(**kwargs) → None[source]

Render all the environment(s).

reset() → None[source]

Reset all related variables in the collector.

reset_buffer() → None[source]

Reset the main data buffer.

reset_env() → None[source]

Reset all of the environment(s)’ states and reset all of the cache buffers (if need).

sample(batch_size: int) → tianshou.data.batch.Batch[source]

Sample a data batch from the internal replay buffer. It will call process_fn() before returning the final batch data.

Parameters

batch_size (int) – 0 means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size.

seed(seed: Union[int, List[int], None] = None) → None[source]

Reset all the seed(s) of the given environment(s).