prio#


class PrioritizedReplayBuffer(size: int, alpha: float, beta: float, weight_norm: bool = True, **kwargs: Any)[source]#

Implementation of Prioritized Experience Replay. arXiv:1511.05952.

Parameters:
  • alpha – the prioritization exponent.

  • beta – the importance sample soft coefficient.

  • weight_norm – whether to normalize returned weights with the maximum weight value within the batch. Default to True.

See also

Please refer to ReplayBuffer for other APIs’ usage.

add(batch: RolloutBatchProtocol, buffer_ids: ndarray | list[int] | None = None) tuple[ndarray, ndarray, ndarray, ndarray][source]#

Add a batch of data into replay buffer.

Parameters:
  • batch – the input data batch. “obs”, “act”, “rew”, “terminated”, “truncated” are required keys.

  • buffer_ids – to make consistent with other buffer’s add function; if it is not None, we assume the input batch’s first dimension is always 1.

Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

get_weight(index: int | ndarray) float | ndarray[source]#

Get the importance sampling weight.

The “weight” in the returned Batch is the weight on loss function to debias the sampling process (some transition tuples are sampled more often so their losses are weighted less).

init_weight(index: int | ndarray) None[source]#
sample_indices(batch_size: int | None) ndarray[source]#

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

Parameters:

batch_size – the number of indices to be sampled. If None, it will be set to the length of the buffer (i.e. return all available indices in a random order).

set_beta(beta: float) None[source]#
update(buffer: ReplayBuffer) ndarray[source]#

Move the data from the given buffer to current buffer.

Return the updated indices. If update fails, return an empty array.

update_weight(index: ndarray, new_weight: ndarray | Tensor) None[source]#

Update priority weight by index in this buffer.

Parameters:
  • index (np.ndarray) – index you want to update weight.

  • new_weight (np.ndarray) – new priority weight you want to update.