tianshou.data

Batch

class tianshou.data.Batch(batch_dict: Optional[Union[dict, tianshou.data.batch.Batch, Sequence[Union[dict, tianshou.data.batch.Batch]], numpy.ndarray]] = None, copy: bool = False, **kwargs: Any)[source]

Bases: object

The internal data structure in Tianshou.

Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of object that can be either numpy array, torch tensor, or batch themself. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.

For a detailed description, please refer to Understand Batch.

__getitem__(index: Union[str, slice, int, numpy.ndarray, List[int]])Any[source]

Return self[index].

__setitem__(index: Union[str, slice, int, numpy.ndarray, List[int]], value: Any)None[source]

Assign value to self[index].

to_numpy()None[source]

Change all torch.Tensor to numpy.ndarray in-place.

to_torch(dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu')None[source]

Change all numpy.ndarray to torch.Tensor in-place.

cat_(batches: Union[tianshou.data.batch.Batch, Sequence[Union[dict, tianshou.data.batch.Batch]]])None[source]

Concatenate a list of (or one) Batch objects into current batch.

static cat(batches: Sequence[Union[dict, tianshou.data.batch.Batch]])tianshou.data.batch.Batch[source]

Concatenate a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.

>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
stack_(batches: Sequence[Union[dict, tianshou.data.batch.Batch]], axis: int = 0)None[source]

Stack a list of Batch object into current batch.

static stack(batches: Sequence[Union[dict, tianshou.data.batch.Batch]], axis: int = 0)tianshou.data.batch.Batch[source]

Stack a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.

>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)

Note

If there are keys that are not shared across all batches, stack with axis != 0 is undefined, and will cause an exception.

empty_(index: Optional[Union[slice, int, numpy.ndarray, List[int]]] = None)tianshou.data.batch.Batch[source]

Return an empty Batch object with 0 or None filled.

If “index” is specified, it will only reset the specific indexed-data.

>>> data.empty_()
>>> print(data)
Batch(
    a: array([[0., 0.],
              [0., 0.]]),
    b: array([None, None], dtype=object),
)
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
>>> data = Batch(a=[False,  True], b=b)
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
    a: array([False,  True]),
    b: Batch(
           c: array([None, 'st']),
           d: array([0., 0.]),
       ),
)
static empty(batch: tianshou.data.batch.Batch, index: Optional[Union[slice, int, numpy.ndarray, List[int]]] = None)tianshou.data.batch.Batch[source]

Return an empty Batch object with 0 or None filled.

The shape is the same as the given Batch.

update(batch: Optional[Union[dict, tianshou.data.batch.Batch]] = None, **kwargs: Any)None[source]

Update this batch from another dict/Batch.

__len__()int[source]

Return len(self).

is_empty(recurse: bool = False)bool[source]

Test if a Batch is empty.

If recurse=True, it further tests the values of the object; else it only tests the existence of any key.

b.is_empty(recurse=True) is mainly used to distinguish Batch(a=Batch(a=Batch())) and Batch(a=1). They both raise exceptions when applied to len(), but the former can be used in cat, while the latter is a scalar and cannot be used in cat.

Another usage is in __len__, where we have to skip checking the length of recursively empty Batch.

>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
property shape

Return self.shape.

split(size: int, shuffle: bool = True, merge_last: bool = False)Iterator[tianshou.data.batch.Batch][source]

Split whole data into multiple small batches.

Parameters
  • size (int) – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”.

  • shuffle (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

  • merge_last (bool) – merge the last batch into the previous one. Default to False.

Buffer

ReplayBuffer

class tianshou.data.ReplayBuffer(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs: Any)[source]

Bases: object

ReplayBuffer stores data generated from interaction between the policy and environment.

ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.

For the example usage of ReplayBuffer, please check out Section Buffer in Basic concepts in Tianshou.

Parameters
  • size (int) – the maximum size of replay buffer.

  • stack_num (int) – the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking).

  • ignore_obs_next (bool) – whether to store obs_next. Default to False.

  • save_only_last_obs (bool) – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking. Default to False.

  • sample_avail (bool) – the parameter indicating sampling only available index when using frame-stack sampling method. Default to False.

__len__()int[source]

Return len(self).

save_hdf5(path: str)None[source]

Save replay buffer to HDF5 file.

classmethod load_hdf5(path: str, device: Optional[str] = None)tianshou.data.buffer.base.ReplayBuffer[source]

