batch#


class Batch(batch_dict: dict | BatchProtocol | Sequence[dict | BatchProtocol] | ndarray | None = None, copy: bool = False, **kwargs: Any)[source]#

The internal data structure in Tianshou.

Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.

For a detailed description, please refer to Understand Batch.

static cat(batches: Sequence[dict | TBatch]) TBatch[source]#

Concatenate a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.

>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
cat_(batches: BatchProtocol | Sequence[dict | BatchProtocol]) None[source]#

Concatenate a list of (or one) Batch objects into current batch.

static empty(batch: TBatch, index: slice | int | ndarray | list[int] | None = None) TBatch[source]#

Return an empty Batch object with 0 or None filled.

The shape is the same as the given Batch.

empty_(index: slice | int | ndarray | list[int] | None = None) Self[source]#

Return an empty Batch object with 0 or None filled.

If “index” is specified, it will only reset the specific indexed-data.

>>> data.empty_()
>>> print(data)
Batch(
    a: array([[0., 0.],
              [0., 0.]]),
    b: array([None, None], dtype=object),
)
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
>>> data = Batch(a=[False,  True], b=b)
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
    a: array([False,  True]),
    b: Batch(
           c: array([None, 'st']),
           d: array([0., 0.]),
       ),
)
is_empty(recurse: bool = False) bool[source]#

Test if a Batch is empty.

If recurse=True, it further tests the values of the object; else it only tests the existence of any key.

b.is_empty(recurse=True) is mainly used to distinguish Batch(a=Batch(a=Batch())) and Batch(a=1). They both raise exceptions when applied to len(), but the former can be used in cat, while the latter is a scalar and cannot be used in cat.

Another usage is in __len__, where we have to skip checking the length of recursively empty Batch.

>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
property shape: list[int]#

Return self.shape.

split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Self][source]#

Split whole data into multiple small batches.

Parameters:
  • size – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”. Size of -1 means the whole batch.

  • shuffle – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

  • merge_last – merge the last batch into the previous one. Default to False.

static stack(batches: Sequence[dict | TBatch], axis: int = 0) TBatch[source]#

Stack a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.

>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)

Note

If there are keys that are not shared across all batches, stack with axis != 0 is undefined, and will cause an exception.

stack_(batches: Sequence[dict | BatchProtocol], axis: int = 0) None[source]#

Stack a list of Batch object into current batch.

to_numpy() None[source]#

Change all torch.Tensor to numpy.ndarray in-place.

to_torch(dtype: dtype | None = None, device: str | int | device = 'cpu') None[source]#

Change all numpy.ndarray to torch.Tensor in-place.

update(batch: dict | Self | None = None, **kwargs: Any) None[source]#

Update this batch from another dict/Batch.

class BatchProtocol(*args, **kwargs)[source]#

The internal data structure in Tianshou.

Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.

For a detailed description, please refer to Understand Batch.

static cat(batches: Sequence[dict | TBatch]) TBatch[source]#

Concatenate a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.

>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
cat_(batches: Self | Sequence[dict | Self]) None[source]#

Concatenate a list of (or one) Batch objects into current batch.

static empty(batch: TBatch, index: slice | int | ndarray | list[int] | None = None) TBatch[source]#

Return an empty Batch object with 0 or None filled.

The shape is the same as the given Batch.

empty_(index: slice | int | ndarray | list[int] | None = None) Self[source]#

Return an empty Batch object with 0 or None filled.

If “index” is specified, it will only reset the specific indexed-data.

>>> data.empty_()
>>> print(data)
Batch(
    a: array([[0., 0.],
              [0., 0.]]),
    b: array([None, None], dtype=object),
)
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
>>> data = Batch(a=[False,  True], b=b)
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
    a: array([False,  True]),
    b: Batch(
           c: array([None, 'st']),
           d: array([0., 0.]),
       ),
)
is_empty(recurse: bool = False) bool[source]#
property shape: list[int]#
split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Self][source]#

Split whole data into multiple small batches.

Parameters:
  • size – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”. Size of -1 means the whole batch.

  • shuffle – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.

  • merge_last – merge the last batch into the previous one. Default to False.

static stack(batches: Sequence[dict | TBatch], axis: int = 0) TBatch[source]#

Stack a list of Batch object into a single new batch.

For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.

>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)

Note

If there are keys that are not shared across all batches, stack with axis != 0 is undefined, and will cause an exception.

stack_(batches: Sequence[dict | Self], axis: int = 0) None[source]#

Stack a list of Batch object into current batch.

to_numpy() None[source]#

Change all torch.Tensor to numpy.ndarray in-place.

to_torch(dtype: dtype | None = None, device: str | int | device = 'cpu') None[source]#

Change all numpy.ndarray to torch.Tensor in-place.

update(batch: dict | Self | None = None, **kwargs: Any) None[source]#

Update this batch from another dict/Batch.

alloc_by_keys_diff(meta: BatchProtocol, batch: BatchProtocol, size: int, stack: bool = True) None[source]#

Creates place-holders inside meta for keys that are in batch but not in meta.

This mainly is an internal method, use it only if you know what you are doing.

create_value(inst: Any, size: int, stack: bool = True) Batch | ndarray | Tensor[source]#

Create empty place-holders according to inst’s shape.

Parameters:

stack – whether to stack or to concatenate. E.g. if inst has shape of (3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5), otherwise (10, 5)