tianshou.data

class tianshou.data.Batch(batch_dict: Optional[Union[dict, Batch, Sequence[Union[dict, Batch]], numpy.ndarray]] = None, copy: bool = False, **kwargs: Any)[source]

Bases: object

The internal data structure in Tianshou.

Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of object that can be either numpy array, torch tensor, or batch themself. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.

For a detailed description, please refer to Understand Batch.

__getitem__(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]]) → Any[source]

Return self[index].

__len__() → int[source]

Return len(self).

__setitem__(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]], value: Any) → None[source]

Assign value to self[index].

static cat(batches: Sequence[Union[dict, Batch]]) → tianshou.data.batch.Batch[source]

Concatenate a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.

>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
cat_(batches: Union[Batch, Sequence[Union[dict, Batch]]]) → None[source]

Concatenate a list of (or one) Batch objects into current batch.

static empty(batch: tianshou.data.batch.Batch, index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]

Return an empty Batch object with 0 or None filled.

The shape is the same as the given Batch.

empty_(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]

Return an empty Batch object with 0 or None filled.

If “index” is specified, it will only reset the specific indexed-data.

>>> data.empty_()
>>> print(data)
Batch(
    a: array([[0., 0.],
              [0., 0.]]),
    b: array([None, None], dtype=object),
)
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
>>> data = Batch(a=[False,  True], b=b)
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
    a: array([False,  True]),
    b: Batch(
           c: array([None, 'st']),
           d: array([0., 0.]),
       ),
)
is_empty(recurse: bool = False) → bool[source]

Test if a Batch is empty.

If recurse=True, it further tests the values of the object; else it only tests the existence of any key.

b.is_empty(recurse=True) is mainly used to distinguish Batch(a=Batch(a=Batch())) and Batch(a=1). They both raise exceptions when applied to len(), but the former can be used in cat, while the latter is a scalar and cannot be used in cat.

Another usage is in __len__, where we have to skip checking the length of recursively empty Batch.

>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
property shape

Return self.shape.

split(size: int, shuffle: bool = True, merge_last: bool = False) → Iterator[tianshou.data.batch.Batch][source]

Split whole data into multiple small batches.

Parameters
  • size (int) – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”.

  • shuffle (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

  • merge_last (bool) – merge the last batch into the previous one. Default to False.

static stack(batches: Sequence[Union[dict, Batch]], axis: int = 0) → tianshou.data.batch.Batch[source]

Stack a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.

>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)

Note

If there are keys that are not shared across all batches, stack with axis != 0 is undefined, and will cause an exception.

stack_(batches: Sequence[Union[dict, Batch]], axis: int = 0) → None[source]

Stack a list of Batch object into current batch.

to_numpy() → None[source]

Change all torch.Tensor to numpy.ndarray in-place.

to_torch(dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → None[source]

Change all numpy.ndarray to torch.Tensor in-place.

update(batch: Optional[Union[dict, Batch]] = None, **kwargs: Any) → None[source]

Update this batch from another dict/Batch.

class tianshou.data.Collector(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.venvs.BaseVectorEnv], buffer: Optional[tianshou.data.buffer.ReplayBuffer] = None, preprocess_fn: Optional[Callable[[], tianshou.data.batch.Batch]] = None, action_noise: Optional[tianshou.exploration.random.BaseNoise] = None, reward_metric: Optional[Callable[[numpy.ndarray], float]] = None)[source]

Bases: object

Collector enables the policy to interact with different types of envs.

Parameters
  • policy – an instance of the BasePolicy class.

  • env – a gym.Env environment or an instance of the BaseVectorEnv class.

  • buffer – an instance of the ReplayBuffer class. If set to None (testing phase), it will not store the data.

  • preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42 and Handle Batched Data Stream in Collector, defaults to None.

  • action_noise (BaseNoise) – add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose.

  • reward_metric (function) – to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1.

The preprocess_fn is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in Batch. It will receive with only obs when the collector resets the environment. It returns either a dict or a Batch with the modified keys and values. Examples are in “test/base/test_collector.py”.

Here is the example:

policy = PGPolicy(...)  # or other policies if you wish
env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(size=10000)
# here we set up a collector with a single environment
collector = Collector(policy, env, buffer=replay_buffer)

# the collector supports vectorized environments as well
envs = DummyVectorEnv([lambda: gym.make('CartPole-v0')
                       for _ in range(3)])
collector = Collector(policy, envs, buffer=replay_buffer)

