kitae.algos.collections package

Contents

kitae.algos.collections package#

Submodules#

kitae.algos.collections.dqn module#

Deep Q-Network (DQN)

class kitae.algos.collections.dqn.DQN(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#

Bases: OffPolicyAgent

Deep Q-Network (DQN) Paper : https://arxiv.org/abs/1312.5602

explore(observation: Array) tuple[Array, Array]#

Uses the policy to explore the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

select_action(observation: Array) tuple[Array, Array]#

Exploits the policy to interact with the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

class kitae.algos.collections.dqn.DQNParams(exploration: float = 0.1, gamma: float = 0.99, skip_steps: int = 0, start_step: int = -1)#

Bases: AlgoParams

Deep Q-Network parameters.

exploration: float = 0.1#
gamma: float = 0.99#
skip_steps: int = 0#
start_step: int = -1#
class kitae.algos.collections.dqn.DQNState(qvalue_state: kitae.modules.pytree.TrainState)#

Bases: AgentPyTree

qvalue_state: TrainState#
replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class kitae.algos.collections.dqn.DQN_tuple(observation, action, return_)#

Bases: tuple

action#

Alias for field number 1

observation#

Alias for field number 0

return_#

Alias for field number 2

kitae.algos.collections.dqn.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]]#
kitae.algos.collections.dqn.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple]#
kitae.algos.collections.dqn.train_state_dqn_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) DQNState#
kitae.algos.collections.dqn.update_qvalue_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.dqn.update_step_factory(config: AlgoConfig) Callable#

kitae.algos.collections.ppo module#

Proximal Policy Optimization (PPO)

class kitae.algos.collections.ppo.PPO(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#

Bases: OnPolicyAgent

Proximal Policy Optimization (PPO) Paper : https://arxiv.org/abs/1707.06347 Implementation details : https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/

class kitae.algos.collections.ppo.PPOParams(gamma: float = 0.99, _lambda: float = 0.95, clip_eps: float = 0.2, entropy_coef: float = 0.1, value_coef: float = 0.5, normalize: bool = True)#

Bases: AlgoParams

Proximal Policy Optimization parameters.

clip_eps: float = 0.2#
entropy_coef: float = 0.1#
gamma: float = 0.99#
normalize: bool = True#
value_coef: float = 0.5#
class kitae.algos.collections.ppo.PPOState(encoder_state: kitae.modules.pytree.TrainState, policy_state: kitae.modules.pytree.TrainState, value_state: kitae.modules.pytree.TrainState)#

Bases: AgentPyTree

encoder_state: TrainState#
policy_state: TrainState#
replace(**updates)#

“Returns a new object replacing the specified fields with new values.

value_state: TrainState#
class kitae.algos.collections.ppo.PPO_tuple(observation, action, log_prob, gae, target, value)#

Bases: tuple

action#

Alias for field number 1

gae#

Alias for field number 3

log_prob#

Alias for field number 2

observation#

Alias for field number 0

target#

Alias for field number 4

value#

Alias for field number 5

kitae.algos.collections.ppo.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]]#
kitae.algos.collections.ppo.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple]#

Process experience PPO-style.

kitae.algos.collections.ppo.train_state_ppo_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) PPOState#
kitae.algos.collections.ppo.update_policy_value_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.ppo.update_step_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#

kitae.algos.collections.sac module#

Soft Actor Critic (SAC)

class kitae.algos.collections.sac.SAC(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#

Bases: OffPolicyAgent

Soft Actor Crtic (SAC) Paper: https://arxiv.org/abs/1812.05905

explore(observation: Array) tuple[Array, Array]#

Uses the policy to explore the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

select_action(observation: Array) tuple[Array, Array]#

Exploits the policy to interact with the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

class kitae.algos.collections.sac.SACParams(gamma: float = 0.99, tau: float = 0.005, log_std_min: float = -20, log_std_max: float = 5, initial_alpha: float = 0.1, skip_steps: int = 1, start_step: int = -1)#

Bases: AlgoParams

Soft Actor Critic parameters.

gamma: float = 0.99#
initial_alpha: float = 0.1#
log_std_max: float = 5#
log_std_min: float = -20#
skip_steps: int = 1#
start_step: int = -1#
tau: float = 0.005#
class kitae.algos.collections.sac.SACState(policy_state: kitae.modules.pytree.TrainState, qvalue_state: kitae.modules.pytree.TrainState, alpha_state: kitae.modules.pytree.TrainState)#

Bases: AgentPyTree

alpha_state: TrainState#
policy_state: TrainState#
qvalue_state: TrainState#
replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class kitae.algos.collections.sac.SAC_tuple(observation, action, target)#

Bases: tuple

action#

Alias for field number 1

observation#

Alias for field number 0

target#

Alias for field number 2

kitae.algos.collections.sac.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]]#
kitae.algos.collections.sac.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple]#
kitae.algos.collections.sac.train_state_sac_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) SACState#
kitae.algos.collections.sac.update_alpha_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.sac.update_policy_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.sac.update_qvalue_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.sac.update_step_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#

kitae.algos.collections.td3 module#

Deep Deterministic Policy Gradient (TD3)

class kitae.algos.collections.td3.TD3(run_name: str, config: AlgoConfig, *, preprocess_fn: Callable | None = None, tabulate: bool = False)#

Bases: OffPolicyAgent

Deep Deterministic Policy Gradient (TD3) Paper: https://arxiv.org/abs/1509.02971

explore(observation: Array) tuple[Array, Array]#

Uses the policy to explore the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

select_action(observation: Array) tuple[Array, Array]#

Exploits the policy to interact with the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

update(buffer: OffPolicyBuffer) dict#
class kitae.algos.collections.td3.TD3Params(gamma: float = 0.99, tau: float = 0.005, action_noise: float = 0.1, policy_update_frequency: int = 1, target_noise_std: float = 0.2, target_noise_clip: float = 0.5, skip_steps: int = 1, start_step: int = -1)#

Bases: AlgoParams

TD3 parameters.

action_noise: float = 0.1#
gamma: float = 0.99#
policy_update_frequency: int = 1#
skip_steps: int = 1#
start_step: int = -1#
target_noise_clip: float = 0.5#
target_noise_std: float = 0.2#
tau: float = 0.005#
class kitae.algos.collections.td3.TD3State(policy_state: kitae.modules.pytree.TrainState, qvalue_state: kitae.modules.pytree.TrainState)#

Bases: AgentPyTree

policy_state: TrainState#
qvalue_state: TrainState#
replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class kitae.algos.collections.td3.TD3_tuple(observation, action, target)#

Bases: tuple

action#

Alias for field number 1

observation#

Alias for field number 0

target#

Alias for field number 2

kitae.algos.collections.td3.explore_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, Array], tuple[Array, Array]]#
kitae.algos.collections.td3.process_experience_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], NamedTuple]#
kitae.algos.collections.td3.train_state_ddpg_factory(key: Array, config: AlgoConfig, *, preprocess_fn: Callable, tabulate: bool = False) TD3State#
kitae.algos.collections.td3.update_policy_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.td3.update_qvalue_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#
kitae.algos.collections.td3.update_step_factory(config: AlgoConfig) Callable[[PyTreeNode, Array, NamedTuple], tuple[PyTreeNode, dict[str, Array]]]#

Module contents#