kitae.algos package#
Subpackages#
- kitae.algos.collections package
- Submodules
- kitae.algos.collections.dqn module
- kitae.algos.collections.ppo module
- kitae.algos.collections.sac module
- kitae.algos.collections.td3 module
- Module contents
Submodules#
kitae.algos.experience module#
- class kitae.algos.experience.ExperiencePipeline(transforms: list[Callable[[AgentPyTree, Array, NamedTuple], NamedTuple]], vectorized: bool = True, parallel: bool = False)#
Bases:
object
Dataclass for ExperiencePipeline.
An ExperiencePipeline handles a sequence of transformations designed for single agent and non-vectorized environment, ie:
` observation = Array(T, 4,) # transforms designed for this observation = {"a": Array(T, 4,), "b": Array(T, 4,)} # not for this observation = Array(T, E, 4) # not for this `
with T the number of steps and E the number of environments.ExperiencePipeline.run automatically broadcast the sequence of transforms to multi-agent and vectorized environments by treating each agent of each environment separately, then merging the results.
In a sequence of transforms, the output of the previous item should be of the same type as the expected input of the next item. No error will be raised.
If no transforms are provided, the pipeline will simply merge the agents and the environments without applying any transformation.
- transforms#
A list of ExperienceTransform to execute sequentially
- Type:
list[Callable[[kitae.modules.pytree.AgentPyTree, jax.Array, NamedTuple], NamedTuple]]
- vectorized#
A boolean that indicates if the environment is vectorized
- Type:
bool
- parallel#
A boolean that indicates if the environment is multi-agent
- Type:
bool
- parallel: bool = False#
- run(state: AgentPyTree, key: Array, experience: NamedTuple) NamedTuple #
Sequentially runs the experience transforms.
- Parameters:
state – An AgentPyTree that contains the agent’s state.
key – A PRNGKeyArray for reproducibility.
experience – A NamedTuple of the same type as the first transform’s input.
- Returns:
A processed NamedTuple.
- run_parallel(state: AgentPyTree, key: Array, experience: NamedTuple) NamedTuple #
Runs a single pipe in parallel.
Experiences should be provided as a tuple of dictionaries, eg: ``` experience = Foo(
a={“a”: array_a(shape=(5, 3)), “b”: array_b(shape=(5, 3))}, b={“a”: Array_A(shape=(5,)), “b”: Array_B(shape=(5,))}
)#
The output is a tuple of Arrays where keys are merged, eg: ``` output = Foo(
a=array_a / array_b (shape=(10, 3)), b=Array_A / Array_b (shape=(10,))
)#
- param state:
An AgentPyTree state
- param key:
A PRNGKeyArray for reproducibility
- param experience:
A NamedTuple of dictionaries of Arrays
- returns:
A processed NamedTuple where keys are merged by concatenating arrays.
- run_parallel_vectorized(state: AgentPyTree, key: Array, experience: NamedTuple) NamedTuple #
Runs a single pipe in parallel and vectorized.
Experiences should be provided as a tuple of dictionaries, where values are Arrays with at least 2 dimensions, where the first two dimensions should be equal. Eg: ``` experience = Foo(
a={“a”: array_a(shape=(5, 10, 3)), “b”: array_b(shape=(5, 10, 3))}, b={“a”: Array_A(shape=(5, 10)), “b”: Array_B(shape=(5, 10))}
)#
The output is a tuple of Arrays where keys are merged, eg: ``` output = Foo(
a=array_a / array_b (shape=(100, 3)), b=Array_A / Array_b (shape=(100))
)#
- param state:
An AgentPyTree state
- param key:
A PRNGKeyArray for reproducibility
- param experience:
A NamedTuple of dictionaries of Arrays with at least the two first dimensions identical
- returns:
A processed NamedTuple where keys are the first two dimensions are merged.
- run_single_pipe(state: AgentPyTree, key: Array, experience: NamedTuple) NamedTuple #
Runs the transforms sequentially for a single agent in a single environment.
- run_vectorized(state: AgentPyTree, key: Array, experience: NamedTuple) NamedTuple #
Runs a single pipe in vectorized.
Experiences should be provided as a tuple of Arrays with at least 2 dimensions, where the first two dimensions should be equal. Eg:
` experience = Foo(a=Array(shape=(5, 10, 3)), b=Array(shape=(5, 10))) `
The output is a tuple of Arrays where the first 2 dimensions are concatenated, eg:
` output = Foo(a=Array(shape=(50, 3)), b=Array(shape=(50,))) `
- Parameters:
state – An AgentPyTree state
key – A PRNGKeyArray for reproducibility
experience – A NamedTuple of Arrays with at least the two first dimensions identical
- Returns:
A processed NamedTuple where the first two dimensions are concatenated.
- transforms: list[Callable[[AgentPyTree, Array, NamedTuple], NamedTuple]]#
- vectorized: bool = True#
- kitae.algos.experience.dict_to_tuple(experience: dict[str, NamedTuple]) NamedTuple #
Converts a NamedTuple of dictionaries to a dictionary of NamedTuple.
- Eg:
{“a”: Foo(a=Array_a), “b”: Foo(b=Array_b)} becomes: Foo(a={“a”: Array_a, “b”: Array_b})
- kitae.algos.experience.merge_n_first_dims(array: Array, n: int = 2) Array #
Merges the n-first dimensions of an Array.
- Eg:
n==1: (T, E, …) => (T, E, …) n==2: (T, E, …) => (B * E, …) n==3: (T, E, A, …) => (B * E * A, …)
- Raises:
AssertionError – If n is lower or equal to 0.
AssertionError – If n is not specified as static when jitting.
- kitae.algos.experience.stack_and_merge_n_first_dims(arrays: Sequence[Array], n: int = 2) Array #
Stacks a sequence of Arrays then merges its n-first dimensions.
- Eg:
n==1: A * (T, E, …) => (T, A, E, …) n==2: A * (T, E, …) => (B * A, E, …) n==3: A * (T, E, …) => (B * A * E, …)
- Raises:
AssertionError – If n is lower or equal to 0.
AssertionError – If n is not specified as static when jitting.
- kitae.algos.experience.tuple_to_dict(experience: NamedTuple) dict[str, NamedTuple] #
Converts a NamedTuple of dictionaries to a dictionary of NamedTuple.
- Eg:
Foo(a={“a”: Array_a, “b”: Array_b}) becomes: {“a”: Foo(a=Array_a), “b”: Foo(b=Array_b)}
kitae.algos.factory module#
- kitae.algos.factory.explore_general_factory(explore_fn: Callable, vectorized: bool, parallel: bool) Callable #
Generalizes a explore_fn to vector and parallel envs.
- kitae.algos.factory.fn_parallel(fn: Callable) Callable #
Parallelizes a function for mutliple agents.
- Typical usage for function with args:
state
trees of structure: {“agent_0”: Array, “agent_1”: Array, …}
hyperparameters
The wrapped function returns a list of trees with the same structure as input.
Warning: args must be entered in the same order as in fn to allow vmapping.