# collect 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
# collect at least 2 steps
collector.collect(n_step=2)
# collect episodes with visual rendering (the render argument is the
#   sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)

Collected data always consist of full episodes. So if only n_step argument is give, the collector may return the data more than the n_step limitation. Same as n_episode for the multiple environment case.

Note

Please make sure the given environment has a time limitation.

collect(n_step: Optional[int] = None, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True) → Dict[str, float][source]

Collect a specified number of step or episode.

Parameters
  • n_step (int) – how many steps you want to collect.

  • n_episode – how many episodes you want to collect. If it is an int, it means to collect at lease n_episode episodes; if it is a list, it means to collect exactly n_episode[i] episodes in the i-th environment

  • random (bool) – whether to use random policy for collecting data, defaults to False.

  • render (float) – the sleep time between rendering consecutive frames, defaults to None (no rendering).

  • no_grad (bool) – whether to retain gradient in policy.forward, defaults to True (no gradient retaining).

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns

A dict including the following keys

  • n/ep the collected number of episodes.

  • n/st the collected number of steps.

  • v/st the speed of steps per second.

  • v/ep the speed of episode per second.

  • rew the mean reward over collected episodes.

  • len the mean length over collected episodes.

get_env_num() → int[source]

Return the number of environments the collector have.

reset() → None[source]

Reset all related variables in the collector.

reset_buffer() → None[source]

Reset the main data buffer.

reset_env() → None[source]

Reset all of the environment(s)’ states and the cache buffers.

reset_stat() → None[source]

Reset the statistic variables.

class tianshou.data.ListReplayBuffer(**kwargs: Any)[source]

Bases: tianshou.data.buffer.ReplayBuffer

List-based replay buffer.

The function of ListReplayBuffer is almost the same as ReplayBuffer. The only difference is that ListReplayBuffer is based on list. Therefore, it does not support advanced indexing, which means you cannot sample a batch of data out of it. It is typically used for storing data.

See also

Please refer to ReplayBuffer for more detailed explanation.

reset() → None[source]

Clear all the data in replay buffer.

sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with size equal to batch_size.

Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

class tianshou.data.PrioritizedReplayBuffer(size: int, alpha: float, beta: float, **kwargs: Any)[source]

Bases: tianshou.data.buffer.ReplayBuffer

Implementation of Prioritized Experience Replay. arXiv:1511.05952.

Parameters
  • alpha (float) – the prioritization exponent.

  • beta (float) – the importance sample soft coefficient.

See also

Please refer to ReplayBuffer for more detailed explanation.

__getitem__(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]

Return a data batch: self[index].

If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, …).

add(obs: Any, act: Any, rew: Union[numbers.Number, numpy.number, numpy.ndarray], done: Union[numbers.Number, numpy.number, numpy.bool_], obs_next: Any = None, info: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, weight: Optional[Union[numbers.Number, numpy.number]] = None, **kwargs: Any) → None[source]

Add a batch of data into replay buffer.

sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with priority probability.

Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

The “weight” in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less).

update_weight(indice: numpy.ndarray, new_weight: Union[numpy.ndarray, torch.Tensor]) → None[source]

Update priority weight by indice in this buffer.

Parameters
  • indice (np.ndarray) – indice you want to update weight.

  • new_weight (np.ndarray) – new priority weight you want to update.

class tianshou.data.ReplayBuffer(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False)[source]

Bases: object

ReplayBuffer stores data generated from interaction between the policy and environment.

The current implementation of Tianshou typically use 7 reserved keys in Batch:

  • obs the observation of step \(t\) ;

  • act the action of step \(t\) ;

  • rew the reward of step \(t\) ;

  • done the done flag of step \(t\) ;

  • obs_next the observation of step \(t+1\) ;

  • info the info of step \(t\) (in gym.Env, the env.step() function returns 4 arguments, and the last one is info);

  • policy the data computed by policy in step \(t\);

The following code snippet illustrates its usage:

>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
...     buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0.])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb'))  # save to file "buf.pkl"
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
...     buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14.,  5.,  6.,  7.,  8.,  9.])

>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0.,  1.,  2.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.])

>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True,  True,  True,  True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb'))  # load from "buf.pkl"
>>> len(buf)
3

ReplayBuffer also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38):

