tianshou.data

class tianshou.data.Batch(batch_dict: Optional[Union[dict, Batch, Tuple[Union[dict, Batch]], List[Union[dict, Batch]], numpy.ndarray]] = None, copy: bool = False, **kwargs)[source]

Bases: object

Tianshou provides Batch as the internal data structure to pass any kind of data to other methods, for example, a collector gives a Batch to policy for learning. Here is the usage:

>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
    a: 4,
    b: array([3, 4, 5]),
    c: '2312312',
)

In short, you can define a Batch with any key-value pair.

For Numpy arrays, only data types with np.object, bool, and number are supported. For strings or other data types, however, they can be held in np.object arrays.

The current implementation of Tianshou typically use 7 reserved keys in Batch:

  • obs the observation of step t ;

  • act the action of step t ;

  • rew the reward of step t ;

  • done the done flag of step t ;

  • obs_next the observation of step t+1 ;

  • info the info of step t (in gym.Env, the env.step() function returns 4 arguments, and the last one is info);

  • policy the data computed by policy in step t;

Batch object can be initialized by a wide variety of arguments, ranging from the key/value pairs or dictionary, to list and Numpy arrays of dict or Batch instances where each element is considered as an individual sample and get stacked together:

>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
    a: Batch(
        b: array([0.0, 'info'], dtype=object),
    ),
)

Batch has the same API as a native Python dict. In this regard, one can access stored data using string key, or iterate over stored data:

>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>>     print(f"{key}: {value}")
a: 4
b: [5, 5]

Batch also partially reproduces the Numpy API for arrays. It also supports the advanced slicing method, such as batch[:, i], if the index is valid. You can access or iterate over the individual samples, if any:

>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> print(data[0])
Batch(
    a: array([0., 2.])
    b: array([ 5, -5]),
)
>>> for sample in data:
>>>     print(sample.a)
[0., 2.]

>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
    a: array([[0., 3.],
              [1., 4.]]),
    b: array([[ 5, -4]]),
)

Similarly, one can also perform simple algebra on it, and stack, split or concatenate multiple instances:

>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
    b: array([ 5, -5]),
    a: array([[0., 2.],
              [1., 3.]]),
)
>>> print(np.mean(data))
Batch(
    b: 0.0,
    a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
    b: array([5]),
    a: array([[0., 2.]]),
), Batch(
    b: array([-5]),
    a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
    b: array([ 5, -5]),
    a: array([[0., 2.],
              [1., 3.]]),
)

Note that stacking of inconsistent data is also supported. In which case, None is added in list or np.ndarray of objects, 0 otherwise.

>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
    a: array([[0., 2.],
              [1., 3.]]),
    b: array([None, 'done'], dtype=object),
)

Method empty_ sets elements to 0 or None for np.object.

>>> data.empty_()
>>> print(data)
Batch(
    a: array([[0., 0.],
              [0., 0.]]),
    b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False,  True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
    a: array([False,  True]),
    b: Batch(
           c: array([None, 'st']),
           d: array([0., 0.]),
       ),
)

shape() and __len__() methods are also provided to respectively get the shape and the length of a Batch instance. It mimics the Numpy API for Numpy arrays, which means that getting the length of a scalar Batch raises an exception.

>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()

Convenience helpers are available to convert in-place the stored data into Numpy arrays or Torch tensors.

Finally, note that Batch is serializable and therefore Pickle compatible. This is especially important for distributed sampling.

__getitem__(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]]) → tianshou.data.batch.Batch[source]

Return self[index].

__len__() → int[source]

Return len(self).

static cat(batches: List[Union[dict, Batch]]) → tianshou.data.batch.Batch[source]

Concatenate a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.

>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
cat_(batches: Union[Batch, List[Union[dict, Batch]]]) → None[source]

Concatenate a list of (or one) Batch objects into current batch.

static empty(batch: tianshou.data.batch.Batch, index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]

Return an empty Batch object with 0 or None filled, the shape is the same as the given Batch.

empty_(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]

Return an empty a Batch object with 0 or None filled. If index is specified, it will only reset the specific indexed-data.

get(k: str, d: Optional[Any] = None) → Union[tianshou.data.batch.Batch, Any][source]

Return self[k] if k in self else d. d defaults to None.

is_empty()[source]
items() → List[Tuple[str, Any]][source]

Return self.items().

keys() → List[str][source]

Return self.keys().

property shape

Return self.shape.

split(size: Optional[int] = None, shuffle: bool = True) → Iterator[tianshou.data.batch.Batch][source]

Split whole data into multiple small batches.

