Source code for tianshou.highlevel.trainer

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar, cast

from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.string import ToStringMixin

TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)


[docs] class TrainingContext: def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger): self.policy = policy self.envs = envs self.logger = logger
[docs] class EpochTrainCallback(ToStringMixin, ABC): """Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase of each epoch. """
[docs] @abstractmethod def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: pass
[docs] def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: def fn(epoch: int, env_step: int) -> None: return self.callback(epoch, env_step, context) return fn
[docs] class EpochTestCallback(ToStringMixin, ABC): """Callback which is called at the beginning of the test phase of each epoch."""
[docs] @abstractmethod def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: pass
[docs] def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]: def fn(epoch: int, env_step: int | None) -> None: return self.callback(epoch, env_step, context) return fn
[docs] class EpochStopCallback(ToStringMixin, ABC): """Callback which is called after the test phase of each epoch in order to determine whether training should stop early. """
[docs] @abstractmethod def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: """Determines whether training should stop. :param mean_rewards: the average undiscounted returns of the testing result :param context: the training context :return: True if the goal has been reached and training should stop, False otherwise """
[docs] def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: def fn(mean_rewards: float) -> bool: return self.should_stop(mean_rewards, context) return fn
[docs] @dataclass class TrainerCallbacks: """Container for callbacks used during training.""" epoch_train_callback: EpochTrainCallback | None = None epoch_test_callback: EpochTestCallback | None = None epoch_stop_callback: EpochStopCallback | None = None
[docs] class EpochTrainCallbackDQNSetEps(EpochTrainCallback): """Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch. """ def __init__(self, eps_test: float): self.eps_test = eps_test
[docs] def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: policy = cast(DQNPolicy, context.policy) policy.set_eps(self.eps_test)
[docs] class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): """Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch, using a linear decay in the first `decay_steps` steps. """ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000): self.eps_train = eps_train self.eps_train_final = eps_train_final self.decay_steps = decay_steps
[docs] def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: policy = cast(DQNPolicy, context.policy) logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( self.eps_train - self.eps_train_final ) else: eps = self.eps_train_final policy.set_eps(eps) logger.write("train/env_step", env_step, {"train/eps": eps})
[docs] class EpochTestCallbackDQNSetEps(EpochTestCallback): """Sets the epsilon value for DQN-based policies at the beginning of the test stage in each epoch. """ def __init__(self, eps_test: float): self.eps_test = eps_test
[docs] def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: policy = cast(DQNPolicy, context.policy) policy.set_eps(self.eps_test)
[docs] class EpochStopCallbackRewardThreshold(EpochStopCallback): """Stops training once the mean rewards exceed the given reward threshold or the threshold that is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`). """ def __init__(self, threshold: float | None = None): """:param threshold: the reward threshold beyond which to stop training. If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`. """ self.threshold = threshold
[docs] def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: threshold = self.threshold if threshold is None: threshold = context.envs.env.spec.reward_threshold # type: ignore assert threshold is not None is_reached = mean_rewards >= threshold if is_reached: log.info(f"Reward threshold ({threshold}) exceeded") return is_reached