kitae package

Contents

kitae package#

Subpackages#

Submodules#

kitae.actor module#

class kitae.actor.PolicyActor(seed: int, policy_state: TrainState, select_action_fn: Callable, *, vectorized: bool = True)#

Bases: IActor, Seeded

Wraps a policy_state into a deployed Actor.

select_action(observation: ObsType) tuple[ActionType, Array]#

Selects an action from an observation.

Parameters:

observation (ObsType) – An observation from the environment.

Returns:

A tuple of actions and log_probs as Array.

kitae.agent module#

Contains the self classes for reinforcement learning.

class kitae.agent.AgentInfo(config: kitae.config.AlgoConfig, extra: dict)#

Bases: object

config: AlgoConfig#
extra: dict#
class kitae.agent.AgentSerializable#

Bases: Serializable

static serialize(agent_info: AgentInfo, path: Path)#
static unserialize(path: Path) AgentInfo#
class kitae.agent.BaseAgent(run_name: str, config: ~kitae.config.AlgoConfig, train_state_factory: ~typing.Callable, explore_factory: ~typing.Callable, process_experience_factory: ~typing.Callable, update_step_factory: ~typing.Callable, *, preprocess_fn: ~typing.Callable | None = None, tabulate: bool = False, experience_type: bool = <class 'kitae.buffer.Experience'>)#

Bases: IAgent, SerializableObject, Seeded

explore(observation: ObsType) tuple[ActionType, ndarray]#

Uses the policy to explore the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

interact_keys(observation: ObsType) Array'>, None)]#
restore(step: int = -1) int#

Restores the agent’s states from the given step.

resume(env, n_env_steps)#

Resumes the training of the agent from the last training step.

Parameters:
  • env (gym.Env) – An EnvLike environment to train in.

  • n_env_steps (int) – An int representing the number of steps in a single environment.

select_action(observation: ObsType) tuple[ActionType, ndarray]#

Exploits the policy to interact with the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

serializable_dict = {'agent_info': <class 'kitae.agent.AgentSerializable'>}#
train(env, n_env_steps)#

Starts the training of the agent.

Parameters:
  • env (gym.Env) – An EnvLike environment to train in.

  • n_env_steps (int) – An int representing the number of steps in a single environment.

classmethod unserialize(path: str | Path)#

Creates a new instance of the agent given the save directory.

Parameters:

path – A string or Path to the save directory.

Returns:

An instance of the chosen agent.

update(buffer: IBuffer) dict#
class kitae.agent.OffPolicyAgent(run_name: str, config: ~kitae.config.AlgoConfig, train_state_factory: ~typing.Callable, explore_factory: ~typing.Callable, process_experience_factory: ~typing.Callable, update_step_factory: ~typing.Callable, *, preprocess_fn: ~typing.Callable | None = None, tabulate: bool = False, experience_type: bool = <class 'kitae.buffer.Experience'>)#

Bases: BaseAgent

algo_type = 'off_policy'#
should_update(step: int, buffer: IBuffer) bool#

Determines if the agent should be updated.

Parameters:
  • step (int) – An int representing the current step for a single environment.

  • buffer (IBuffer) – A Buffer containing the transitions obtained from the environment.

Returns:

A boolean expliciting if the agent should be updated.

step = 0#
class kitae.agent.OnPolicyAgent(run_name: str, config: ~kitae.config.AlgoConfig, train_state_factory: ~typing.Callable, explore_factory: ~typing.Callable, process_experience_factory: ~typing.Callable, update_step_factory: ~typing.Callable, *, preprocess_fn: ~typing.Callable | None = None, tabulate: bool = False, experience_type: bool = <class 'kitae.buffer.Experience'>)#

Bases: BaseAgent

algo_type = 'on_policy'#
should_update(step: int, buffer: IBuffer) bool#

Determines if the agent should be updated.

Parameters:
  • step (int) – An int representing the current step for a single environment.

  • buffer (IBuffer) – A Buffer containing the transitions obtained from the environment.

Returns:

A boolean expliciting if the agent should be updated.

kitae.buffer module#

class kitae.buffer.Buffer(seed: int, max_buffer_size: int = 0)#

Bases: IBuffer

Base Buffer class.

rng#

A numpy random number generator.

Type:

np.random.Generator

max_buffer_size#

The maximum size of the buffer.

Type:

int

buffer#

A list or deque that contains the transitions.

Type:

list | deque

add(experience: Experience) None#

Adds a transition to the buffer.

class kitae.buffer.Experience(observation, action, reward, done, next_observation, log_prob)#

Bases: tuple

action#

Alias for field number 1

done#

Alias for field number 3

log_prob#

Alias for field number 5

next_observation#

Alias for field number 4

observation#

Alias for field number 0

reward#

Alias for field number 2

class kitae.buffer.OffPolicyBuffer(seed: int, max_buffer_size: int = 0)#

Bases: Buffer

OffPolicy variant of the buffer class.

sample(batch_size: int) list[Experience]#

Samples from the OffPolicy buffer.

Parameters:

batch_size (int) – the number of elements to sample.

Returns:

A list of transitions as tuples.

class kitae.buffer.OnPolicyBuffer(seed: int, max_buffer_size: int = 0)#

Bases: Buffer

OnPolicy variant of the buffer class.

sample(batch_size: int = -1) list[Experience]#

Samples from the OnPolicy buffer and then empties it.

Returns:

A list of transitions as tuples.

kitae.buffer.batchify(data: tuple[Array, ...], batch_size: int) tuple[Array, ...]#

Creates batches from a tuple of Array.

Tip

Typical Usage:

batches = batchify(data, batch_size)
for batch in zip(*batches):
    ...
Parameters:
  • data (tuple[jax.Array, ...]) – A tuple of Array of shape [T, …]

  • batch_size (int) – A int that represents the length of each batches

