tianshou.data¶
-
class
tianshou.data.
Batch
(batch_dict: Optional[Union[dict, Batch, Sequence[Union[dict, Batch]], numpy.ndarray]] = None, copy: bool = False, **kwargs: Any)[source]¶ Bases:
object
The internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of object that can be either numpy array, torch tensor, or batch themself. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.
For a detailed description, please refer to Understand Batch.
-
__getitem__
(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]]) → Any[source]¶ Return self[index].
-
__setitem__
(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]], value: Any) → None[source]¶ Assign value to self[index].
-
static
cat
(batches: Sequence[Union[dict, Batch]]) → tianshou.data.batch.Batch[source]¶ Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
-
cat_
(batches: Union[Batch, Sequence[Union[dict, Batch]]]) → None[source]¶ Concatenate a list of (or one) Batch objects into current batch.
-
static
empty
(batch: tianshou.data.batch.Batch, index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]¶ Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
-
empty_
(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]¶ Return an empty Batch object with 0 or None filled.
If “index” is specified, it will only reset the specific indexed-data.
>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
-
is_empty
(recurse: bool = False) → bool[source]¶ Test if a Batch is empty.
If
recurse=True
, it further tests the values of the object; else it only tests the existence of any key.b.is_empty(recurse=True)
is mainly used to distinguishBatch(a=Batch(a=Batch()))
andBatch(a=1)
. They both raise exceptions when applied tolen()
, but the former can be used incat
, while the latter is a scalar and cannot be used incat
.Another usage is in
__len__
, where we have to skip checking the length of recursively empty Batch.>>> Batch().is_empty() True >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() False >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) True >>> Batch(d=1).is_empty() False >>> Batch(a=np.float64(1.0)).is_empty() False
-
property
shape
¶ Return self.shape.
-
split
(size: int, shuffle: bool = True, merge_last: bool = False) → Iterator[tianshou.data.batch.Batch][source]¶ Split whole data into multiple small batches.
- Parameters
size (int) – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”.
shuffle (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.
merge_last (bool) – merge the last batch into the previous one. Default to False.
-
static
stack
(batches: Sequence[Union[dict, Batch]], axis: int = 0) → tianshou.data.batch.Batch[source]¶ Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stack
withaxis != 0
is undefined, and will cause an exception.
-
stack_
(batches: Sequence[Union[dict, Batch]], axis: int = 0) → None[source]¶ Stack a list of Batch object into current batch.
-
-
class
tianshou.data.
Collector
(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.venvs.BaseVectorEnv], buffer: Optional[tianshou.data.buffer.ReplayBuffer] = None, preprocess_fn: Optional[Callable[[…], tianshou.data.batch.Batch]] = None, action_noise: Optional[tianshou.exploration.random.BaseNoise] = None, reward_metric: Optional[Callable[[numpy.ndarray], float]] = None)[source]¶ Bases:
object
Collector enables the policy to interact with different types of envs.
- Parameters
policy – an instance of the
BasePolicy
class.env – a
gym.Env
environment or an instance of theBaseVectorEnv
class.buffer – an instance of the
ReplayBuffer
class. If set toNone
(testing phase), it will not store the data.preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42 and Handle Batched Data Stream in Collector, defaults to None.
action_noise (BaseNoise) – add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose.
reward_metric (function) – to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1.
The
preprocess_fn
is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed inBatch
. It will receive with onlyobs
when the collector resets the environment. It returns either a dict or aBatch
with the modified keys and values. Examples are in “test/base/test_collector.py”.Here is the example:
policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=replay_buffer) # collect 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03)
Collected data always consist of full episodes. So if only
n_step
argument is give, the collector may return the data more than then_step
limitation. Same asn_episode
for the multiple environment case.Note
Please make sure the given environment has a time limitation.
-
collect
(n_step: Optional[int] = None, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True) → Dict[str, float][source]¶ Collect a specified number of step or episode.
- Parameters
n_step (int) – how many steps you want to collect.
n_episode – how many episodes you want to collect. If it is an int, it means to collect at lease
n_episode
episodes; if it is a list, it means to collect exactlyn_episode[i]
episodes in the i-th environmentrandom (bool) – whether to use random policy for collecting data, defaults to False.
render (float) – the sleep time between rendering consecutive frames, defaults to None (no rendering).
no_grad (bool) – whether to retain gradient in policy.forward, defaults to True (no gradient retaining).
Note
One and only one collection number specification is permitted, either
n_step
orn_episode
.- Returns
A dict including the following keys
n/ep
the collected number of episodes.n/st
the collected number of steps.v/st
the speed of steps per second.v/ep
the speed of episode per second.rew
the mean reward over collected episodes.len
the mean length over collected episodes.
-
class
tianshou.data.
ListReplayBuffer
(**kwargs: Any)[source]¶ Bases:
tianshou.data.buffer.ReplayBuffer
List-based replay buffer.
The function of
ListReplayBuffer
is almost the same asReplayBuffer
. The only difference is thatListReplayBuffer
is based on list. Therefore, it does not support advanced indexing, which means you cannot sample a batch of data out of it. It is typically used for storing data.See also
Please refer to
ReplayBuffer
for more detailed explanation.
-
class
tianshou.data.
PrioritizedReplayBuffer
(size: int, alpha: float, beta: float, **kwargs: Any)[source]¶ Bases:
tianshou.data.buffer.ReplayBuffer
Implementation of Prioritized Experience Replay. arXiv:1511.05952.
- Parameters
alpha (float) – the prioritization exponent.
beta (float) – the importance sample soft coefficient.
See also
Please refer to
ReplayBuffer
for more detailed explanation.-
__getitem__
(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]¶ Return a data batch: self[index].
If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, …).
-
add
(obs: Any, act: Any, rew: Union[numbers.Number, numpy.number, numpy.ndarray], done: Union[numbers.Number, numpy.number, numpy.bool_], obs_next: Any = None, info: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, weight: Optional[Union[numbers.Number, numpy.number]] = None, **kwargs: Any) → None[source]¶ Add a batch of data into replay buffer.
-
sample
(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Get a random sample from buffer with priority probability.
Return all the data in the buffer if batch_size is 0.
- Returns
Sample data and its corresponding index inside the buffer.
The “weight” in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less).
-
class
tianshou.data.
ReplayBuffer
(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False)[source]¶ Bases:
object
ReplayBuffer
stores data generated from interaction between the policy and environment.The current implementation of Tianshou typically use 7 reserved keys in
Batch
:obs
the observation of step \(t\) ;act
the action of step \(t\) ;rew
the reward of step \(t\) ;done
the done flag of step \(t\) ;obs_next
the observation of step \(t+1\) ;info
the info of step \(t\) (ingym.Env
, theenv.step()
function returns 4 arguments, and the last one isinfo
);policy
the data computed by policy in step \(t\);
The following code snippet illustrates its usage:
>>> import pickle, numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> buf.obs # since we set size = 20, len(buf.obs) == 20. array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 >>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl" >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> len(buf2) 10 >>> buf2.obs # since its size = 10, it only stores the last 10 steps' result. array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) >>> # move buf2's result into buf (meanwhile keep it chronologically) >>> buf.update(buf2) array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 0., 0., 0., 0., 0., 0., 0.]) >>> # get a random sample from buffer >>> # the batch_data is equal to buf[incide]. >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) >>> len(buf) 13 >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" >>> len(buf) 3
ReplayBuffer
also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38):>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 ... buf.add(obs={'id': i}, act=i, rew=i, done=done, ... obs_next={'id': i + 1}) >>> print(buf) # you can see obs_next is not saved in buf ReplayBuffer( act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), info: Batch(), obs: Batch( id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ), policy: Batch(), rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ) >>> index = np.arange(len(buf)) >>> print(buf.get(index, 'obs').id) [[ 7. 7. 8. 9.] [ 7. 8. 9. 10.] [11. 11. 11. 11.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [ 7. 7. 7. 7.] [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() 0.0 >>> # we can get obs_next through __getitem__, even if it doesn't exist >>> print(buf[:].obs_next.id) [[ 7. 8. 9. 10.] [ 7. 8. 9. 10.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [12. 13. 14. 15.] [ 7. 7. 7. 8.] [ 7. 7. 8. 9.]]
- Parameters
size (int) – the size of replay buffer.
stack_num (int) – the frame-stack sampling argument, should be greater than or equal to 1, defaults to 1 (no stacking).
ignore_obs_next (bool) – whether to store obs_next, defaults to False.
save_only_last_obs (bool) – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking, defaults to False.
sample_avail (bool) – the parameter indicating sampling only available index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently.
-
__getitem__
(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]¶ Return a data batch: self[index].
If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, …).
-
add
(obs: Any, act: Any, rew: Union[numbers.Number, numpy.number, numpy.ndarray], done: Union[numbers.Number, numpy.number, numpy.bool_], obs_next: Any = None, info: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, **kwargs: Any) → None[source]¶ Add a batch of data into replay buffer.
-
get
(indice: Union[slice, int, numpy.integer, numpy.ndarray], key: str, stack_num: Optional[int] = None) → Union[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the indice. The stack_num (here equals to 4) is given from buffer initialization procedure.
-
sample
(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Get a random sample from buffer with size equal to batch_size.
Return all the data in the buffer if batch_size is 0.
- Returns
Sample data and its corresponding index inside the buffer.
-
property
stack_num
¶
-
class
tianshou.data.
SegmentTree
(size: int)[source]¶ Bases:
object
Implementation of Segment Tree.
The segment tree stores an array
arr
with sizen
. It supports value update and fast query of the sum for the interval[left, right)
in O(log n) time. The detailed procedure is as follows:Pad the array to have length of power of 2, so that leaf nodes in the segment tree have the same depth.
Store the segment tree in a binary heap.
- Parameters
size (int) – the size of segment tree.
-
__getitem__
(index: Union[int, numpy.ndarray]) → Union[float, numpy.ndarray][source]¶ Return self[index].
-
__setitem__
(index: Union[int, numpy.ndarray], value: Union[float, numpy.ndarray]) → None[source]¶ Update values in segment tree.
Duplicate values in
index
are handled by numpy: later index overwrites previous ones.>>> a = np.array([1, 2, 3, 4]) >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] >>> print(a) [6 7 3 4]
-
get_prefix_sum_idx
(value: Union[float, numpy.ndarray]) → Union[int, numpy.ndarray][source]¶ Find the index with given value.
Return the minimum index for each
v
invalue
so that \(v \le \mathrm{sums}_i\), where \(\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j\).Warning
Please make sure all of the values inside the segment tree are non-negative when using this function.
-
tianshou.data.
to_numpy
(x: Optional[Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.number, numpy.bool_, numbers.Number, numpy.ndarray, torch.Tensor]]) → Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray][source]¶ Return an object without torch.Tensor.
-
tianshou.data.
to_torch
(x: Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.number, numpy.bool_, numbers.Number, numpy.ndarray, torch.Tensor], dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → Union[tianshou.data.batch.Batch, dict, list, tuple, torch.Tensor][source]¶ Return an object without np.ndarray.