collector#


class AsyncCollector(policy: BasePolicy, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, preprocess_fn: Callable[[...], RolloutBatchProtocol] | None = None, exploration_noise: bool = False)[source]#

Async Collector handles async vector environment.

The arguments are exactly the same as Collector, please refer to Collector for more detailed explanation.

collect(n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None) CollectStats[source]#

Collect a specified number of step or episode with async env setting.

This function doesn’t collect exactly n_step or n_episode number of transitions. Instead, in order to support async setting, it may collect more than given n_step or n_episode transitions and save into buffer.

Parameters:
  • n_step – how many steps you want to collect.

  • n_episode – how many episodes you want to collect.

  • random – whether to use random policy for collecting data. Default to False.

  • render – the sleep time between rendering consecutive frames. Default to None (no rendering).

  • no_grad – whether to retain gradient in policy.forward(). Default to True (no gradient retaining).

  • gym_reset_kwargs – extra keyword arguments to pass into the environment’s reset function. Defaults to None (extra keyword arguments)

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns:

A dataclass object

reset_env(gym_reset_kwargs: dict[str, Any] | None = None) None[source]#

Reset all of the environments.

class CollectStats(*, n_collected_episodes: int = 0, n_collected_steps: int = 0, collect_time: float = 0.0, collect_speed: float = 0.0, returns: ndarray, returns_stat: SequenceSummaryStats | None, lens: ndarray, lens_stat: SequenceSummaryStats | None)[source]#

A data structure for storing the statistics of rollouts.

collect_speed: float = 0.0#

The speed of collecting (env_step per second).

collect_time: float = 0.0#

The time for collecting transitions.

lens: ndarray#

The collected episode lengths.

lens_stat: SequenceSummaryStats | None#

Stats of the collected episode lengths.

returns: ndarray#

The collected episode returns.

returns_stat: SequenceSummaryStats | None#

Stats of the collected returns.

class CollectStatsBase(*, n_collected_episodes: int = 0, n_collected_steps: int = 0)[source]#

The most basic stats, often used for offline learning.

n_collected_episodes: int = 0#

The number of collected episodes.

n_collected_steps: int = 0#

The number of collected steps.

class Collector(policy: BasePolicy, env: Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, preprocess_fn: Callable[[...], RolloutBatchProtocol] | None = None, exploration_noise: bool = False)[source]#

Collector enables the policy to interact with different types of envs with exact number of steps or episodes.

Parameters:
  • policy – an instance of the BasePolicy class.

  • env – a gym.Env environment or an instance of the BaseVectorEnv class.

  • buffer – an instance of the ReplayBuffer class. If set to None, it will not store the data. Default to None.

  • preprocess_fn (function) – a function called before the data has been added to the buffer, see issue #42 and Handle Batched Data Stream in Collector. Default to None.

  • exploration_noise – determine whether the action needs to be modified with corresponding policy’s exploration noise. If so, “policy. exploration_noise(act, batch)” will be called automatically to add the exploration noise into action. Default to False.

The “preprocess_fn” is a function called before the data has been added to the buffer with batch format. It will receive only “obs” and “env_id” when the collector resets the environment, and will receive the keys “obs_next”, “rew”, “terminated”, “truncated, “info”, “policy” and “env_id” in a normal env step. Alternatively, it may also accept the keys “obs_next”, “rew”, “done”, “info”, “policy” and “env_id”. It returns either a dict or a Batch with the modified keys and values. Examples are in “test/base/test_collector.py”.

Note

Please make sure the given environment has a time limitation if using n_episode collect option.

Note

In past versions of Tianshou, the replay buffer that was passed to __init__ was automatically reset. This is not done in the current implementation.

collect(n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None) CollectStats[source]#

Collect a specified number of step or episode.

To ensure unbiased sampling result with n_episode option, this function will first collect n_episode - env_num episodes, then for the last env_num episodes, they will be collected evenly from each env.

Parameters:
  • n_step – how many steps you want to collect.

  • n_episode – how many episodes you want to collect.

  • random – whether to use random policy for collecting data. Default to False.

  • render – the sleep time between rendering consecutive frames. Default to None (no rendering).

  • no_grad – whether to retain gradient in policy.forward(). Default to True (no gradient retaining).

  • gym_reset_kwargs – extra keyword arguments to pass into the environment’s reset function. Defaults to None (extra keyword arguments)

Note

One and only one collection number specification is permitted, either n_step or n_episode.

Returns:

A dataclass object

reset(reset_buffer: bool = True, gym_reset_kwargs: dict[str, Any] | None = None) None[source]#

Reset the environment, statistics, current data and possibly replay memory.

Parameters:
  • reset_buffer – if true, reset the replay buffer that is attached to the collector.

  • gym_reset_kwargs – extra keyword arguments to pass into the environment’s reset function. Defaults to None (extra keyword arguments)

reset_buffer(keep_statistics: bool = False) None[source]#

Reset the data buffer.

reset_env(gym_reset_kwargs: dict[str, Any] | None = None) None[source]#

Reset all of the environments.

reset_stat() None[source]#

Reset the statistic variables.