Load replay buffer from HDF5 file.

reset(keep_statistics: bool = False)None[source]

Clear all the data in replay buffer and episode statistics.

set_batch(batch: tianshou.data.batch.Batch)None[source]

Manually choose the batch you want the ReplayBuffer to manage.

unfinished_index()numpy.ndarray[source]

Return the index of unfinished episode.

prev(index: Union[int, numpy.ndarray])numpy.ndarray[source]

Return the index of previous transition.

The index won’t be modified if it is the beginning of an episode.

next(index: Union[int, numpy.ndarray])numpy.ndarray[source]

Return the index of next transition.

The index won’t be modified if it is the end of an episode.

update(buffer: tianshou.data.buffer.base.ReplayBuffer)numpy.ndarray[source]

Move the data from the given buffer to current buffer.

Return the updated indices. If update fails, return an empty array.

add(batch: tianshou.data.batch.Batch, buffer_ids: Optional[Union[numpy.ndarray, List[int]]] = None)Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray][source]

Add a batch of data into replay buffer.

Parameters
  • batch (Batch) – the input data batch. Its keys must belong to the 7 reserved keys, and “obs”, “act”, “rew”, “done” is required.

  • buffer_ids – to make consistent with other buffer’s add function; if it is not None, we assume the input batch’s first dimension is always 1.

Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

sample_index(batch_size: int)numpy.ndarray[source]

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

sample(batch_size: int)Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]

Get a random sample from buffer with size = batch_size.

Return all the data in the buffer if batch_size is 0.

Returns

Sample data and its corresponding index inside the buffer.

get(index: Union[int, List[int], numpy.ndarray], key: str, default_value: Optional[Any] = None, stack_num: Optional[int] = None)Union[tianshou.data.batch.Batch, numpy.ndarray][source]

Return the stacked result.

E.g., if you set key = "obs", stack_num = 4, index = t, it returns the stacked result as [obs[t-3], obs[t-2], obs[t-1], obs[t]].

Parameters
  • index – the index for getting stacked data.

  • key (str) – the key to get, should be one of the reserved_keys.

  • default_value – if the given key’s data is not found and default_value is set, return this default_value.

  • stack_num (int) – Default to self.stack_num.

__getitem__(index: Union[slice, int, List[int], numpy.ndarray])tianshou.data.batch.Batch[source]

Return a data batch: self[index].

If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, …).

PrioritizedReplayBuffer

class tianshou.data.PrioritizedReplayBuffer(size: int, alpha: float, beta: float, **kwargs: Any)[source]

Bases: tianshou.data.buffer.base.ReplayBuffer

Implementation of Prioritized Experience Replay. arXiv:1511.05952.

Parameters
  • alpha (float) – the prioritization exponent.

  • beta (float) – the importance sample soft coefficient.

See also

Please refer to ReplayBuffer for other APIs’ usage.

init_weight(index: Union[int, numpy.ndarray])None[source]
update(buffer: tianshou.data.buffer.base.ReplayBuffer)numpy.ndarray[source]

Move the data from the given buffer to current buffer.

Return the updated indices. If update fails, return an empty array.

add(batch: tianshou.data.batch.Batch, buffer_ids: Optional[Union[numpy.ndarray, List[int]]] = None)Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray][source]

Add a batch of data into replay buffer.

Parameters
  • batch (Batch) – the input data batch. Its keys must belong to the 7 reserved keys, and “obs”, “act”, “rew”, “done” is required.

  • buffer_ids – to make consistent with other buffer’s add function; if it is not None, we assume the input batch’s first dimension is always 1.

Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

sample_index(batch_size: int)numpy.ndarray[source]

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

get_weight(index: Union[int, numpy.ndarray])Union[float, numpy.ndarray][source]

Get the importance sampling weight.

The “weight” in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less).

update_weight(index: numpy.ndarray, new_weight: Union[numpy.ndarray, torch.Tensor])None[source]

Update priority weight by index in this buffer.

Parameters
  • index (np.ndarray) – index you want to update weight.

  • new_weight (np.ndarray) – new priority weight you want to update.

__getitem__(index: Union[slice, int, List[int], numpy.ndarray])tianshou.data.batch.Batch[source]

Return a data batch: self[index].

If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, …).

ReplayBufferManager

class tianshou.data.ReplayBufferManager(buffer_list: List[tianshou.data.buffer.base.ReplayBuffer])[source]

