Source code for tianshou.highlevel.params.dist_fn

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any

import torch

from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin


[docs] class DistributionFunctionFactory(ToStringMixin, ABC): # True return type defined in subclasses
[docs] @abstractmethod def create_dist_fn( self, envs: Environments, ) -> Callable[[Any], torch.distributions.Distribution]: pass
[docs] class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: envs.get_type().assert_discrete(self) return self._dist_fn
@staticmethod def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(logits=p)
[docs] class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: envs.get_type().assert_continuous(self) return self._dist_fn
@staticmethod def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution: loc, scale = loc_scale return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
[docs] class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: match envs.get_type(): case EnvType.DISCRETE: return DistributionFunctionFactoryCategorical().create_dist_fn(envs) case EnvType.CONTINUOUS: return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs) case _: raise ValueError(envs.get_type())