kitae.algos.collections package#
Submodules#
kitae.algos.collections.dqn module#
Deep Q-Network (DQN)
- class kitae.algos.collections.dqn.DQN(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#
Bases:
OffPolicyAgent
Deep Q-Network (DQN) Paper : https://arxiv.org/abs/1312.5602
- explore(observation: Array) tuple[Array, Array] #
Uses the policy to explore the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- select_action(observation: Array) tuple[Array, Array] #
Exploits the policy to interact with the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- class kitae.algos.collections.dqn.DQNParams(exploration: float = 0.1, gamma: float = 0.99, skip_steps: int = 0, start_step: int = -1)#
Bases:
AlgoParams
Deep Q-Network parameters.
- exploration: float = 0.1#
- gamma: float = 0.99#
- skip_steps: int = 0#
- start_step: int = -1#
- class kitae.algos.collections.dqn.DQNState(qvalue_state: kitae.modules.pytree.TrainState)#
Bases:
AgentPyTree
- qvalue_state: TrainState#
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class kitae.algos.collections.dqn.DQN_tuple(observation, action, return_)#
Bases:
tuple
- action#
Alias for field number 1
- observation#
Alias for field number 0
- return_#
Alias for field number 2
- kitae.algos.collections.dqn.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]] #
- kitae.algos.collections.dqn.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple] #
- kitae.algos.collections.dqn.train_state_dqn_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) DQNState #
- kitae.algos.collections.dqn.update_qvalue_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.dqn.update_step_factory(config: AlgoConfig) Callable #
kitae.algos.collections.ppo module#
Proximal Policy Optimization (PPO)
- class kitae.algos.collections.ppo.PPO(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#
Bases:
OnPolicyAgent
Proximal Policy Optimization (PPO) Paper : https://arxiv.org/abs/1707.06347 Implementation details : https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
- class kitae.algos.collections.ppo.PPOParams(gamma: float = 0.99, _lambda: float = 0.95, clip_eps: float = 0.2, entropy_coef: float = 0.1, value_coef: float = 0.5, normalize: bool = True)#
Bases:
AlgoParams
Proximal Policy Optimization parameters.
- clip_eps: float = 0.2#
- entropy_coef: float = 0.1#
- gamma: float = 0.99#
- normalize: bool = True#
- value_coef: float = 0.5#
- class kitae.algos.collections.ppo.PPOState(encoder_state: kitae.modules.pytree.TrainState, policy_state: kitae.modules.pytree.TrainState, value_state: kitae.modules.pytree.TrainState)#
Bases:
AgentPyTree
- encoder_state: TrainState#
- policy_state: TrainState#
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- value_state: TrainState#
- class kitae.algos.collections.ppo.PPO_tuple(observation, action, log_prob, gae, target, value)#
Bases:
tuple
- action#
Alias for field number 1
- gae#
Alias for field number 3
- log_prob#
Alias for field number 2
- observation#
Alias for field number 0
- target#
Alias for field number 4
- value#
Alias for field number 5
- kitae.algos.collections.ppo.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]] #
- kitae.algos.collections.ppo.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple] #
Process experience PPO-style.
- kitae.algos.collections.ppo.train_state_ppo_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) PPOState #
- kitae.algos.collections.ppo.update_policy_value_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.ppo.update_step_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
kitae.algos.collections.sac module#
Soft Actor Critic (SAC)
- class kitae.algos.collections.sac.SAC(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#
Bases:
OffPolicyAgent
Soft Actor Crtic (SAC) Paper: https://arxiv.org/abs/1812.05905
- explore(observation: Array) tuple[Array, Array] #
Uses the policy to explore the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- select_action(observation: Array) tuple[Array, Array] #
Exploits the policy to interact with the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- class kitae.algos.collections.sac.SACParams(gamma: float = 0.99, tau: float = 0.005, log_std_min: float = -20, log_std_max: float = 5, initial_alpha: float = 0.1, skip_steps: int = 1, start_step: int = -1)#
Bases:
AlgoParams
Soft Actor Critic parameters.
- gamma: float = 0.99#
- initial_alpha: float = 0.1#
- log_std_max: float = 5#
- log_std_min: float = -20#
- skip_steps: int = 1#
- start_step: int = -1#
- tau: float = 0.005#
- class kitae.algos.collections.sac.SACState(policy_state: kitae.modules.pytree.TrainState, qvalue_state: kitae.modules.pytree.TrainState, alpha_state: kitae.modules.pytree.TrainState)#
Bases:
AgentPyTree
- alpha_state: TrainState#
- policy_state: TrainState#
- qvalue_state: TrainState#
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class kitae.algos.collections.sac.SAC_tuple(observation, action, target)#
Bases:
tuple
- action#
Alias for field number 1
- observation#
Alias for field number 0
- target#
Alias for field number 2
- kitae.algos.collections.sac.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]] #
- kitae.algos.collections.sac.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple] #
- kitae.algos.collections.sac.train_state_sac_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) SACState #
- kitae.algos.collections.sac.update_alpha_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.sac.update_policy_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.sac.update_qvalue_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.sac.update_step_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
kitae.algos.collections.td3 module#
Deep Deterministic Policy Gradient (TD3)
- class kitae.algos.collections.td3.TD3(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#
Bases:
OffPolicyAgent
Deep Deterministic Policy Gradient (TD3) Paper: https://arxiv.org/abs/1509.02971
- explore(observation: Array) tuple[Array, Array] #
Uses the policy to explore the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- select_action(observation: Array) tuple[Array, Array] #
Exploits the policy to interact with the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- update(buffer: OffPolicyBuffer) dict #
- class kitae.algos.collections.td3.TD3Params(gamma: float = 0.99, tau: float = 0.005, action_noise: float = 0.1, policy_update_frequency: int = 1, target_noise_std: float = 0.2, target_noise_clip: float = 0.5, skip_steps: int = 1, start_step: int = -1)#
Bases:
AlgoParams
TD3 parameters.
- action_noise: float = 0.1#
- gamma: float = 0.99#
- policy_update_frequency: int = 1#
- skip_steps: int = 1#
- start_step: int = -1#
- target_noise_clip: float = 0.5#
- target_noise_std: float = 0.2#
- tau: float = 0.005#
- class kitae.algos.collections.td3.TD3State(policy_state: kitae.modules.pytree.TrainState, qvalue_state: kitae.modules.pytree.TrainState)#
Bases:
AgentPyTree
- policy_state: TrainState#
- qvalue_state: TrainState#
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class kitae.algos.collections.td3.TD3_tuple(observation, action, target)#
Bases:
tuple
- action#
Alias for field number 1
- observation#
Alias for field number 0
- target#
Alias for field number 2
- kitae.algos.collections.td3.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]] #
- kitae.algos.collections.td3.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple] #
- kitae.algos.collections.td3.train_state_ddpg_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) TD3State #
- kitae.algos.collections.td3.update_policy_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.td3.update_qvalue_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #
- kitae.algos.collections.td3.update_step_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]] #