Bases: tianshou.data.buffer.base.ReplayBuffer

ReplayBufferManager contains a list of ReplayBuffer with exactly the same configuration.

These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory.

Parameters

buffer_list – a list of ReplayBuffer needed to be handled.

See also

Please refer to ReplayBuffer for other APIs’ usage.

__len__()int[source]

Return len(self).

reset(keep_statistics: bool = False)None[source]

Clear all the data in replay buffer and episode statistics.

set_batch(batch: tianshou.data.batch.Batch)None[source]

Manually choose the batch you want the ReplayBuffer to manage.

unfinished_index()numpy.ndarray[source]

Return the index of unfinished episode.

prev(index: Union[int, numpy.ndarray])numpy.ndarray[source]

Return the index of previous transition.

The index won’t be modified if it is the beginning of an episode.

next(index: Union[int, numpy.ndarray])numpy.ndarray[source]

Return the index of next transition.

The index won’t be modified if it is the end of an episode.

update(buffer: tianshou.data.buffer.base.ReplayBuffer)numpy.ndarray[source]

The ReplayBufferManager cannot be updated by any buffer.

add(batch: tianshou.data.batch.Batch, buffer_ids: Optional[Union[numpy.ndarray, List[int]]] = None)Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray][source]

Add a batch of data into ReplayBufferManager.

Each of the data’s length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, …, buffer_num - 1].

Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

sample_index(batch_size: int)numpy.ndarray[source]

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

PrioritizedReplayBufferManager

class tianshou.data.PrioritizedReplayBufferManager(buffer_list: Sequence[tianshou.data.buffer.prio.PrioritizedReplayBuffer])[source]

Bases: tianshou.data.buffer.prio.PrioritizedReplayBuffer, tianshou.data.buffer.manager.ReplayBufferManager

PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with exactly the same configuration.

These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory.

Parameters

buffer_list – a list of PrioritizedReplayBuffer needed to be handled.

See also

Please refer to ReplayBuffer for other APIs’ usage.

VectorReplayBuffer

class tianshou.data.VectorReplayBuffer(total_size: int, buffer_num: int, **kwargs: Any)[source]

Bases: tianshou.data.buffer.manager.ReplayBufferManager

VectorReplayBuffer contains n ReplayBuffer with the same size.

It is used for storing transition from different environments yet keeping the order of time.

Parameters
  • total_size (int) – the total size of VectorReplayBuffer.

  • buffer_num (int) – the number of ReplayBuffer it uses, which are under the same configuration.

Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail) are the same as ReplayBuffer.

See also

Please refer to ReplayBuffer for other APIs’ usage.

PrioritizedVectorReplayBuffer

class tianshou.data.PrioritizedVectorReplayBuffer(total_size: int, buffer_num: int, **kwargs: Any)[source]

Bases: tianshou.data.buffer.manager.PrioritizedReplayBufferManager

PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.

It is used for storing transition from different environments yet keeping the order of time.

Parameters
  • total_size (int) – the total size of PrioritizedVectorReplayBuffer.

  • buffer_num (int) – the number of PrioritizedReplayBuffer it uses, which are under the same configuration.

Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/ sample_avail) are the same as PrioritizedReplayBuffer.

See also

Please refer to ReplayBuffer for other APIs’ usage.

CachedReplayBuffer

class tianshou.data.CachedReplayBuffer(main_buffer: tianshou.data.buffer.base.ReplayBuffer, cached_buffer_num: int, max_episode_length: int)[source]

Bases: tianshou.data.buffer.manager.ReplayBufferManager

CachedReplayBuffer contains a given main buffer and n cached buffers, cached_buffer_num * ReplayBuffer(size=max_episode_length).

The memory layout is: | main_buffer | cached_buffers[0] | cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1] |.

The data is first stored in cached buffers. When an episode is terminated, the data will move to the main buffer and the corresponding cached buffer will be reset.

Parameters
  • main_buffer (ReplayBuffer) – the main buffer whose .update() function behaves normally.

  • cached_buffer_num (int) – number of ReplayBuffer needs to be created for cached buffer.

  • max_episode_length (int) – the maximum length of one episode, used in each cached buffer’s maxsize.

See also

Please refer to ReplayBuffer for other APIs’ usage.

add(batch: tianshou.data.batch.Batch, buffer_ids: Optional[Union[numpy.ndarray, List[int]]] = None)Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray][source]

Add a batch of data into CachedReplayBuffer.

