tianshou.data¶
-
class
tianshou.data.
Batch
(**kwargs)[source]¶ Bases:
object
Tianshou provides
Batch
as the internal data structure to pass any kind of data to other methods, for example, a collector gives aBatch
to policy for learning. Here is the usage:>>> import numpy as np >>> from tianshou.data import Batch >>> data = Batch(a=4, b=[5, 5], c='2312312') >>> data.b [5, 5] >>> data.b = np.array([3, 4, 5]) >>> print(data) Batch( a: 4, b: array([3, 4, 5]), c: '2312312', )
In short, you can define a
Batch
with any key-value pair. The current implementation of Tianshou typically use 7 reserved keys inBatch
:obs
the observation of step ;act
the action of step ;rew
the reward of step ;done
the done flag of step ;obs_next
the observation of step ;info
the info of step (ingym.Env
, theenv.step()
function return 4 arguments, and the last one isinfo
);policy
the data computed by policy in step ;
Batch
has other methods, including__getitem__()
,__len__()
,append()
, andsplit()
:>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6])) >>> # here we test __getitem__ >>> index = [2, 1] >>> data[index].obs array([22, 11]) >>> # here we test __len__ >>> len(data) 3 >>> data.append(data) # similar to list.append >>> data.obs array([0, 11, 22, 0, 11, 22]) >>> # split whole data into multiple small batch >>> for d in data.split(size=2, shuffle=False): ... print(d.obs, d.rew) [ 0 11] [6 6] [22 0] [6 6] [11 22] [6 6]
-
__getitem__
(index: Union[str, slice]) → Union[tianshou.data.batch.Batch, dict][source]¶ Return self[index].
-
get
(k: str, d: Optional[Any] = None) → Union[tianshou.data.batch.Batch, Any][source]¶ Return self[k] if k in self else d. d defaults to None.
-
split
(size: Optional[int] = None, shuffle: bool = True) → Iterator[tianshou.data.batch.Batch][source]¶ Split whole data into multiple small batch.
- Parameters
size (int) – if it is
None
, it does not split the data batch; otherwise it will divide the data batch with the given size. Default toNone
.shuffle (bool) – randomly shuffle the entire data batch if it is
True
, otherwise remain in the same. Default toTrue
.
-
tianshou.data.
to_numpy
(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray]) → Union[dict, tianshou.data.batch.Batch, numpy.ndarray][source]¶ Return an object without torch.Tensor.
-
tianshou.data.
to_torch
(x: Union[torch.Tensor, dict, tianshou.data.batch.Batch, numpy.ndarray], dtype: Optional[torch.dtype] = None, device: Union[str, int] = 'cpu') → Union[dict, tianshou.data.batch.Batch, torch.Tensor][source]¶ Return an object without np.ndarray.
-
class
tianshou.data.
ReplayBuffer
(size: int, stack_num: Optional[int] = 0, ignore_obs_next: bool = False, **kwargs)[source]¶ Bases:
object
ReplayBuffer
stores data generated from interaction between the policy and environment. It stores basically 7 types of data, as mentioned inBatch
, based onnumpy.ndarray
. Here is the usage:>>> import numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> len(buf) 3 >>> buf.obs # since we set size = 20, len(buf.obs) == 20. array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> len(buf2) 10 >>> buf2.obs # since its size = 10, it only stores the last 10 steps' result. array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) >>> # move buf2's result into buf (meanwhile keep it chronologically) >>> buf.update(buf2) array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 0., 0., 0., 0., 0., 0., 0.]) >>> # get a random sample from buffer >>> # the batch_data is equal to buf[incide]. >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True])
ReplayBuffer
also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38, need version >= 0.2.3):>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 ... buf.add(obs={'id': i}, act=i, rew=i, done=done, ... obs_next={'id': i + 1}) >>> print(buf) # you can see obs_next is not saved in buf ReplayBuffer( act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), info: Batch(), obs: Batch( id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ), policy: Batch(), rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ) >>> index = np.arange(len(buf)) >>> print(buf.get(index, 'obs').id) [[ 7. 7. 8. 9.] [ 7. 8. 9. 10.] [11. 11. 11. 11.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [ 7. 7. 7. 7.] [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() 0.0 >>> # we can get obs_next through __getitem__, even if it doesn't exist >>> print(buf[:].obs_next.id) [[ 7. 8. 9. 10.] [ 7. 8. 9. 10.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [12. 13. 14. 15.] [ 7. 7. 7. 8.] [ 7. 7. 8. 9.]]
-
__getitem__
(index: Union[slice, numpy.ndarray]) → tianshou.data.batch.Batch[source]¶ Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, …].
-
add
(obs: Union[dict, numpy.ndarray], act: Union[numpy.ndarray, float], rew: float, done: bool, obs_next: Union[dict, numpy.ndarray, None] = None, info: dict = {}, policy: Union[dict, tianshou.data.batch.Batch, None] = {}, **kwargs) → None[source]¶ Add a batch of data into replay buffer.
-
get
(indice: Union[slice, numpy.ndarray], key: str, stack_num: Optional[int] = None) → Union[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure.
-
-
class
tianshou.data.
ListReplayBuffer
(**kwargs)[source]¶ Bases:
tianshou.data.buffer.ReplayBuffer
The function of
ListReplayBuffer
is almost the same asReplayBuffer
. The only difference is thatListReplayBuffer
is based onlist
.See also
Please refer to
ReplayBuffer
for more detailed explanation.
-
class
tianshou.data.
PrioritizedReplayBuffer
(size: int, alpha: float, beta: float, mode: str = 'weight', **kwargs)[source]¶ Bases:
tianshou.data.buffer.ReplayBuffer
Prioritized replay buffer implementation.
- Parameters
alpha (float) – the prioritization exponent.
beta (float) – the importance sample soft coefficient.
mode (str) – defaults to
weight
.
See also
Please refer to
ReplayBuffer
for more detailed explanation.-
__getitem__
(index: Union[slice, numpy.ndarray]) → tianshou.data.batch.Batch[source]¶ Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, …].
-
add
(obs: Union[dict, numpy.ndarray], act: Union[numpy.ndarray, float], rew: float, done: bool, obs_next: Union[dict, numpy.ndarray, None] = None, info: dict = {}, policy: Union[dict, tianshou.data.batch.Batch, None] = {}, weight: float = 1.0, **kwargs) → None[source]¶ Add a batch of data into replay buffer.
-
class
tianshou.data.
Collector
(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.vecenv.BaseVectorEnv], buffer: Union[tianshou.data.buffer.ReplayBuffer, List[tianshou.data.buffer.ReplayBuffer], None] = None, preprocess_fn: Callable[[Any], Union[dict, tianshou.data.batch.Batch]] = None, stat_size: Optional[int] = 100, **kwargs)[source]¶ Bases:
object
The
Collector
enables the policy to interact with different types of environments conveniently.- Parameters
policy – an instance of the
BasePolicy
class.env – a
gym.Env
environment or an instance of theBaseVectorEnv
class.buffer – an instance of the
ReplayBuffer
class, or a list ofReplayBuffer
. If set toNone
, it will automatically assign a small-sizeReplayBuffer
.preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42, defaults to
None
.stat_size (int) – for the moving average of recording speed, defaults to 100.
The
preprocess_fn
is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed inBatch
. It will receive with onlyobs
when the collector resets the environment. It returns either a dict or aBatch
with the modified keys and values. Examples are in “test/base/test_collector.py”.Example:
policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer()
For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation.
Note
Please make sure the given environment has a time limitation.
-
collect
(n_step: int = 0, n_episode: Union[int, List[int]] = 0, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None) → Dict[str, float][source]¶ Collect a specified number of step or episode.
- Parameters
n_step (int) – how many steps you want to collect.
n_episode (int or list) – how many episodes you want to collect (in each environment).
render (float) – the sleep time between rendering consecutive frames, defaults to
None
(no rendering).log_fn (function) – a function which receives env info, typically for tensorboard logging.
Note
One and only one collection number specification is permitted, either
n_step
orn_episode
.- Returns
A dict including the following keys
n/ep
the collected number of episodes.n/st
the collected number of steps.v/st
the speed of steps per second.v/ep
the speed of episode per second.rew
the mean reward over collected episodes.len
the mean length over collected episodes.
-
reset_env
() → None[source]¶ Reset all of the environment(s)’ states and reset all of the cache buffers (if need).
-
sample
(batch_size: int) → tianshou.data.batch.Batch[source]¶ Sample a data batch from the internal replay buffer. It will call
process_fn()
before returning the final batch data.- Parameters
batch_size (int) –
0
means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size.