%%capture

import pickle

import numpy as np

from tianshou.data import Batch, ReplayBuffer

Buffer#

Replay Buffer is a very common module in DRL implementations. In Tianshou, you can consider Buffer module as as a specialized form of Batch, which helps you track all data trajectories and provide utilities such as sampling method besides the basic storage.

There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial Vectorized Environment). In this tutorial, we will focus on ReplayBuffer.

Usages#

Basic usages as a batch#

Usually a buffer stores all the data in a batch with circular-queue style.

# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).
print("========================================")
dummy_buf = ReplayBuffer(size=10)
print(dummy_buf)
print(f"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}")

# add 3 steps of data into ReplayBuffer sequentially
print("========================================")
for i in range(3):
    dummy_buf.add(
        Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),
    )
print(dummy_buf)
print(f"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}")

# add another 10 steps of data into ReplayBuffer sequentially
print("========================================")
for i in range(3, 13):
    dummy_buf.add(
        Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),
    )
print(dummy_buf)
print(f"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}")
========================================
ReplayBuffer()
maxsize: 10, data length: 0
========================================
ReplayBuffer(
    obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),
    terminated: array([False, False, False, False, False, False, False, False, False,
                       False]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False, False, False, False, False, False, False, False,
                 False]),
    obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),
    info: Batch(),
)
maxsize: 10, data length: 3
========================================
ReplayBuffer(
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    act: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    terminated: array([False, False, False, False, False, False, False, False, False,
                       False]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False, False, False, False, False, False, False, False,
                 False]),
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
    info: Batch(),
)
maxsize: 10, data length: 10

Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc.

print(dummy_buf[-1])
print(dummy_buf[-3:])
# Try more methods you find useful in Batch yourself.
Batch(
    obs: array(9),
    act: array(9),
    rew: array(9.),
    terminated: array(False),
    truncated: array(False),
    done: array(False),
    obs_next: array(10),
    info: Batch(),
    policy: Batch(),
)
Batch(
    obs: array([7, 8, 9]),
    act: array([7, 8, 9]),
    rew: array([7., 8., 9.]),
    terminated: array([False, False, False]),
    truncated: array([False, False, False]),
    done: array([False, False, False]),
    obs_next: array([ 8,  9, 10]),
    info: Batch(),
    policy: Batch(),
)

ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings.

_dummy_buf = pickle.loads(pickle.dumps(dummy_buf))

Understanding reserved keys for buffer#

As explained above, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain nine reserved keys in Batch.

  • obs

  • act

  • rew

  • terminated

  • truncated

  • done

  • obs_next

  • info

  • policy

The meaning of these nine reserved keys are consistent with the meaning in Gymansium. We would recommend you simply use these nine keys when adding batched data into ReplayBuffer, because some of them are tracked in ReplayBuffer (e.g. “done” value is tracked to help us determine a trajectory’s start index and end index, together with its total reward and episode length.)

buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.
buf.add(Batch(......, info={"extro_info":0})) # Recommended.

Data sampling#

We keep a replay buffer in DRL for one purpose: sample data from it for training. ReplayBuffer.sample() and ReplayBuffer.split(..., shuffle=True) can both fulfill this need.

dummy_buf.sample(batch_size=5)
(Batch(
     obs: array([3, 7, 7, 9, 8]),
     act: array([3, 7, 7, 9, 8]),
     rew: array([3., 7., 7., 9., 8.]),
     terminated: array([False, False, False, False, False]),
     truncated: array([False, False, False, False, False]),
     done: array([False, False, False, False, False]),
     obs_next: array([ 4,  8,  8, 10,  9]),
     info: Batch(),
     policy: Batch(),
 ),
 array([3, 7, 7, 9, 8]))

Trajectory tracking#

Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.

First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet.

trajectory_buffer = ReplayBuffer(size=10)
# Add the first trajectory (length is 3) into ReplayBuffer
print("========================================")
for i in range(3):
    result = trajectory_buffer.add(
        Batch(
            obs=i,
            act=i,
            rew=i,
            terminated=1 if i == 2 else 0,
            truncated=0,
            done=i == 2,
            obs_next=i + 1,
            info={},
        ),
    )
    print(result)
print(trajectory_buffer)
print(f"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}")

# Add the second trajectory (length is 5) into ReplayBuffer
print("========================================")
for i in range(3, 8):
    result = trajectory_buffer.add(
        Batch(
            obs=i,
            act=i,
            rew=i,
            terminated=1 if i == 7 else 0,
            truncated=0,
            done=i == 7,
            obs_next=i + 1,
            info={},
        ),
    )
    print(result)
print(trajectory_buffer)
print(f"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}")

# Add the third trajectory (length is 5, still not finished) into ReplayBuffer
print("========================================")
for i in range(8, 13):
    result = trajectory_buffer.add(
        Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=False, obs_next=i + 1, info={}),
    )
    print(result)
