tianshou.trainer¶
-
tianshou.trainer.
gather_info
(start_time: float, train_c: tianshou.data.collector.Collector, test_c: tianshou.data.collector.Collector, best_reward: float) → Dict[str, Union[float, str]][source]¶ A simple wrapper of gathering information from collectors.
- Returns
A dictionary with the following keys:
train_step
the total collected step of training collector;train_episode
the total collected episode of training collector;train_time/collector
the time for collecting frames in the training collector;train_time/model
the time for training models;train_speed
the speed of training (frames per second);test_step
the total collected step of test collector;test_episode
the total collected episode of test collector;test_time
the time for testing;test_speed
the speed of testing (frames per second);best_reward
the best reward over the test results;duration
the total elapsed time.
-
tianshou.trainer.
offpolicy_trainer
(policy: tianshou.policy.base.BasePolicy, train_collector: tianshou.data.collector.Collector, test_collector: tianshou.data.collector.Collector, max_epoch: int, step_per_epoch: int, collect_per_step: int, episode_per_test: Union[int, List[int]], batch_size: int, update_per_step: int = 1, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[tianshou.policy.base.BasePolicy], None]] = None, writer: Optional[torch.utils.tensorboard.writer.SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, test_in_train: bool = True) → Dict[str, Union[float, str]][source]¶ A wrapper for off-policy trainer procedure.
The “step” in trainer means a policy network update.
- Parameters
policy – an instance of the
BasePolicy
class.train_collector (
Collector
) – the collector used for training.test_collector (
Collector
) – the collector used for testing.max_epoch (int) – the maximum of epochs for training. The training process might be finished before reaching the
max_epoch
.step_per_epoch (int) – the number of step for updating policy network in one epoch.
collect_per_step (int) – the number of frames the collector would collect before the network update. In other words, collect some frames and do some policy network update.
episode_per_test – the number of episodes for one policy evaluation.
batch_size (int) – the batch size of sample data, which is going to feed in the policy network.
update_per_step (int) – the number of times the policy network would be updated after frames are collected, for example, set it to 256 means it updates policy 256 times once after
collect_per_step
frames are collected.train_fn (function) – a function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch.
test_fn (function) – a function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch.
save_fn (function) – a function for saving policy when the undiscounted average mean reward in evaluation phase gets better.
stop_fn (function) – a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
writer (torch.utils.tensorboard.SummaryWriter) – a TensorBoard SummaryWriter.
log_interval (int) – the log interval of the writer.
verbose (bool) – whether to print the information.
test_in_train (bool) – whether to test in the training phase.
- Returns
See
gather_info()
.
-
tianshou.trainer.
onpolicy_trainer
(policy: tianshou.policy.base.BasePolicy, train_collector: tianshou.data.collector.Collector, test_collector: tianshou.data.collector.Collector, max_epoch: int, step_per_epoch: int, collect_per_step: int, repeat_per_collect: int, episode_per_test: Union[int, List[int]], batch_size: int, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[tianshou.policy.base.BasePolicy], None]] = None, writer: Optional[torch.utils.tensorboard.writer.SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, test_in_train: bool = True) → Dict[str, Union[float, str]][source]¶ A wrapper for on-policy trainer procedure.
The “step” in trainer means a policy network update.
- Parameters
policy – an instance of the
BasePolicy
class.train_collector (
Collector
) – the collector used for training.test_collector (
Collector
) – the collector used for testing.max_epoch (int) – the maximum of epochs for training. The training process might be finished before reaching the
max_epoch
.step_per_epoch (int) – the number of step for updating policy network in one epoch.
collect_per_step (int) – the number of episodes the collector would collect before the network update. In other words, collect some episodes and do one policy network update.
repeat_per_collect (int) – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice.
episode_per_test (int or list of ints) – the number of episodes for one policy evaluation.
batch_size (int) – the batch size of sample data, which is going to feed in the policy network.
train_fn (function) – a function receives the current number of epoch and step index, and performs some operations at the beginning of training in this poch.
test_fn (function) – a function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch.
save_fn (function) – a function for saving policy when the undiscounted average mean reward in evaluation phase gets better.
stop_fn (function) – a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
writer (torch.utils.tensorboard.SummaryWriter) – a TensorBoard SummaryWriter.
log_interval (int) – the log interval of the writer.
verbose (bool) – whether to print the information.
test_in_train (bool) – whether to test in the training phase.
- Returns
See
gather_info()
.
-
tianshou.trainer.
test_episode
(policy: tianshou.policy.base.BasePolicy, collector: tianshou.data.collector.Collector, test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, n_episode: Union[int, List[int]], writer: Optional[torch.utils.tensorboard.writer.SummaryWriter] = None, global_step: Optional[int] = None) → Dict[str, float][source]¶ A simple wrapper of testing policy in collector.