Parameters
  • size (int) – if it is None, it does not split the data batch; otherwise it will divide the data batch with the given size. Default to None.

  • shuffle (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

static stack(batches: List[Union[dict, Batch]], axis: int = 0) → tianshou.data.batch.Batch[source]

Stack a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.

>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)

Note

If there are keys that are not shared across all batches, stack with axis != 0 is undefined, and will cause an exception.

stack_(batches: List[Union[dict, Batch]], axis: int = 0) → None[source]

Stack a list of Batch object into current batch.

to_numpy() → None[source]

Change all torch.Tensor to numpy.ndarray. This is an in-place operation.

to_torch(dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → None[source]

Change all numpy.ndarray to torch.Tensor. This is an in-place operation.

update(batch: Optional[Union[dict, Batch]] = None, **kwargs) → None[source]

Update this batch from another dict/Batch.

values() → List[Any][source]

Return self.values().

class tianshou.data.Collector(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.vecenv.BaseVectorEnv], buffer: Optional[tianshou.data.buffer.ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, tianshou.data.batch.Batch]] = None, stat_size: Optional[int] = 100, action_noise: Optional[tianshou.exploration.random.BaseNoise] = None, reward_metric: Optional[Callable[[numpy.ndarray], float]] = None, **kwargs)[source]

Bases: object

The Collector enables the policy to interact with different types of environments conveniently.

Parameters
  • policy – an instance of the BasePolicy class.

  • env – a gym.Env environment or an instance of the BaseVectorEnv class.

  • buffer – an instance of the ReplayBuffer class, or a list of ReplayBuffer. If set to None, it will automatically assign a small-size ReplayBuffer.

  • preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42 and Handle Batched Data Stream in Collector, defaults to None.

  • stat_size (int) – for the moving average of recording speed, defaults to 100.

  • action_noise (BaseNoise) – add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose.

  • reward_metric (function) – to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1.

The preprocess_fn is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in Batch. It will receive with only obs when the collector resets the environment. It returns either a dict or a Batch with the modified keys and values. Examples are in “test/base/test_collector.py”.

Example:

policy = PGPolicy(...)  # or other policies if you wish
env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(size=10000)
# here we set up a collector with a single environment
collector = Collector(policy, env, buffer=replay_buffer)

# the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer)

# collect at least 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
# collect at least 2 steps
collector.collect(n_step=2)
# collect episodes with visual rendering (the render argument is the
#   sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)

# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data)  # btw, vanilla policy gradient only
#   supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
#   clear the buffer
collector.reset_buffer()

For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation.

Note

Please make sure the given environment has a time limitation.

close() → None[source]

Close the environment(s).

collect(n_step: int = 0, n_episode: Union[int, List[int]] = 0, random: bool = False, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None) → Dict[str, float][source]

Collect a specified number of step or episode.

Parameters
  • n_step (int) – how many steps you want to collect.

  • n_episode (int or list) – how many episodes you want to collect (in each environment).

  • random (bool) – whether to use random policy for collecting data, defaults to False.

  • render (float) – the sleep time between rendering consecutive frames, defaults to None (no rendering).

  • log_fn (function) – a function which receives env info, typically for tensorboard logging.

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns

A dict including the following keys

  • n/ep the collected number of episodes.

  • n/st the collected number of steps.

  • v/st the speed of steps per second.

  • v/ep the speed of episode per second.

  • rew the mean reward over collected episodes.

  • len the mean length over collected episodes.

get_env_num() → int[source]

Return the number of environments the collector have.

render(**kwargs) → None[source]

Render all the environment(s).

reset() → None[source]

Reset all related variables in the collector.

reset_buffer() → None[source]

Reset the main data buffer.

reset_env() → None[source]

Reset all of the environment(s)’ states and reset all of the cache buffers (if need).

sample(batch_size: int) → tianshou.data.batch.Batch[source]

Sample a data batch from the internal replay buffer. It will call process_fn() before returning the final batch data.

Parameters

batch_size (int) – 0 means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size.

seed(seed: Optional[Union[int, List[int]]] = None) → None[source]

Reset all the seed(s) of the given environment(s).

class tianshou.data.ListReplayBuffer(**kwargs)[source]

Bases: tianshou.data.buffer.ReplayBuffer

The function of ListReplayBuffer is almost the same as ReplayBuffer. The only difference is that ListReplayBuffer is based on list. Therefore, it does not support advanced indexing, which means you cannot sample a batch of data out of it. It is typically used for storing data.

See also

Please refer to ReplayBuffer for more detailed explanation.

reset() → None[source]

