Quick Start#
This guide will show you how to use existing agents. To implement custom agents, please see Custom Agents.
1. Installation#
This package requires Python 3.10 or later and a working JAX installation. To install JAX, refer to the instructions.
pip install --upgrade pip
pip install --upgrade git+https://github.com/Raffaelbdl/flax_rl
2. Create an environment#
Kitae training loop uses vectorized environments. You can simply load one using Kitae’s make
functions.
from kitae.envs.make import make_vec_env
SEED = 0
ENV_ID = "CartPole-v1"
N_ENVS = 16
env = make_vec_env(
env_id=ENV_ID,
n_envs=N_ENVS,
capture_video=False,
run_name=None,
)
Here we create a CartPole environment, vectorized on 16 processes.
If you only need 1 environment, you can either set n_envs=1
, or use the make_single_env
function.
3. Create an AlgoConfig#
All parameters are encapsulated into an AlgoConfig
.
An AlgoConfig
is composed of 4 subconfigs:
AlgoParams
: the algorithm specific parametersUpdateConfig
: the update specific parametersTrainConfig
: the training specific parametersEnvConfig
: the environment specific parameters
In this guide, we will use a ppo instance.
from kitae import config as cfg
from kitae.algos.collections import ppo
env_cfg = cfg.EnvConfig(
task_name=ENV_ID,
observation_space=env.single_observation_space,
action_space=env.single_action_space,
n_envs=N_ENVS,
n_agents=1
)
algo_params = ppo.PPOParams(
gamma=0.99,
_lambda=0.95,
clip_eps=0.2,
entropy_coef=0.01,
value_coef=0.5,
normalize=True,
)
update_cfg = cfg.UpdateConfig(
learning_rate=0.0003,
learning_rate_annealing=True,
max_grad_norm=0.5,
max_buffer_size=64,
batch_size=256,
n_epochs=1,
shared_encoder=True,
)
train_cfg = cfg.TrainConfig(n_env_steps=5*10**5, save_frequency=-1)
algo_config = cfg.AlgoConfig(
seed=SEED,
algo_params=algo_params,
update_cfg=update_cfg,
train_cfg=train_cfg,
env_cfg=env_cfg,
)
4. Instantiate and train an agent#
Finally, let’s instantiate a PPO agent.
agent = ppo.PPO(run_name="example-ppo", config=algo_config)
Once a config has been defined, instantiating an agent is as simple as that! Training it is no more complex.
agent.train(env, agent.config.train_cfg.n_env_steps)