kitae package#
Subpackages#
- kitae.algos package
- Subpackages
- Submodules
- kitae.algos.experience module
ExperiencePipeline
ExperiencePipeline.transforms
ExperiencePipeline.vectorized
ExperiencePipeline.parallel
ExperiencePipeline.parallel
ExperiencePipeline.run()
ExperiencePipeline.run_parallel()
ExperiencePipeline.run_parallel_vectorized()
ExperiencePipeline.run_single_pipe()
ExperiencePipeline.run_vectorized()
ExperiencePipeline.transforms
ExperiencePipeline.vectorized
dict_to_tuple()
merge_n_first_dims()
stack_and_merge_n_first_dims()
tuple_to_dict()
- kitae.algos.factory module
- Module contents
- kitae.envs package
- kitae.loops package
- kitae.modules package
- Submodules
- kitae.modules.encoder module
- kitae.modules.modules module
- kitae.modules.optimizer module
- kitae.modules.policy module
- kitae.modules.pytree module
- kitae.modules.qvalue module
- kitae.modules.value module
- Module contents
- kitae.operations package
Submodules#
kitae.actor module#
- class kitae.actor.PolicyActor(seed: int, policy_state: TrainState, select_action_fn: Callable, *, vectorized: bool = True)#
Bases:
IActor
,Seeded
Wraps a policy_state into a deployed Actor.
- select_action(observation: ObsType) tuple[ActionType, Array] #
Selects an action from an observation.
- Parameters:
observation (ObsType) – An observation from the environment.
- Returns:
A tuple of actions and log_probs as Array.
kitae.agent module#
Contains the self classes for reinforcement learning.
- class kitae.agent.AgentInfo(config: kitae.config.AlgoConfig, extra: dict)#
Bases:
object
- config: AlgoConfig#
- extra: dict#
- class kitae.agent.AgentSerializable#
Bases:
Serializable
- class kitae.agent.BaseAgent(run_name: str, config: ~kitae.config.AlgoConfig, train_state_factory: ~typing.Callable, explore_factory: ~typing.Callable, process_experience_factory: ~typing.Callable, update_step_factory: ~typing.Callable, *, preprocess_fn: ~typing.Callable | None = None, tabulate: bool = False, experience_type: bool = <class 'kitae.buffer.Experience'>)#
Bases:
IAgent
,SerializableObject
,Seeded
- explore(observation: ObsType) tuple[ActionType, ndarray] #
Uses the policy to explore the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- interact_keys(observation: ObsType) Array'>, None)] #
- restore(step: int = -1) int #
Restores the agent’s states from the given step.
- resume(env, n_env_steps)#
Resumes the training of the agent from the last training step.
- Parameters:
env (gym.Env) – An EnvLike environment to train in.
n_env_steps (int) – An int representing the number of steps in a single environment.
- select_action(observation: ObsType) tuple[ActionType, ndarray] #
Exploits the policy to interact with the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- serializable_dict = {'agent_info': <class 'kitae.agent.AgentSerializable'>}#
- train(env, n_env_steps)#
Starts the training of the agent.
- Parameters:
env (gym.Env) – An EnvLike environment to train in.
n_env_steps (int) – An int representing the number of steps in a single environment.
- classmethod unserialize(path: str | Path)#
Creates a new instance of the agent given the save directory.
- Parameters:
path – A string or Path to the save directory.
- Returns:
An instance of the chosen agent.
- class kitae.agent.OffPolicyAgent(run_name: str, config: ~kitae.config.AlgoConfig, train_state_factory: ~typing.Callable, explore_factory: ~typing.Callable, process_experience_factory: ~typing.Callable, update_step_factory: ~typing.Callable, *, preprocess_fn: ~typing.Callable | None = None, tabulate: bool = False, experience_type: bool = <class 'kitae.buffer.Experience'>)#
Bases:
BaseAgent
- algo_type = 'off_policy'#
- should_update(step: int, buffer: IBuffer) bool #
Determines if the agent should be updated.
- Parameters:
step (int) – An int representing the current step for a single environment.
buffer (IBuffer) – A Buffer containing the transitions obtained from the environment.
- Returns:
A boolean expliciting if the agent should be updated.
- step = 0#
- class kitae.agent.OnPolicyAgent(run_name: str, config: ~kitae.config.AlgoConfig, train_state_factory: ~typing.Callable, explore_factory: ~typing.Callable, process_experience_factory: ~typing.Callable, update_step_factory: ~typing.Callable, *, preprocess_fn: ~typing.Callable | None = None, tabulate: bool = False, experience_type: bool = <class 'kitae.buffer.Experience'>)#
Bases:
BaseAgent
- algo_type = 'on_policy'#
- should_update(step: int, buffer: IBuffer) bool #
Determines if the agent should be updated.
- Parameters:
step (int) – An int representing the current step for a single environment.
buffer (IBuffer) – A Buffer containing the transitions obtained from the environment.
- Returns:
A boolean expliciting if the agent should be updated.
kitae.buffer module#
- class kitae.buffer.Buffer(seed: int, max_buffer_size: int = 0)#
Bases:
IBuffer
Base Buffer class.
- rng#
A numpy random number generator.
- Type:
np.random.Generator
- max_buffer_size#
The maximum size of the buffer.
- Type:
int
- buffer#
A list or deque that contains the transitions.
- Type:
list | deque
- add(experience: Experience) None #
Adds a transition to the buffer.
- class kitae.buffer.Experience(observation, action, reward, done, next_observation, log_prob)#
Bases:
tuple
- action#
Alias for field number 1
- done#
Alias for field number 3
- log_prob#
Alias for field number 5
- next_observation#
Alias for field number 4
- observation#
Alias for field number 0
- reward#
Alias for field number 2
- class kitae.buffer.OffPolicyBuffer(seed: int, max_buffer_size: int = 0)#
Bases:
Buffer
OffPolicy variant of the buffer class.
- sample(batch_size: int) list[Experience] #
Samples from the OffPolicy buffer.
- Parameters:
batch_size (int) – the number of elements to sample.
- Returns:
A list of transitions as tuples.
- class kitae.buffer.OnPolicyBuffer(seed: int, max_buffer_size: int = 0)#
Bases:
Buffer
OnPolicy variant of the buffer class.
- sample(batch_size: int = -1) list[Experience] #
Samples from the OnPolicy buffer and then empties it.
- Returns:
A list of transitions as tuples.
- kitae.buffer.batchify(data: tuple[Array, ...], batch_size: int) tuple[Array, ...] #
Creates batches from a tuple of Array.
Tip
Typical Usage:
batches = batchify(data, batch_size) for batch in zip(*batches): ...
- Parameters:
data (tuple[jax.Array, ...]) – A tuple of Array of shape [T, …]
batch_size (int) – A int that represents the length of each batches
- Returns:
A tuple of Array of shape [T // batch_size, batch_size, …]
- Raises:
AssertionError – if the batch_size is strictly greater than the number of elements
- kitae.buffer.batchify_and_randomize(key: Array, data: tuple[Array, ...], batch_size: int) tuple[Array, ...] #
Randomizes and creates batches from a tuple of Array.
Tip
Typical Usage:
batches = batchify_and_randomize(key, data, batch_size) for batch in zip(*batches): ...
- Parameters:
key (jax.Array) – An Array for randomness
data (tuple[jax.Array, ...]) – A tuple of Array of shape [T, …]
batch_size (int) – A int that represents the length of each batches
- Returns:
A tuple of Array of shape [T // batch_size, batch_size, …]
- Raises:
AssertionError – if the batch_size is strictly greater than the number of elements
- kitae.buffer.buffer_factory(seed: int, algo_type: AlgoType, max_buffer_size: int) Buffer #
Generates a buffer based on the AlgoType.
- Parameters:
seed (int) – An int for reproducibility.
algo_type (AlgoType)
max_buffer_size (int) – The maximum size of the buffer.
- Returns:
An empty instance of the corresponding buffer.
- kitae.buffer.jax_stack_experiences(experiences: list[Experience]) Experience #
Stacks list of Experience into a single Experience.
- Parameters:
experiences – a list of Experience to stack.
- Returns:
An Experience of the stacked inputs.
- kitae.buffer.numpy_stack_experiences(experiences: list[Experience]) Experience #
Stacks list of Experience into a single Experience.
- Parameters:
experiences – a list of Experience to stack.
- Returns:
An Experience of the stacked inputs.
kitae.config module#
- class kitae.config.AlgoConfig(seed: int, algo_params: kitae.config.AlgoParams, update_cfg: kitae.config.UpdateConfig, train_cfg: kitae.config.TrainConfig, env_cfg: kitae.config.EnvConfig)#
Bases:
object
- algo_params: AlgoParams#
- seed: int#
- train_cfg: TrainConfig#
- update_cfg: UpdateConfig#
- class kitae.config.AlgoParams#
Bases:
object
- class kitae.config.ConfigSerializable#
Bases:
Serializable
Static Serializable class for AlgoConfig.
- static serialize(config: AlgoConfig, path: Path)#
- static unserialize(path: Path) AlgoConfig #
- class kitae.config.EnvConfig(task_name: str, observation_space: gymnasium.spaces.space.Space, action_space: gymnasium.spaces.space.Space, n_envs: int, n_agents: int)#
Bases:
object
- action_space: Space#
- n_agents: int#
- n_envs: int#
- observation_space: Space#
- task_name: str#
- class kitae.config.TrainConfig(n_env_steps: int, save_frequency: int)#
Bases:
object
- n_env_steps: int#
- save_frequency: int#
- class kitae.config.UpdateConfig(learning_rate: float, learning_rate_annealing: bool, max_grad_norm: float, max_buffer_size: int, batch_size: int, n_epochs: int, shared_encoder: bool)#
Bases:
object
- batch_size: int#
- learning_rate: float#
- learning_rate_annealing: bool#
- max_buffer_size: int#
- max_grad_norm: float#
- n_epochs: int#
kitae.interface module#
- class kitae.interface.AlgoType(value)#
Bases:
Enum
An enumeration.
- OFF_POLICY = 'off_policy'#
- ON_POLICY = 'on_policy'#
- class kitae.interface.IActor#
Bases:
ABC
Interface for Actor instances.
- abstract select_action(observation: ObsType) tuple[ActionType, Array] #
Exploits the policy to interact with the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- class kitae.interface.IAgent#
Bases:
ABC
Interface for Agent instances.
- abstract explore(observation: ObsType) tuple[ActionType, Array] #
Uses the policy to explore the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- abstract resume(env: Any, n_env_steps: int) None #
Resumes the training of the agent from the last training step.
- Parameters:
env (gym.Env) – An EnvLike environment to train in.
n_env_steps (int) – An int representing the number of steps in a single environment.
- abstract select_action(observation: ObsType) tuple[ActionType, Array] #
Exploits the policy to interact with the environment.
- Parameters:
observation (ObsType) – An ObsType within the observation_space.
- Returns:
An ActionType within the action_space.
- abstract should_update(step: int, buffer: IBuffer) bool #
Determines if the agent should be updated.
- Parameters:
step (int) – An int representing the current step for a single environment.
buffer (IBuffer) – A Buffer containing the transitions obtained from the environment.
- Returns:
A boolean expliciting if the agent should be updated.
- abstract train(env: Any, n_env_steps: int) None #
Starts the training of the agent.
- Parameters:
env (gym.Env) – An EnvLike environment to train in.
n_env_steps (int) – An int representing the number of steps in a single environment.
kitae.saving module#
- class kitae.saving.SaverContext(checkpointer: Checkpointer, save_frequency: int)#
Bases:
AbstractContextManager
A context to ensures that the agent state is saved when the training is interrupted.
Tip
Typical usage:
with SaverContext(saver, save_frequency) as s: for step in range(n_env_steps): ... s.update(step, agent.state)
- saver#
A saver instance.
- Type:
Saver
- update(step: int, state: Any) None #
Informs the Saver of a new state, and saves it when necessary.
- Parameters:
step (int) – The current step of the environment.
state (Any) – The state of the agent.
- kitae.saving.default_run_name(env_id: str) str #
Generates a default name for a run.