tianshou.trainer¶
On-policy¶
- class tianshou.trainer.OnpolicyTrainer(policy: ~tianshou.policy.base.BasePolicy, max_epoch: int, batch_size: int | None, train_collector: ~tianshou.data.collector.Collector | None = None, test_collector: ~tianshou.data.collector.Collector | None = None, buffer: ~tianshou.data.buffer.base.ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: ~collections.abc.Callable[[int, int], None] | None = None, test_fn: ~collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: ~collections.abc.Callable[[float], bool] | None = None, save_best_fn: ~collections.abc.Callable[[~tianshou.policy.base.BasePolicy], None] | None = None, save_checkpoint_fn: ~collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: ~collections.abc.Callable[[~numpy.ndarray], ~numpy.ndarray] | None = None, logger: ~tianshou.utils.logger.base.BaseLogger = <tianshou.utils.logger.base.LazyLogger object>, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True, save_fn: ~collections.abc.Callable[[~tianshou.policy.base.BasePolicy], None] | None = None)[source]¶
Bases:
BaseTrainer
An iterator class for onpolicy trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
The “step” in onpolicy trainer means an environment step (a.k.a. transition).
Example usage:
trainer = OnpolicyTrainer(...) for epoch, epoch_stat, info in trainer: print("Epoch:", epoch) print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info)
epoch int: the epoch number
epoch_stat dict: a large collection of metrics of the current epoch
info dict: result returned from
gather_info()
You can even iterate on several trainers at the same time:
trainer1 = OnpolicyTrainer(...) trainer2 = OnpolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...)
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicy
class.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffer
method, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epoch
ifstop_fn
is set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None
.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None
.save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None
. It wassave_fn
previously.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str
; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool
, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)
, used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
Off-policy¶
- class tianshou.trainer.OffpolicyTrainer(policy: ~tianshou.policy.base.BasePolicy, max_epoch: int, batch_size: int | None, train_collector: ~tianshou.data.collector.Collector | None = None, test_collector: ~tianshou.data.collector.Collector | None = None, buffer: ~tianshou.data.buffer.base.ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: ~collections.abc.Callable[[int, int], None] | None = None, test_fn: ~collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: ~collections.abc.Callable[[float], bool] | None = None, save_best_fn: ~collections.abc.Callable[[~tianshou.policy.base.BasePolicy], None] | None = None, save_checkpoint_fn: ~collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: ~collections.abc.Callable[[~numpy.ndarray], ~numpy.ndarray] | None = None, logger: ~tianshou.utils.logger.base.BaseLogger = <tianshou.utils.logger.base.LazyLogger object>, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True, save_fn: ~collections.abc.Callable[[~tianshou.policy.base.BasePolicy], None] | None = None)[source]¶
Bases:
BaseTrainer
Offpolicy trainer, samples mini-batches from buffer and passes them to update.
Note that with this trainer, it is expected that the policy’s learn method does not perform additional mini-batching but just updates params from the received mini-batch. An iterator class for offpolicy trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
The “step” in offpolicy trainer means an environment step (a.k.a. transition).
Example usage:
trainer = OffpolicyTrainer(...) for epoch, epoch_stat, info in trainer: print("Epoch:", epoch) print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info)
epoch int: the epoch number
epoch_stat dict: a large collection of metrics of the current epoch
info dict: result returned from
gather_info()
You can even iterate on several trainers at the same time:
trainer1 = OffpolicyTrainer(...) trainer2 = OffpolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...)
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicy
class.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffer
method, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epoch
ifstop_fn
is set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None
.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None
.save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None
. It wassave_fn
previously.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str
; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool
, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)
, used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
Offline¶
- class tianshou.trainer.OfflineTrainer(policy: ~tianshou.policy.base.BasePolicy, max_epoch: int, batch_size: int | None, train_collector: ~tianshou.data.collector.Collector | None = None, test_collector: ~tianshou.data.collector.Collector | None = None, buffer: ~tianshou.data.buffer.base.ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: ~collections.abc.Callable[[int, int], None] | None = None, test_fn: ~collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: ~collections.abc.Callable[[float], bool] | None = None, save_best_fn: ~collections.abc.Callable[[~tianshou.policy.base.BasePolicy], None] | None = None, save_checkpoint_fn: ~collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: ~collections.abc.Callable[[~numpy.ndarray], ~numpy.ndarray] | None = None, logger: ~tianshou.utils.logger.base.BaseLogger = <tianshou.utils.logger.base.LazyLogger object>, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True, save_fn: ~collections.abc.Callable[[~tianshou.policy.base.BasePolicy], None] | None = None)[source]¶
Bases:
BaseTrainer
Offline trainer, samples mini-batches from buffer and passes them to update.
Uses a buffer directly and usually does not have a collector. An iterator class for offline trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
The “step” in offline trainer means a gradient step.
Example usage:
trainer = OfflineTrainer(...) for epoch, epoch_stat, info in trainer: print("Epoch:", epoch) print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info)
epoch int: the epoch number
epoch_stat dict: a large collection of metrics of the current epoch
info dict: result returned from
gather_info()
You can even iterate on several trainers at the same time:
trainer1 = OfflineTrainer(...) trainer2 = OfflineTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...)
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicy
class.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffer
method, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epoch
ifstop_fn
is set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None
.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None
.save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None
. It wassave_fn
previously.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str
; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool
, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)
, used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
utils¶
- tianshou.trainer.test_episode(policy: BasePolicy, collector: Collector, test_fn: Callable[[int, int | None], None] | None, epoch: int, n_episode: int, logger: BaseLogger | None = None, global_step: int | None = None, reward_metric: Callable[[ndarray], ndarray] | None = None) dict[str, Any] [source]¶
A simple wrapper of testing policy in collector.
- tianshou.trainer.gather_info(start_time: float, train_collector: Collector | None, test_collector: Collector | None, best_reward: float, best_reward_std: float) dict[str, float | str] [source]¶
A simple wrapper of gathering information from collectors.
- Returns:
A dictionary with the following keys:
train_step
the total collected step of training collector;train_episode
the total collected episode of training collector;train_time/collector
the time for collecting transitions in the training collector;train_time/model
the time for training models;train_speed
the speed of training (env_step per second);test_step
the total collected step of test collector;test_episode
the total collected episode of test collector;test_time
the time for testing;test_speed
the speed of testing (env_step per second);best_reward
the best reward over the test results;duration
the total elapsed time.