Each of the data’s length (first dimension) must equal to the length of buffer_ids. By default the buffer_ids is [0, 1, …, cached_buffer_num - 1].

Return (current_index, episode_reward, episode_length, episode_start_index) with each of the shape (len(buffer_ids), …), where (current_index[i], episode_reward[i], episode_length[i], episode_start_index[i]) refers to the cached_buffer_ids[i]th cached buffer’s corresponding episode result.

Collector

Collector

class tianshou.data.Collector(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.venvs.BaseVectorEnv], buffer: Optional[tianshou.data.buffer.base.ReplayBuffer] = None, preprocess_fn: Optional[Callable[[], tianshou.data.batch.Batch]] = None, exploration_noise: bool = False)[source]

Bases: object

Collector enables the policy to interact with different types of envs with exact number of steps or episodes.

Parameters
  • policy – an instance of the BasePolicy class.

  • env – a gym.Env environment or an instance of the BaseVectorEnv class.

  • buffer – an instance of the ReplayBuffer class. If set to None, it will not store the data. Default to None.

  • preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42 and Handle Batched Data Stream in Collector. Default to None.

  • exploration_noise (bool) – determine whether the action needs to be modified with corresponding policy’s exploration noise. If so, “policy. exploration_noise(act, batch)” will be called automatically to add the exploration noise into action. Default to False.

The “preprocess_fn” is a function called before the data has been added to the buffer with batch format. It will receive with only “obs” when the collector resets the environment, and will receive four keys “obs_next”, “rew”, “done”, “info” in a normal env step. It returns either a dict or a Batch with the modified keys and values. Examples are in “test/base/test_collector.py”.

Note

Please make sure the given environment has a time limitation if using n_episode collect option.

reset()None[source]

Reset all related variables in the collector.

reset_stat()None[source]

Reset the statistic variables.

reset_buffer(keep_statistics: bool = False)None[source]

Reset the data buffer.

reset_env()None[source]

Reset all of the environments.

collect(n_step: Optional[int] = None, n_episode: Optional[int] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True)Dict[str, Any][source]

Collect a specified number of step or episode.

To ensure unbiased sampling result with n_episode option, this function will first collect n_episode - env_num episodes, then for the last env_num episodes, they will be collected evenly from each env.

Parameters
  • n_step (int) – how many steps you want to collect.

  • n_episode (int) – how many episodes you want to collect.

  • random (bool) – whether to use random policy for collecting data. Default to False.

  • render (float) – the sleep time between rendering consecutive frames. Default to None (no rendering).

  • no_grad (bool) – whether to retain gradient in policy.forward(). Default to True (no gradient retaining).

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns

A dict including the following keys

  • n/ep collected number of episodes.

  • n/st collected number of steps.

  • rews array of episode reward over collected episodes.

  • lens array of episode length over collected episodes.

  • idxs array of episode start index in buffer over collected episodes.

AsyncCollector

class tianshou.data.AsyncCollector(policy: tianshou.policy.base.BasePolicy, env: tianshou.env.venvs.BaseVectorEnv, buffer: Optional[tianshou.data.buffer.base.ReplayBuffer] = None, preprocess_fn: Optional[Callable[[], tianshou.data.batch.Batch]] = None, exploration_noise: bool = False)[source]

Bases: tianshou.data.collector.Collector

Async Collector handles async vector environment.

The arguments are exactly the same as Collector, please refer to Collector for more detailed explanation.

reset_env()None[source]

Reset all of the environments.

collect(n_step: Optional[int] = None, n_episode: Optional[int] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True)Dict[str, Any][source]

Collect a specified number of step or episode with async env setting.

This function doesn’t collect exactly n_step or n_episode number of transitions. Instead, in order to support async setting, it may collect more than given n_step or n_episode transitions and save into buffer.

Parameters
  • n_step (int) – how many steps you want to collect.

  • n_episode (int) – how many episodes you want to collect.

  • random (bool) – whether to use random policy for collecting data. Default to False.

  • render (float) – the sleep time between rendering consecutive frames. Default to None (no rendering).

  • no_grad (bool) – whether to retain gradient in policy.forward(). Default to True (no gradient retaining).

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns

A dict including the following keys

  • n/ep collected number of episodes.

  • n/st collected number of steps.

  • rews array of episode reward over collected episodes.

  • lens array of episode length over collected episodes.

  • idxs array of episode start index in buffer over collected episodes.