discrete_crr#


class DiscreteCRRPolicy(*, actor: Module, critic: Module, optim: Optimizer, action_space: Discrete, discount_factor: float = 0.99, policy_improvement_mode: Literal['exp', 'binary', 'all'] = 'exp', ratio_upper_bound: float = 20.0, beta: float = 1.0, min_q_weight: float = 10.0, target_update_freq: int = 0, reward_normalization: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.

Parameters:
  • actor – the actor network following the rules in BasePolicy. (s -> logits)

  • critic – the action-value critic (i.e., Q function) network. (s -> Q(s, *))

  • optim – a torch.optim for optimizing the model.

  • discount_factor – in [0, 1].

  • policy_improvement_mode (str) – type of the weight function f. Possible values: “binary”/”exp”/”all”.

  • ratio_upper_bound – when policy_improvement_mode is “exp”, the value of the exp function is upper-bounded by this parameter.

  • beta – when policy_improvement_mode is “exp”, this is the denominator of the exp function.

  • min_q_weight – weight for CQL loss/regularizer. Default to 10.

  • target_update_freq – the target network update frequency (0 if you do not use the target network).

  • reward_normalization – if True, will normalize the returns by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! See TODO in process_fn.

  • observation_space – Env’s observation space.

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to PGPolicy for more detailed explanation.

learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TDiscreteCRRTrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

sync_weight() None[source]#
class DiscreteCRRTrainingStats(actor_loss: float, critic_loss: float, cql_loss: float, *, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: tianshou.data.stats.SequenceSummaryStats)[source]#
actor_loss: float#
cql_loss: float#
critic_loss: float#