agent#


class A2CAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
class ActorCriticAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
create_actor_critic_module_opt(envs: Environments, device: str | device, lr: float) ActorCriticOpt[source]#
class ActorDualCriticsAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
class AgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#

Factory for the creation of an agent’s policy, its trainer as well as collectors.

create_policy(envs: Environments, device: str | device) BasePolicy[source]#
create_train_test_collector(policy: BasePolicy, envs: Environments) tuple[Collector, Collector][source]#
abstract create_trainer(world: World, policy_persistence: PolicyPersistence) BaseTrainer[source]#
set_policy_wrapper_factory(policy_wrapper_factory: PolicyWrapperFactory | None) None[source]#
set_trainer_callbacks(callbacks: TrainerCallbacks) None[source]#
class DDPGAgentFactory(params: DDPGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
class DQNAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#
class DiscreteCriticOnlyAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#
class DiscreteSACAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
class IQNAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#
class NPGAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
class OffPolicyAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#
create_trainer(world: World, policy_persistence: PolicyPersistence) OffpolicyTrainer[source]#
class OnPolicyAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#
create_trainer(world: World, policy_persistence: PolicyPersistence) OnpolicyTrainer[source]#
class PGAgentFactory(params: PGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactory)[source]#
class PPOAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
class REDQAgentFactory(params: REDQParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactory)[source]#
class SACAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
class TD3AgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
class TRPOAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#