base#


class ReplayBuffer(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs: Any)[source]#

ReplayBuffer stores data generated from interaction between the policy and environment.

ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.

For the example usage of ReplayBuffer, please check out Section Buffer in Basic concepts in Tianshou.

Parameters:
  • size – the maximum size of replay buffer.

  • stack_num – the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking).

  • ignore_obs_next – whether to not store obs_next. Default to False.

  • save_only_last_obs – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking. Default to False.

  • sample_avail – the parameter indicating sampling only available index when using frame-stack sampling method. Default to False.

add(batch: RolloutBatchProtocol, buffer_ids: ndarray | list[int] | None = None) tuple[ndarray, ndarray, ndarray, ndarray][source]#

Add a batch of data into replay buffer.

Parameters:
  • batch – the input data batch. “obs”, “act”, “rew”, “terminated”, “truncated” are required keys.

  • buffer_ids – to make consistent with other buffer’s add function; if it is not None, we assume the input batch’s first dimension is always 1.

Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

classmethod from_data(obs: Dataset, act: Dataset, rew: Dataset, terminated: Dataset, truncated: Dataset, done: Dataset, obs_next: Dataset) Self[source]#
get(index: int | list[int] | ndarray, key: str, default_value: Any = None, stack_num: int | None = None) Batch | ndarray[source]#

Return the stacked result.

E.g., if you set key = "obs", stack_num = 4, index = t, it returns the stacked result as [obs[t-3], obs[t-2], obs[t-1], obs[t]].

Parameters:
  • index – the index for getting stacked data.

  • key (str) – the key to get, should be one of the reserved_keys.

  • default_value – if the given key’s data is not found and default_value is set, return this default_value.

  • stack_num – Default to self.stack_num.

classmethod load_hdf5(path: str, device: str | None = None) Self[source]#

Load replay buffer from HDF5 file.

next(index: int | ndarray) ndarray[source]#

Return the index of next transition.

The index won’t be modified if it is the end of an episode.

prev(index: int | ndarray) ndarray[source]#

Return the index of previous transition.

The index won’t be modified if it is the beginning of an episode.

reset(keep_statistics: bool = False) None[source]#

Clear all the data in replay buffer and episode statistics.

sample(batch_size: int | None) tuple[RolloutBatchProtocol, ndarray][source]#

Get a random sample from buffer with size = batch_size.

Return all the data in the buffer if batch_size is 0.

Returns:

Sample data and its corresponding index inside the buffer.

sample_indices(batch_size: int | None) ndarray[source]#

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

Parameters:

batch_size – the number of indices to be sampled. If None, it will be set to the length of the buffer (i.e. return all available indices in a random order).

save_hdf5(path: str, compression: str | None = None) None[source]#

Save replay buffer to HDF5 file.

set_batch(batch: RolloutBatchProtocol) None[source]#

Manually choose the batch you want the ReplayBuffer to manage.

unfinished_index() ndarray[source]#

Return the index of unfinished episode.

update(buffer: ReplayBuffer) ndarray[source]#

Move the data from the given buffer to current buffer.

Return the updated indices. If update fails, return an empty array.