print(trajectory_buffer)
print(f"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}")
========================================
(array([0]), array([0.]), array([0]), array([0]))
(array([1]), array([0.]), array([0]), array([0]))
(array([2]), array([3.]), array([3]), array([0]))
ReplayBuffer(
    obs: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    act: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0]),
    rew: array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0.]),
    terminated: array([False, False,  True, False, False, False, False, False, False,
                       False]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False,  True, False, False, False, False, False, False,
                 False]),
    obs_next: array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0]),
    info: Batch(),
)
maxsize: 10, data length: 3
========================================
(array([3]), array([0.]), array([0]), array([3]))
(array([4]), array([0.]), array([0]), array([3]))
(array([5]), array([0.]), array([0]), array([3]))
(array([6]), array([0.]), array([0]), array([3]))
(array([7]), array([25.]), array([5]), array([3]))
ReplayBuffer(
    obs: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),
    act: array([0, 1, 2, 3, 4, 5, 6, 7, 0, 0]),
    rew: array([0., 1., 2., 3., 4., 5., 6., 7., 0., 0.]),
    terminated: array([False, False,  True, False, False, False, False,  True, False,
                       False]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False,  True, False, False, False, False,  True, False,
                 False]),
    obs_next: array([1, 2, 3, 4, 5, 6, 7, 8, 0, 0]),
    info: Batch(),
)
maxsize: 10, data length: 8
========================================
(array([8]), array([0.]), array([0]), array([8]))
(array([9]), array([0.]), array([0]), array([8]))
(array([0]), array([0.]), array([0]), array([8]))
(array([1]), array([0.]), array([0]), array([8]))
(array([2]), array([0.]), array([0]), array([8]))
ReplayBuffer(
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    act: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    terminated: array([False, False, False, False, False, False, False,  True, False,
                       False]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False, False, False, False, False, False,  True, False,
                 False]),
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
    info: Batch(),
)
maxsize: 10, data length: 10

Episode length and rewards tracking#

Notice that ReplayBuffer.add() returns a tuple of 4 numbers every time, meaning (current_index, episode_reward, episode_length, episode_start_index). episode_reward and episode_length are valid only when a trajectory is finished. This might save developers some trouble.

Episode index management#

In the ReplayBuffer above, we can get access to any data step by indexing.

print(trajectory_buffer)
print("========================================")

data = trajectory_buffer[6]
print(data)
ReplayBuffer(
    obs: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    act: array([10, 11, 12,  3,  4,  5,  6,  7,  8,  9]),
    rew: array([10., 11., 12.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]),
    terminated: array([False, False, False, False, False, False, False,  True, False,
                       False]),
    truncated: array([False, False, False, False, False, False, False, False, False,
                      False]),
    done: array([False, False, False, False, False, False, False,  True, False,
                 False]),
    obs_next: array([11, 12, 13,  4,  5,  6,  7,  8,  9, 10]),
    info: Batch(),
)
========================================
Batch(
    obs: array(6),
    act: array(6),
    rew: array(6.),
    terminated: array(False),
    truncated: array(False),
    done: array(False),
    obs_next: array(7),
    info: Batch(),
    policy: Batch(),
)

We know that step “6” is not the start of an episode - which should be step “3”, since “3-7” is the second trajectory we add into the ReplayBuffer - but we wonder how do we get the earliest index of that episode.

This may seem easy but actually it is not. We cannot simply look at the “done” flag preceding the start of a new episode, because since the third-added trajectory is not finished yet, step “3” is surrounded by flag “False”. There are many things to consider. Things could get more nasty when using more advanced ReplayBuffer like VectorReplayBuffer, since it does not store the data in a simple circular-queue.

Luckily, all ReplayBuffer instances help you identify step indexes through a unified API. One can simply input an array of indexes and look for their previous index in the episode.

# previous step of indexes [0, 1, 2, 3, 4, 5, 6] are:
print(trajectory_buffer.prev(np.array([0, 1, 2, 3, 4, 5, 6])))
[9 0 1 3 3 4 5]

Using ReplayBuffer.prev(), we know that the earliest step of that episode is step “3”. Similarly, ReplayBuffer.next() helps us identify the last index of an episode regardless of which kind of ReplayBuffer we are using.

# next step of indexes [4,5,6,7,8,9] are:
print(trajectory_buffer.next(np.array([4, 5, 6, 7, 8, 9])))
[5 6 7 7 9 0]

We can also search for the indexes which are labeled “done: False”, but are the last step in a trajectory.

print(trajectory_buffer.unfinished_index())
[2]

Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms (Example usage in Tianshou). The unified APIs ensure a modular design and a flexible interface.

Further Reading#

Other Buffer Module#

  • PrioritizedReplayBuffer, which helps you implement prioritized experience replay

  • CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)

  • ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).

Check the documentation and the source code for more details.

Support for steps stacking to use RNN in DRL.#

There is an option called stack_num (default to 1) when initializing the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details.