Source code for tianshou.highlevel.params.dist_fn

from abc import ABC, abstractmethod

import torch

from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.string import ToStringMixin


[docs] class DistributionFunctionFactory(ToStringMixin, ABC):
[docs] @abstractmethod def create_dist_fn(self, envs: Environments) -> TDistributionFunction: pass
[docs] class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistributionFunction: envs.get_type().assert_discrete(self) return self._dist_fn
@staticmethod def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution: return torch.distributions.Categorical(logits=p)
[docs] class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistributionFunction: envs.get_type().assert_continuous(self) return self._dist_fn
@staticmethod def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution: return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
[docs] class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistributionFunction: match envs.get_type(): case EnvType.DISCRETE: return DistributionFunctionFactoryCategorical().create_dist_fn(envs) case EnvType.CONTINUOUS: return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs) case _: raise ValueError(envs.get_type())