Source code for tianshou.highlevel.module.intermediate
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.utils.string import ToStringMixin
[docs]
@dataclass
class IntermediateModule:
"""Container for a module which computes an intermediate representation (with a known dimension)."""
module: torch.nn.Module
output_dim: int
[docs]
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
"""Factory for the generation of a module which computes an intermediate representation."""
[docs]
@abstractmethod
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
[docs]
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module