Source code for tianshou.highlevel.params.dist_fn
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin
[docs]
class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
[docs]
@abstractmethod
def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass
[docs]
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
[docs]
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
[docs]
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
[docs]
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
[docs]
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
[docs]
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())