Source code for tianshou.highlevel.module.special

from collections.abc import Sequence

from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin


[docs] class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): def __init__( self, preprocess_net_factory: IntermediateModuleFactory, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, ): self.preprocess_net_factory = preprocess_net_factory self.hidden_sizes = hidden_sizes self.num_cosines = num_cosines
[docs] def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) return ImplicitQuantileNetwork( preprocess_net=preprocess_net.module, action_shape=envs.get_action_shape(), hidden_sizes=self.hidden_sizes, num_cosines=self.num_cosines, preprocess_net_output_dim=preprocess_net.output_dim, device=device, ).to(device)