Source code for tianshou.highlevel.persistence

import logging
from abc import ABC, abstractmethod
from import Callable
from enum import Enum
from typing import TYPE_CHECKING

import torch

from import World

    from tianshou.highlevel.module.core import TDevice

log = logging.getLogger(__name__)

[docs] class PersistEvent(Enum): """Enumeration of persistence events that Persistence objects can react to.""" PERSIST_POLICY = "persist_policy" """Policy neural network is persisted (new best found)"""
[docs] class RestoreEvent(Enum): """Enumeration of restoration events that Persistence objects can react to.""" RESTORE_POLICY = "restore_policy" """Policy neural network parameters are restored"""
[docs] class Persistence(ABC):
[docs] @abstractmethod def persist(self, event: PersistEvent, world: World) -> None: pass
[docs] @abstractmethod def restore(self, event: RestoreEvent, world: World) -> None: pass
[docs] class PersistenceGroup(Persistence): """Groups persistence handler such that they can be applied collectively.""" def __init__(self, *p: Persistence, enabled: bool = True): self.items = p self.enabled = enabled
[docs] def persist(self, event: PersistEvent, world: World) -> None: if not self.enabled: return for item in self.items: item.persist(event, world)
[docs] def restore(self, event: RestoreEvent, world: World) -> None: for item in self.items: item.restore(event, world)
[docs] class PolicyPersistence:
[docs] class Mode(Enum): """Mode of persistence.""" POLICY_STATE_DICT = "policy_state_dict" """Persist only the policy's state dictionary. Note that for a policy to be restored from such a dictionary, it is necessary to first create a structurally equivalent object which can accept the respective state.""" POLICY = "policy" """Persist the entire policy. This is larger but has the advantage of the policy being loadable without requiring an environment to be instantiated. It has the potential disadvantage that upon breaking code changes in the policy implementation (e.g. renamed/moved class), it will no longer be loadable. Note that a precondition is that the policy be picklable in its entirety. """
[docs] def get_filename(self) -> str: return self.value + ".pt"
def __init__( self, additional_persistence: Persistence | None = None, enabled: bool = True, mode: Mode = Mode.POLICY, ): """Handles persistence of the policy. :param additional_persistence: a persistence instance which is to be invoked whenever this object is used to persist/restore data :param enabled: whether persistence is enabled (restoration is always enabled) :param mode: the persistence mode """ self.additional_persistence = additional_persistence self.enabled = enabled self.mode = mode
[docs] def persist(self, policy: torch.nn.Module, world: World) -> None: if not self.enabled: return path = world.persist_path(self.mode.get_filename()) match self.mode: case self.Mode.POLICY_STATE_DICT:"Saving policy state dictionary in {path}"), path) case self.Mode.POLICY:"Saving policy object in {path}"), path) case _: raise NotImplementedError if self.additional_persistence is not None: self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)
[docs] def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None: path = world.restore_path(self.mode.get_filename())"Restoring policy from {path}") match self.mode: case self.Mode.POLICY_STATE_DICT: state_dict = torch.load(path, map_location=device) case self.Mode.POLICY: loaded_policy: torch.nn.Module = torch.load(path, map_location=device) state_dict = loaded_policy.state_dict() case _: raise NotImplementedError policy.load_state_dict(state_dict) if self.additional_persistence is not None: self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
[docs] def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]: def save_best_fn(pol: torch.nn.Module) -> None: self.persist(pol, world) return save_best_fn