Returns:

A tuple of Array of shape [T // batch_size, batch_size, …]

Raises:

AssertionError – if the batch_size is strictly greater than the number of elements

kitae.buffer.batchify_and_randomize(key: Array, data: tuple[Array, ...], batch_size: int) tuple[Array, ...]#

Randomizes and creates batches from a tuple of Array.

Tip

Typical Usage:

batches = batchify_and_randomize(key, data, batch_size)
for batch in zip(*batches):
    ...
Parameters:
  • key (jax.Array) – An Array for randomness

  • data (tuple[jax.Array, ...]) – A tuple of Array of shape [T, …]

  • batch_size (int) – A int that represents the length of each batches

Returns:

A tuple of Array of shape [T // batch_size, batch_size, …]

Raises:

AssertionError – if the batch_size is strictly greater than the number of elements

kitae.buffer.buffer_factory(seed: int, algo_type: AlgoType, max_buffer_size: int) Buffer#

Generates a buffer based on the AlgoType.

Parameters:
  • seed (int) – An int for reproducibility.

  • algo_type (AlgoType)

  • max_buffer_size (int) – The maximum size of the buffer.

Returns:

An empty instance of the corresponding buffer.

kitae.buffer.jax_stack_experiences(experiences: list[Experience]) Experience#

Stacks list of Experience into a single Experience.

Parameters:

experiences – a list of Experience to stack.

Returns:

An Experience of the stacked inputs.

kitae.buffer.numpy_stack_experiences(experiences: list[Experience]) Experience#

Stacks list of Experience into a single Experience.

Parameters:

experiences – a list of Experience to stack.

Returns:

An Experience of the stacked inputs.

kitae.config module#

class kitae.config.AlgoConfig(seed: int, algo_params: kitae.config.AlgoParams, update_cfg: kitae.config.UpdateConfig, train_cfg: kitae.config.TrainConfig, env_cfg: kitae.config.EnvConfig)#

Bases: object

algo_params: AlgoParams#
env_cfg: EnvConfig#
seed: int#
train_cfg: TrainConfig#
update_cfg: UpdateConfig#
class kitae.config.AlgoParams#

Bases: object

class kitae.config.ConfigSerializable#

Bases: Serializable

Static Serializable class for AlgoConfig.

static serialize(config: AlgoConfig, path: Path)#
static unserialize(path: Path) AlgoConfig#
class kitae.config.EnvConfig(task_name: str, observation_space: gymnasium.spaces.space.Space, action_space: gymnasium.spaces.space.Space, n_envs: int, n_agents: int)#

Bases: object

action_space: Space#
n_agents: int#
n_envs: int#
observation_space: Space#
task_name: str#
class kitae.config.TrainConfig(n_env_steps: int, save_frequency: int)#

Bases: object

n_env_steps: int#
save_frequency: int#
class kitae.config.UpdateConfig(learning_rate: float, learning_rate_annealing: bool, max_grad_norm: float, max_buffer_size: int, batch_size: int, n_epochs: int, shared_encoder: bool)#

Bases: object

batch_size: int#
learning_rate: float#
learning_rate_annealing: bool#
max_buffer_size: int#
max_grad_norm: float#
n_epochs: int#
shared_encoder: bool#

kitae.interface module#

class kitae.interface.AlgoType(value)#

Bases: Enum

An enumeration.

OFF_POLICY = 'off_policy'#
ON_POLICY = 'on_policy'#
class kitae.interface.IActor#

Bases: ABC

Interface for Actor instances.

abstract select_action(observation: ObsType) tuple[ActionType, Array]#

Exploits the policy to interact with the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

class kitae.interface.IAgent#

Bases: ABC

Interface for Agent instances.

abstract explore(observation: ObsType) tuple[ActionType, Array]#

Uses the policy to explore the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

abstract resume(env: Any, n_env_steps: int) None#

Resumes the training of the agent from the last training step.

Parameters:
  • env (gym.Env) – An EnvLike environment to train in.

  • n_env_steps (int) – An int representing the number of steps in a single environment.

abstract select_action(observation: ObsType) tuple[ActionType, Array]#

Exploits the policy to interact with the environment.

Parameters:

observation (ObsType) – An ObsType within the observation_space.

Returns:

An ActionType within the action_space.

abstract should_update(step: int, buffer: IBuffer) bool#

Determines if the agent should be updated.

Parameters:
  • step (int) – An int representing the current step for a single environment.

  • buffer (IBuffer) – A Buffer containing the transitions obtained from the environment.

Returns:

A boolean expliciting if the agent should be updated.

abstract train(env: Any, n_env_steps: int) None#

Starts the training of the agent.

Parameters:
  • env (gym.Env) – An EnvLike environment to train in.

  • n_env_steps (int) – An int representing the number of steps in a single environment.

class kitae.interface.IBuffer#

Bases: ABC

Interface for Buffer instances.

abstract add(experience: type[NamedTuple]) None#
abstract sample(sample_size: int) list[type[NamedTuple]]#

kitae.saving module#

class kitae.saving.SaverContext(checkpointer: Checkpointer, save_frequency: int)#

Bases: AbstractContextManager

A context to ensures that the agent state is saved when the training is interrupted.

Tip

Typical usage:

with SaverContext(saver, save_frequency) as s:
    for step in range(n_env_steps):
        ...

        s.update(step, agent.state)
saver#

A saver instance.

Type:

Saver

update(step: int, state: Any) None#

Informs the Saver of a new state, and saves it when necessary.

Parameters:
  • step (int) – The current step of the environment.

  • state (Any) – The state of the agent.

kitae.saving.default_run_name(env_id: str) str#

Generates a default name for a run.

kitae.types module#

kitae.version module#

Module contents#