trainer#


class EpochStopCallback[source]#

Callback which is called after the test phase of each epoch in order to determine whether training should stop early.

get_trainer_fn(context: TrainingContext) Callable[[float], bool][source]#
abstract should_stop(mean_rewards: float, context: TrainingContext) bool[source]#

Determines whether training should stop.

Parameters:
  • mean_rewards – the average undiscounted returns of the testing result

  • context – the training context

Returns:

True if the goal has been reached and training should stop, False otherwise

class EpochStopCallbackRewardThreshold(threshold: float | None = None)[source]#

Stops training once the mean rewards exceed the given reward threshold or the threshold that is specified in the gymnasium environment (i.e. env.spec.reward_threshold).

should_stop(mean_rewards: float, context: TrainingContext) bool[source]#

Determines whether training should stop.

Parameters:
  • mean_rewards – the average undiscounted returns of the testing result

  • context – the training context

Returns:

True if the goal has been reached and training should stop, False otherwise

class EpochTestCallback[source]#

Callback which is called at the beginning of the test phase of each epoch.

abstract callback(epoch: int, env_step: int | None, context: TrainingContext) None[source]#
get_trainer_fn(context: TrainingContext) Callable[[int, int | None], None][source]#
class EpochTestCallbackDQNSetEps(eps_test: float)[source]#

Sets the epsilon value for DQN-based policies at the beginning of the test stage in each epoch.

callback(epoch: int, env_step: int | None, context: TrainingContext) None[source]#
class EpochTrainCallback[source]#

Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase of each epoch.

abstract callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
get_trainer_fn(context: TrainingContext) Callable[[int, int], None][source]#
class EpochTrainCallbackDQNEpsLinearDecay(eps_train: float, eps_train_final: float, decay_steps: int = 1000000)[source]#

Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch, using a linear decay in the first decay_steps steps.

callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
class EpochTrainCallbackDQNSetEps(eps_test: float)[source]#

Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch.

callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
class TrainerCallbacks(epoch_train_callback: EpochTrainCallback | None = None, epoch_test_callback: EpochTestCallback | None = None, epoch_stop_callback: EpochStopCallback | None = None)[source]#

Container for callbacks used during training.

epoch_stop_callback: EpochStopCallback | None = None#
epoch_test_callback: EpochTestCallback | None = None#
epoch_train_callback: EpochTrainCallback | None = None#
class TrainingContext(policy: TPolicy, envs: Environments, logger: BaseLogger)[source]#