Source code for tianshou.highlevel.params.policy_wrapper

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic, TypeVar

from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin

TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)


[docs] class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
[docs] @abstractmethod def create_wrapped_policy( self, policy: BasePolicy, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, ) -> TPolicyOut: pass
[docs] class PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactory[ICMPolicy], ): def __init__( self, *, feature_net_factory: IntermediateModuleFactory, hidden_sizes: Sequence[int], lr: float, lr_scale: float, reward_scale: float, forward_loss_weight: float, ): self.feature_net_factory = feature_net_factory self.hidden_sizes = hidden_sizes self.lr = lr self.lr_scale = lr_scale self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight
[docs] def create_wrapped_policy( self, policy: BasePolicy, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, ) -> ICMPolicy: feature_net = self.feature_net_factory.create_intermediate_module(envs, device) action_dim = envs.get_action_shape() if not isinstance(action_dim, int): raise ValueError(f"Environment action shape must be an integer, got {action_dim}") feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net.module, feature_dim, action_dim, hidden_sizes=self.hidden_sizes, device=device, ) icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) return ICMPolicy( policy=policy, model=icm_net, optim=icm_optim, action_space=envs.get_action_space(), lr_scale=self.lr_scale, reward_scale=self.reward_scale, forward_loss_weight=self.forward_loss_weight, ).to(device)