Custom Agents#

Agents in Kitae follows the IAgent interface. Most of the functionalities are already pre-implemented by the BaseAgent class, from which we will derive a DQN Agent in this tutorial.

To simplify the implementation of agents, Kitae uses 4 factory functions:

  • train_state_factory: creates the agent’s states

  • explore_factory: creates the function used to interact in the environment

  • process_experience_factory: creates the function to process the data before updating

  • update_step_factory: creates the function to update the agent’s states

0. Import everything#

from collections import namedtuple
from dataclasses import dataclass
from typing import Callable

import distrax as dx
import jax
import jax.numpy as jnp
import optax

from kitae.base import OffPolicyAgent
import kitae.config as cfg
from kitae.types import Params

from kitae.buffer import Experience
from kitae.timesteps import compute_td_targets

from kitae.modules.modules import init_params
from kitae.modules.pytree import AgentPyTree, TrainState
from kitae.modules.qvalue import qvalue_factory

DQN_tuple = namedtuple("DQN_tuple", ["observation", "action", "return_"])
NO_EXPLORATION = 0.0

1. DQN parameters#

@dataclass
class DQNParams:
    exploration: float
    gamma: float
    skip_steps: int
    start_step: int = -1

An AlgoParams is a simple instance of a dataclass. Inheriting from kitae.config.AlgoParams is optional.

2. TrainState Factory#

The train_state_factory takes a key and an AlgoConfig as arguments.

Its output will be stored in the agent as state attribute.

class DQNState(AgentPyTree):
    qvalue: TrainState

def train_state_dqn_factory(
    key: jax.Array,
    config: cfg.AlgoConfig,
    *,
    preprocess_fn: Callable = None,
    tabulate: bool = False,
) -> DQNState:
    observation_shape = config.env_cfg.observation_space.shape
    n_actions = config.env_cfg.action_space.n   # discrete spaces only

    class QValue(nn.Module):
        @nn.compact
        def __call__(self, observations: jax.Array) -> jax.Array:
            x = observations
            x = nn.relu(nn.Dense(64)(x))
            x = nn.relu(nn.Dense(64)(x))
            return nn.Dense(n_actions)(x)
    
    qvalue = QValue()
    return DQNState(
        TrainState.create(
            apply_fn=jax.jit(qvalue.apply),
            params=init_params(key, qvalue, [observation_shape], tabulate),
            target_params=init_params(key, qvalue, [observation_shape], False),
            tx=optax.adam(config.update_cfg.learning_rate),
        )
    )

Here we use kitae.modules.pytree.TrainState which has an additional target_params attribute. In this case, a flax.training.train_state.TrainState would have been enough.

3. Explore factory#

The explore_factory takes an AlgoConfig as argument.

Its output is a function that takes a state, a key and a number of trees as positional arguments. This function should return two Array, an action and a log_prob associated to the action.

This function should consider inputs of the shape [batch_size, ...].

def explore_factory(config: cfg.AlgoConfig) -> Callable:
    @jax.jit
    def explore_fn(
        dqn_state: TrainState,
        key: jax.Array,
        observations: jax.Array,
        exploration: float,
    ) -> jax.Array:
        all_qvalues = dqn_state.apply_fn(dqn_state.params, observations)
        actions, log_probs = dx.EpsilonGreedy(
            all_qvalues, exploration
        ).sample_and_log_prob(seed=key)

        return actions, log_probs

    return explore_fn

4. Process experience factory#

The process_experience_factory takes an AlgoConfig as argument.

Its output is a function that takes a state, a key and a tuple of Arrays. This function should return a tuple of Arrays after processing the inputs.

This function should consider inputs of the shape [batch_size, ...].

import jax.numpy as jnp

from kitae.buffer import Experience

