Source code for tianshou.highlevel.params.alpha
from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.string import ToStringMixin
[docs]
class AutoAlphaFactory(ToStringMixin, ABC):
[docs]
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
[docs]
class AutoAlphaFactoryDefault(AutoAlphaFactory):
def __init__(self, lr: float = 3e-4):
self.lr = lr
[docs]
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
return target_entropy, log_alpha, alpha_optim