kitae.loops package#
Submodules#
kitae.loops.train module#
- kitae.loops.train.check_env(env: Env | EnvPool | ParallelEnv | SubProcVecParallelEnv) Env | EnvPool | ParallelEnv | SubProcVecParallelEnv #
Checks if environment can be used for training.
- Parameters:
env (EnvLike) – An environment.
- Returns:
The original environment or a compatible one.
- kitae.loops.train.process_termination(global_step: int, next_observations: ndarray, infos: dict, writer: SummaryWriter | None) ndarray #
Processes the termination part of the loop.
- Parameters:
global_step (int) – The current environment step.
next_observations (Array) – The observations resulting from the last actions.
infos (dict) – The info resulting from the last actions.
writer (SummaryWriter) – A writer to log metrics.
- Returns:
The next observations to store in the buffer.
- kitae.loops.train.vectorized_train(seed: int, agent: IAgent, env: Env | EnvPool | ParallelEnv | SubProcVecParallelEnv, n_env_steps: int, algo_type: AlgoType, *, start_step: int = 1, checkpointer: Checkpointer | None = None, writer: SummaryWriter | None = None) None #
Trains an agent in a vectorized environment.
Important
env.close() will be called at the end of the training.
- Parameters:
seed (int) – An int for reproducibility.
agent (IAgent) – An agent to train.
env (EnvLike) – An environment to train in.
n_env_steps (int) – The number of steps to do in the environment.
algo_type (AlgoType) – The type of algorithm.
start_step (int) – The starting step in the environment.
saver (Saver) – A saver instance.
kitae.loops.update module#
- kitae.loops.update.update_epoch(key: Array, state: AgentPyTree, experience: NamedTuple, batchify_fn: Callable, update_batch_fn: Callable[[AgentPyTree, Array, NamedTuple], tuple[AgentPyTree, dict[str, Array]]], *, experience_type: type[NamedTuple], batch_size: int) tuple[AgentPyTree, dict[str, Array]] #
Updates a state in a single epoch.
This function uses jax.lax.scan which can reduce compilation time.
- Parameters:
key – A PRNGKeyArray for reproducibility.
state – A AgentPyTree containing the agent’s state.
experience – An ExperienceTuple containing processed trajectories.
batchify_fn – A Callable that processes the experience into batches.
update_batch_fn – An UpdateStepFn that updates the agent’s state.
experience_type – A custom experience type.
batch_size – An int that determines the size of a batch.
- Returns:
An updated agent’s state and the corresponding loss dictionary.