>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
...     done = i % 5 == 0
...     buf.add(obs={'id': i}, act=i, rew=i, done=done,
...             obs_next={'id': i + 1})
>>> print(buf)  # you can see obs_next is not saved in buf
ReplayBuffer(
    act: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
    done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
    info: Batch(),
    obs: Batch(
             id: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
         ),
    policy: Batch(),
    rew: array([ 9., 10., 11., 12., 13., 14., 15.,  7.,  8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7.  7.  8.  9.]
 [ 7.  8.  9. 10.]
 [11. 11. 11. 11.]
 [11. 11. 11. 12.]
 [11. 11. 12. 13.]
 [11. 12. 13. 14.]
 [12. 13. 14. 15.]
 [ 7.  7.  7.  7.]
 [ 7.  7.  7.  8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7.  8.  9. 10.]
 [ 7.  8.  9. 10.]
 [11. 11. 11. 12.]
 [11. 11. 12. 13.]
 [11. 12. 13. 14.]
 [12. 13. 14. 15.]
 [12. 13. 14. 15.]
 [ 7.  7.  7.  8.]
 [ 7.  7.  8.  9.]]
Parameters
  • size (int) – the size of replay buffer.

  • stack_num (int) – the frame-stack sampling argument, should be greater than or equal to 1, defaults to 1 (no stacking).

  • ignore_obs_next (bool) – whether to store obs_next, defaults to False.

  • save_only_last_obs (bool) – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking, defaults to False.

  • sample_avail (bool) – the parameter indicating sampling only available index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently.

__getitem__(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]

Return a data batch: self[index].

If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, …).

__len__() → int[source]

Return len(self).

add(obs: Any, act: Any, rew: Union[numbers.Number, numpy.number, numpy.ndarray], done: Union[numbers.Number, numpy.number, numpy.bool_], obs_next: Any = None, info: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, **kwargs: Any) → None[source]

Add a batch of data into replay buffer.

get(indice: Union[slice, int, numpy.integer, numpy.ndarray], key: str, stack_num: Optional[int] = None) → Union[tianshou.data.batch.Batch, numpy.ndarray][source]

Return the stacked result.

E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the indice. The stack_num (here equals to 4) is given from buffer initialization procedure.

reset() → None[source]

Clear all the data in replay buffer.

sample(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with size equal to batch_size.

Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

property stack_num
update(buffer: tianshou.data.buffer.ReplayBuffer) → None[source]

Move the data from the given buffer to self.

class tianshou.data.SegmentTree(size: int)[source]

Bases: object

Implementation of Segment Tree.

The segment tree stores an array arr with size n. It supports value update and fast query of the sum for the interval [left, right) in O(log n) time. The detailed procedure is as follows:

  1. Pad the array to have length of power of 2, so that leaf nodes in the segment tree have the same depth.

  2. Store the segment tree in a binary heap.

Parameters

size (int) – the size of segment tree.

__getitem__(index: Union[int, numpy.ndarray]) → Union[float, numpy.ndarray][source]

Return self[index].

__len__() → int[source]
__setitem__(index: Union[int, numpy.ndarray], value: Union[float, numpy.ndarray]) → None[source]

Update values in segment tree.

Duplicate values in index are handled by numpy: later index overwrites previous ones.

>>> a = np.array([1, 2, 3, 4])
>>> a[[0, 1, 0, 1]] = [4, 5, 6, 7]
>>> print(a)
[6 7 3 4]
get_prefix_sum_idx(value: Union[float, numpy.ndarray]) → Union[int, numpy.ndarray][source]

Find the index with given value.

Return the minimum index for each v in value so that \(v \le \mathrm{sums}_i\), where \(\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j\).

Warning

Please make sure all of the values inside the segment tree are non-negative when using this function.

reduce(start: int = 0, end: Optional[int] = None) → float[source]

Return operation(value[start:end]).

tianshou.data.to_numpy(x: Optional[Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.number, numpy.bool_, numbers.Number, numpy.ndarray, torch.Tensor]]) → Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray][source]

Return an object without torch.Tensor.

tianshou.data.to_torch(x: Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.number, numpy.bool_, numbers.Number, numpy.ndarray, torch.Tensor], dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → Union[tianshou.data.batch.Batch, dict, list, tuple, torch.Tensor][source]

Return an object without np.ndarray.

tianshou.data.to_torch_as(x: Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor], y: torch.Tensor) → Union[tianshou.data.batch.Batch, dict, list, tuple, torch.Tensor][source]

Return an object without np.ndarray.

Same as to_torch(x, dtype=y.dtype, device=y.device).