Clear all the data in replay buffer.

sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

class tianshou.data.PrioritizedReplayBuffer(size: int, alpha: float, beta: float, mode: str = 'weight', replace: bool = False, **kwargs)[source]

Bases: tianshou.data.buffer.ReplayBuffer

Prioritized replay buffer implementation.

Parameters
  • alpha (float) – the prioritization exponent.

  • beta (float) – the importance sample soft coefficient.

  • mode (str) – defaults to weight.

  • replace (bool) – whether to sample with replacement

See also

Please refer to ReplayBuffer for more detailed explanation.

__getitem__(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]

Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, …].

add(obs: Union[dict, numpy.ndarray], act: Union[numpy.ndarray, float], rew: Union[int, float], done: bool, obs_next: Optional[Union[dict, numpy.ndarray]] = None, info: dict = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, weight: float = 1.0, **kwargs) → None[source]

Add a batch of data into replay buffer.

property replace
sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with priority probability. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

update_weight(indice: Union[slice, numpy.ndarray], new_weight: numpy.ndarray) → None[source]

Update priority weight by indice in this buffer.

Parameters
  • indice (np.ndarray) – indice you want to update weight

  • new_weight (np.ndarray) – new priority weight you want to update

class tianshou.data.ReplayBuffer(size: int, stack_num: Optional[int] = 0, ignore_obs_next: bool = False, sample_avail: bool = False, **kwargs)[source]

Bases: object

ReplayBuffer stores data generated from interaction between the policy and environment. It stores basically 7 types of data, as mentioned in Batch, based on numpy.ndarray. Here is the usage:

>>> import numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
...     buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0.])

>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
...     buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14.,  5.,  6.,  7.,  8.,  9.])

>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.])

>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True,  True,  True,  True])

ReplayBuffer also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38):

>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
...     done = i % 5 == 0
...     buf.add(obs={'id': i}, act=i, rew=i, done=done,
...             obs_next={'id': i + 1})
>>> print(buf)  # you can see obs_next is not saved in buf
ReplayBuffer(
    act: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
    done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
    info: Batch(),
    obs: Batch(
             id: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
         ),
    policy: Batch(),
    rew: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7.  7.  8.  9.]
 [ 7.  8.  9. 10.]
 [11. 11. 11. 11.]
 [11. 11. 11. 12.]
 [11. 11. 12. 13.]
 [11. 12. 13. 14.]
 [12. 13. 14. 15.]
 [ 7.  7.  7.  7.]
 [ 7.  7.  7.  8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7.  8.  9. 10.]
 [ 7.  8.  9. 10.]
 [11. 11. 11. 12.]
 [11. 11. 12. 13.]
 [11. 12. 13. 14.]
 [12. 13. 14. 15.]
 [12. 13. 14. 15.]
 [ 7.  7.  7.  8.]
 [ 7.  7.  8.  9.]]
Parameters
  • size (int) – the size of replay buffer.

  • stack_num (int) – the frame-stack sampling argument, should be greater than 1, defaults to 0 (no stacking).

  • ignore_obs_next (bool) – whether to store obs_next, defaults to False.

  • sample_avail (bool) – the parameter indicating sampling only available index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently.

__getitem__(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]

Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, …].

__len__() → int[source]

Return len(self).

add(obs: Union[dict, tianshou.data.batch.Batch, numpy.ndarray], act: Union[numpy.ndarray, float], rew: Union[int, float], done: bool, obs_next: Optional[Union[dict, tianshou.data.batch.Batch, numpy.ndarray]] = None, info: dict = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, **kwargs) → None[source]

Add a batch of data into replay buffer.

get(indice: Union[slice, int, numpy.integer, numpy.ndarray], key: str, stack_num: Optional[int] = None) → Union[tianshou.data.batch.Batch, numpy.ndarray][source]

Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure.

reset() → None[source]

Clear all the data in replay buffer.

sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

update(buffer: tianshou.data.buffer.ReplayBuffer) → None[source]

Move the data from the given buffer to self.

tianshou.data.to_numpy(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray]) → Union[dict, tianshou.data.batch.Batch, numpy.ndarray][source]

Return an object without torch.Tensor.

tianshou.data.to_torch(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray], dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → Union[dict, tianshou.data.batch.Batch, torch.Tensor][source]

Return an object without np.ndarray.

tianshou.data.to_torch_as(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray], y: torch.Tensor) → Union[dict, tianshou.data.batch.Batch, torch.Tensor][source]

Return an object without np.ndarray. Same as to_torch(x, dtype=y.dtype, device=y.device).