Source code for tianshou.highlevel.params.dist_fn
from abc import ABC, abstractmethod
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.string import ToStringMixin
[docs]
class DistributionFunctionFactory(ToStringMixin, ABC):
[docs]
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
pass
[docs]
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
[docs]
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Categorical(logits=p)
[docs]
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
[docs]
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
[docs]
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
[docs]
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())