def process_experience_factory(config: cfg.AlgoConfig) -> Callable:
    algo_params = config.algo_params

    @jax.jit
    def process_experience_fn(
        dqn_state: DQNState,
        key: jax.Array,
        experience: Experience,
    ) -> tuple[jax.Array, ...]:

        all_next_qvalues = dqn_state.qvalue.apply_fn(
            dqn_state.qvalue.params, experience.next_observation
        )
        next_qvalues = jnp.max(all_next_qvalues, axis=-1, keepdims=True)

        discounts = algo_params.gamma * (1.0 - experience.done[..., None])
        returns = compute_td_targets(
            experience.reward[..., None], discounts, next_qvalues
        )
        actions = experience.action[..., None]

        return (experience.observation, actions, returns)

    return process_experience_fn

5. Update factory#

The update_step_factory takes an AlgoConfig as argument.

Its output is a function that takes a state, a key and a tuple of Arrays. This function should return a updated version of the state, and a dictionary with the update information.

This function should consider inputs of the shape [batch_size, ...].

def update_step_factory(config: cfg.AlgoConfig) -> Callable:

    @jax.jit
    def update_step_fn(
        dqn_state: DQNState,
        key: jax.Array,
        experiences: tuple[jax.Array, ...],
    ) -> tuple[DQNState, dict]:
        batch = DQN_tuple(*experiences)

        def loss_fn(params: Params):
            all_qvalues = dqn_state.qvalue.apply_fn(params, batch.observation)
            qvalues = jnp.take_along_axis(all_qvalues, batch.action, axis=-1)
            loss = jnp.mean(optax.l2_loss(qvalues, batch.return_))
            return loss, {"loss_qvalue": loss}

        (_, info), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            dqn_state.qvalue.params
        )
        dqn_state.qvalue = dqn_state.qvalue.apply_gradients(grads=grads)

        return dqn_state, info

    return update_step_fn

6. The DQN class#

Thanks to the kitae.base.OffPolicyAgent class, only the explore and select_action methods need to be implemented.

class DQN(OffPolicyAgent):
    def __init__(
        self,
        run_name: str,
        config: cfg.AlgoConfig,
        *,
        preprocess_fn: Callable = None,
        tabulate: bool = False,
    ):
        super().__init__(
            run_name,
            config,
            train_state_dqn_factory,
            explore_factory,
            process_experience_factory,
            update_step_factory,
            preprocess_fn=preprocess_fn,
            tabulate=tabulate,
            experience_type=Experience,
        )

        self.algo_params = self.config.algo_params

    def select_action(self, observation: jax.Array) -> tuple[jax.Array, jax.Array]:
        keys = self.interact_keys(observation)

        action, zeros = self.explore_fn(
            self.state, keys, observation, exploration=NO_EXPLORATION
        )
        return action, zeros

    def explore(self, observation: jax.Array) -> tuple[jax.Array, jax.Array]:
        keys = self.interact_keys(observation)

        action, zeros = self.explore_fn(
            self.state.qvalue,
            keys,
            observation,
            exploration=self.algo_params.exploration,
        )
        return action, zeros

7. Training#

You can now instantiate and train your DQN!

from kitae.envs.make import make_vec_env

SEED = 0
ENV_ID = "CartPole-v1"
N_ENVS = 16

env = make_vec_env(
    env_id=ENV_ID, 
    n_envs=N_ENVS, 
    capture_video=False, 
    run_name=None,
)

env_cfg = cfg.EnvConfig(
    task_name=ENV_ID,
    observation_space=env.single_observation_space,
    action_space=env.single_action_space,
    n_envs=N_ENVS,
    n_agents=1
)

dqn_params = DQNParams(
    exploration=0.1,
    gamma=0.99,
    skip_steps=1,
    start_step=-1
)

update_cfg = cfg.UpdateConfig(
    learning_rate=0.0003,
    learning_rate_annealing=True,
    max_grad_norm=0.5,
    max_buffer_size=64,
    batch_size=256,
    n_epochs=1,
    shared_encoder=True,
)

train_cfg = cfg.TrainConfig(n_env_steps=5*10**5, save_frequency=-1)

algo_config = cfg.AlgoConfig(
    seed=SEED,
    algo_params=dqn_params,
    update_cfg=update_cfg,
    train_cfg=train_cfg,
    env_cfg=env_cfg,
)

dqn = DQN(run_name="example-dqn", config=algo_config)
dqn.train(env, dqn.config.train_cfg.n_env_steps)