tianshou.data¶
-
class
tianshou.data.
Batch
(batch_dict: Optional[Union[dict, Batch, Tuple[Union[dict, Batch]], List[Union[dict, Batch]], numpy.ndarray]] = None, copy: bool = False, **kwargs)[source]¶ Bases:
object
Tianshou provides
Batch
as the internal data structure to pass any kind of data to other methods, for example, a collector gives aBatch
to policy for learning.For a detailed description, please refer to Understand Batch.
-
__getitem__
(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]]) → tianshou.data.batch.Batch[source]¶ Return self[index].
-
__setitem__
(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]], value: Any) → None[source]¶ Assign value to self[index].
-
static
cat
(batches: List[Union[dict, Batch]]) → tianshou.data.batch.Batch[source]¶ Concatenate a list of
Batch
object into a single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
-
cat_
(batches: Union[Batch, List[Union[dict, Batch]]]) → None[source]¶ Concatenate a list of (or one)
Batch
objects into current batch.
-
static
empty
(batch: tianshou.data.batch.Batch, index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]¶ Return an empty
Batch
object with 0 orNone
filled, the shape is the same as the givenBatch
.
-
empty_
(index: Union[str, slice, int, numpy.integer, numpy.ndarray, List[int]] = None) → tianshou.data.batch.Batch[source]¶ Return an empty a
Batch
object with 0 orNone
filled. Ifindex
is specified, it will only reset the specific indexed-data.>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
-
get
(k: str, d: Optional[Any] = None) → Any[source]¶ Return self[k] if k in self else d. d defaults to None.
-
is_empty
(recurse: bool = False)[source]¶ Test if a Batch is empty. If
recurse=True
, it further tests the values of the object; else it only tests the existence of any key.b.is_empty(recurse=True)
is mainly used to distinguishBatch(a=Batch(a=Batch()))
andBatch(a=1)
. They both raise exceptions when applied tolen()
, but the former can be used incat
, while the latter is a scalar and cannot be used incat
.Another usage is in
__len__
, where we have to skip checking the length of recursively empty Batch.>>> Batch().is_empty() True >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() False >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) True >>> Batch(d=1).is_empty() False >>> Batch(a=np.float64(1.0)).is_empty() False
-
pop
(k: str, d: Optional[Any] = None) → Any[source]¶ Return and remove self[k] if k in self else d. d defaults to None.
-
property
shape
¶ Return self.shape.
-
split
(size: int, shuffle: bool = True, merge_last: bool = False) → Iterator[tianshou.data.batch.Batch][source]¶ Split whole data into multiple small batches.
- Parameters
size (int) – divide the data batch with the given size, but one batch if the length of the batch is smaller than
size
.shuffle (bool) – randomly shuffle the entire data batch if it is
True
, otherwise remain in the same. Default toTrue
.merge_last (bool) – merge the last batch into the previous one. Default to
False
.
-
static
stack
(batches: List[Union[dict, Batch]], axis: int = 0) → tianshou.data.batch.Batch[source]¶ Stack a list of
Batch
object into a single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stack
withaxis != 0
is undefined, and will cause an exception.
-
stack_
(batches: List[Union[dict, Batch]], axis: int = 0) → None[source]¶ Stack a list of
Batch
object into current batch.
-
to_torch
(dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → None[source]¶ Change all numpy.ndarray to torch.Tensor. This is an in-place operation.
-
-
class
tianshou.data.
Collector
(policy: tianshou.policy.base.BasePolicy, env: Union[gym.core.Env, tianshou.env.venvs.BaseVectorEnv], buffer: Optional[tianshou.data.buffer.ReplayBuffer] = None, preprocess_fn: Callable[[Any], tianshou.data.batch.Batch] = None, action_noise: Optional[tianshou.exploration.random.BaseNoise] = None, reward_metric: Optional[Callable[[numpy.ndarray], float]] = None)[source]¶ Bases:
object
The
Collector
enables the policy to interact with different types of environments conveniently.- Parameters
policy – an instance of the
BasePolicy
class.env – a
gym.Env
environment or an instance of theBaseVectorEnv
class.buffer – an instance of the
ReplayBuffer
class. If set toNone
(testing phase), it will not store the data.preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42 and Handle Batched Data Stream in Collector, defaults to
None
.action_noise (BaseNoise) – add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose.
reward_metric (function) – to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1.
The
preprocess_fn
is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed inBatch
. It will receive with onlyobs
when the collector resets the environment. It returns either a dict or aBatch
with the modified keys and values. Examples are in “test/base/test_collector.py”.Example:
policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=replay_buffer) # collect 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03)
Collected data always consist of full episodes. So if only
n_step
argument is give, the collector may return the data more than then_step
limitation. Same asn_episode
for the multiple environment case.Note
Please make sure the given environment has a time limitation.
-
collect
(n_step: Optional[int] = None, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True) → Dict[str, float][source]¶ Collect a specified number of step or episode.
- Parameters
n_step (int) – how many steps you want to collect.
n_episode – how many episodes you want to collect. If it is an int, it means to collect at lease
n_episode
episodes; if it is a list, it means to collect exactlyn_episode[i]
episodes in the i-th environmentrandom (bool) – whether to use random policy for collecting data, defaults to
False
.render (float) – the sleep time between rendering consecutive frames, defaults to
None
(no rendering).no_grad (bool) – whether to retain gradient in policy.forward, defaults to
True
(no gradient retaining).
Note
One and only one collection number specification is permitted, either
n_step
orn_episode
.- Returns
A dict including the following keys
n/ep
the collected number of episodes.n/st
the collected number of steps.v/st
the speed of steps per second.v/ep
the speed of episode per second.rew
the mean reward over collected episodes.len
the mean length over collected episodes.
-
reset_env
() → None[source]¶ Reset all of the environment(s)’ states and reset all of the cache buffers (if need).
-
sample
(batch_size: int) → tianshou.data.batch.Batch[source]¶ Sample a data batch from the internal replay buffer. It will call
process_fn()
before returning the final batch data.- Parameters
batch_size (int) –
0
means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size.
-
class
tianshou.data.
ListReplayBuffer
(**kwargs)[source]¶ Bases:
tianshou.data.buffer.ReplayBuffer
The function of
ListReplayBuffer
is almost the same asReplayBuffer
. The only difference is thatListReplayBuffer
is based onlist
. Therefore, it does not support advanced indexing, which means you cannot sample a batch of data out of it. It is typically used for storing data.See also
Please refer to
ReplayBuffer
for more detailed explanation.
-
class
tianshou.data.
PrioritizedReplayBuffer
(size: int, alpha: float, beta: float, **kwargs)[source]¶ Bases:
tianshou.data.buffer.ReplayBuffer
Implementation of Prioritized Experience Replay. arXiv:1511.05952
- Parameters
alpha (float) – the prioritization exponent.
beta (float) – the importance sample soft coefficient.
See also
Please refer to
ReplayBuffer
for more detailed explanation.-
__getitem__
(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]¶ Return a data batch: self[index]. If stack_num is larger than 1, return the stacked obs and obs_next with shape [batch, len, …].
-
add
(obs: Union[dict, tianshou.data.batch.Batch, numpy.ndarray, float], act: Union[dict, tianshou.data.batch.Batch, numpy.ndarray, float], rew: Union[int, float], done: Union[bool, int], obs_next: Optional[Union[dict, tianshou.data.batch.Batch, numpy.ndarray, float]] = None, info: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, weight: Optional[float] = None, **kwargs) → None[source]¶ Add a batch of data into replay buffer.
-
sample
(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Get a random sample from buffer with priority probability. Return all the data in the buffer if batch_size is
0
.- Returns
Sample data and its corresponding index inside the buffer.
The
weight
in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less).
-
class
tianshou.data.
ReplayBuffer
(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False)[source]¶ Bases:
object
ReplayBuffer
stores data generated from interaction between the policy and environment. The current implementation of Tianshou typically use 7 reserved keys inBatch
:obs
the observation of step \(t\) ;act
the action of step \(t\) ;rew
the reward of step \(t\) ;done
the done flag of step \(t\) ;obs_next
the observation of step \(t+1\) ;info
the info of step \(t\) (ingym.Env
, theenv.step()
function returns 4 arguments, and the last one isinfo
);policy
the data computed by policy in step \(t\);
The following code snippet illustrates its usage:
>>> import pickle, numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> buf.obs # since we set size = 20, len(buf.obs) == 20. array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 >>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl" >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> len(buf2) 10 >>> buf2.obs # since its size = 10, it only stores the last 10 steps' result. array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) >>> # move buf2's result into buf (meanwhile keep it chronologically) >>> buf.update(buf2) array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 0., 0., 0., 0., 0., 0., 0.]) >>> # get a random sample from buffer >>> # the batch_data is equal to buf[incide]. >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) >>> len(buf) 13 >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" >>> len(buf) 3
ReplayBuffer
also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38):>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 ... buf.add(obs={'id': i}, act=i, rew=i, done=done, ... obs_next={'id': i + 1}) >>> print(buf) # you can see obs_next is not saved in buf ReplayBuffer( act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), info: Batch(), obs: Batch( id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ), policy: Batch(), rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ) >>> index = np.arange(len(buf)) >>> print(buf.get(index, 'obs').id) [[ 7. 7. 8. 9.] [ 7. 8. 9. 10.] [11. 11. 11. 11.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [ 7. 7. 7. 7.] [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() 0.0 >>> # we can get obs_next through __getitem__, even if it doesn't exist >>> print(buf[:].obs_next.id) [[ 7. 8. 9. 10.] [ 7. 8. 9. 10.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [12. 13. 14. 15.] [ 7. 7. 7. 8.] [ 7. 7. 8. 9.]]
- Parameters
size (int) – the size of replay buffer.
stack_num (int) – the frame-stack sampling argument, should be greater than or equal to 1, defaults to 1 (no stacking).
ignore_obs_next (bool) – whether to store obs_next, defaults to
False
.save_only_last_obs (bool) – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking, defaults to
False
.sample_avail (bool) – the parameter indicating sampling only available index when using frame-stack sampling method, defaults to
False
. This feature is not supported in Prioritized Replay Buffer currently.
-
__getitem__
(index: Union[slice, int, numpy.integer, numpy.ndarray]) → tianshou.data.batch.Batch[source]¶ Return a data batch: self[index]. If stack_num is larger than 1, return the stacked obs and obs_next with shape [batch, len, …].
-
add
(obs: Union[dict, tianshou.data.batch.Batch, numpy.ndarray, float], act: Union[dict, tianshou.data.batch.Batch, numpy.ndarray, float], rew: Union[int, float], done: Union[bool, int], obs_next: Optional[Union[dict, tianshou.data.batch.Batch, numpy.ndarray, float]] = None, info: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, policy: Optional[Union[dict, tianshou.data.batch.Batch]] = {}, **kwargs) → None[source]¶ Add a batch of data into replay buffer.
-
get
(indice: Union[slice, int, numpy.integer, numpy.ndarray], key: str, stack_num: Optional[int] = None) → Union[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure.
-
sample
(batch_size: int) → Tuple[tianshou.data.batch.Batch, numpy.ndarray][source]¶ Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is
0
.- Returns
Sample data and its corresponding index inside the buffer.
-
property
stack_num
¶
-
class
tianshou.data.
SegmentTree
(size: int)[source]¶ Bases:
object
Implementation of Segment Tree: store an array
arr
with sizen
in a segment tree, support value update and fast query of the sum for the interval[left, right)
in O(log n) time.The detailed procedure is as follows:
Pad the array to have length of power of 2, so that leaf nodes in the segment tree have the same depth.
Store the segment tree in a binary heap.
- Parameters
size (int) – the size of segment tree.
-
__getitem__
(index: Union[int, numpy.ndarray]) → Union[float, numpy.ndarray][source]¶ Return self[index]
-
__setitem__
(index: Union[int, numpy.ndarray], value: Union[float, numpy.ndarray]) → None[source]¶ Duplicate values in
index
are handled by numpy: later index overwrites previous ones.>>> a = np.array([1, 2, 3, 4]) >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] >>> print(a) [6 7 3 4]
-
get_prefix_sum_idx
(value: Union[float, numpy.ndarray]) → Union[int, numpy.ndarray][source]¶ Return the minimum index for each
v
invalue
so that \(v \le \mathrm{sums}_i\), where \(\mathrm{sums}_i = \sum_{j=0}^{i} \mathrm{arr}_j\).Warning
Please make sure all of the values inside the segment tree are non-negative when using this function.
-
tianshou.data.
to_numpy
(x: Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor]) → Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor][source]¶ Return an object without torch.Tensor.
-
tianshou.data.
to_torch
(x: Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor], dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') → Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor][source]¶ Return an object without np.ndarray.
-
tianshou.data.
to_torch_as
(x: Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor], y: torch.Tensor) → Union[tianshou.data.batch.Batch, dict, list, tuple, numpy.ndarray, torch.Tensor][source]¶ Return an object without np.ndarray. Same as
to_torch(x, dtype=y.dtype, device=y.device)
.