kitae.loops package#

Submodules#

kitae.loops.train module#

kitae.loops.train.check_env(env: Env | EnvPool | ParallelEnv | SubProcVecParallelEnv) Env | EnvPool | ParallelEnv | SubProcVecParallelEnv#

Checks if environment can be used for training.

Parameters:

env (EnvLike) – An environment.

Returns:

The original environment or a compatible one.

kitae.loops.train.process_termination(global_step: int, next_observations: ndarray, infos: dict, writer: SummaryWriter | None) ndarray#

Processes the termination part of the loop.

Parameters:
  • global_step (int) – The current environment step.

  • next_observations (Array) – The observations resulting from the last actions.

  • infos (dict) – The info resulting from the last actions.

  • writer (SummaryWriter) – A writer to log metrics.

Returns:

The next observations to store in the buffer.

kitae.loops.train.vectorized_train(seed: int, agent: IAgent, env: Env | EnvPool | ParallelEnv | SubProcVecParallelEnv, n_env_steps: int, algo_type: AlgoType, *, start_step: int = 1, checkpointer: Checkpointer | None = None, writer: SummaryWriter | None = None) None#

Trains an agent in a vectorized environment.

Important

env.close() will be called at the end of the training.

Parameters:
  • seed (int) – An int for reproducibility.

  • agent (IAgent) – An agent to train.

  • env (EnvLike) – An environment to train in.

  • n_env_steps (int) – The number of steps to do in the environment.

  • algo_type (AlgoType) – The type of algorithm.

  • start_step (int) – The starting step in the environment.

  • saver (Saver) – A saver instance.

kitae.loops.update module#

kitae.loops.update.update_epoch(key: Array, state: AgentPyTree, experience: NamedTuple, batchify_fn: Callable, update_batch_fn: Callable[[AgentPyTree, Array, NamedTuple], tuple[AgentPyTree, dict[str, Array]]], *, experience_type: type[NamedTuple], batch_size: int) tuple[AgentPyTree, dict[str, Array]]#

Updates a state in a single epoch.

This function uses jax.lax.scan which can reduce compilation time.

Parameters:
  • key – A PRNGKeyArray for reproducibility.

  • state – A AgentPyTree containing the agent’s state.

  • experience – An ExperienceTuple containing processed trajectories.

  • batchify_fn – A Callable that processes the experience into batches.

  • update_batch_fn – An UpdateStepFn that updates the agent’s state.

  • experience_type – A custom experience type.

  • batch_size – An int that determines the size of a batch.

Returns:

An updated agent’s state and the corresponding loss dictionary.

Module contents#