kitae.modules package#
Submodules#
kitae.modules.encoder module#
- class kitae.modules.encoder.VectorEncoder(preprocess_fn: Callable = None, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- preprocess_fn: Callable = None#
- scope: Scope | None = None#
- class kitae.modules.encoder.VisionEncoder(rearrange_pattern: str, preprocess_fn: Callable = None, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- preprocess_fn: Callable = None#
- rearrange_pattern: str#
- scope: Scope | None = None#
- kitae.modules.encoder.encoder_factory(observation_space: Space, *, rearrange_pattern: str = 'b h w c -> b h w c', preprocess_fn: Callable | None = None) Type[Module] #
kitae.modules.modules module#
- class kitae.modules.modules.IndependentVariable(init_fn: ~typing.Callable, shape: tuple[int], name: str | None = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>)#
Bases:
Module
Class for independent learnable variables.
Used in SAC as the temperature coefficient.
- init_fn: Callable#
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- shape: tuple[int]#
- class kitae.modules.modules.MLP(layers: list[int], activation: Callable, final_activation: Callable, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- activation: Callable#
- final_activation: Callable#
- layers: list[int]#
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- class kitae.modules.modules.PassThrough(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- kitae.modules.modules.conv_layer(features: int, kernel_size: int, strides: int, kernel_init_std: float = 1.4142135623730951, bias_init_cst: float = 0.0) Conv #
- kitae.modules.modules.init_params(key: Array, module: Module, input_shapes: Sequence[tuple[int]], tabulate: bool) FrozenDict #
Initializes a module parameters.
- kitae.modules.modules.parallel_copy(module: Module, n: int)#
Encapsulates copies of a module and infer them in parallel.
kitae.modules.optimizer module#
- kitae.modules.optimizer.linear_learning_rate_schedule(init_learning_rate: float, end_learning_rate: float, *, n_envs: int, n_env_steps: int, max_buffer_size: int, batch_size: int, num_epochs: int) Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int] #
kitae.modules.policy module#
- class kitae.modules.policy.PolicyCategorical(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
PolicyOutput
- name: str | None = None#
- num_outputs: int#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- class kitae.modules.policy.PolicyNormal(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
PolicyOutput
- name: str | None = None#
- num_outputs: int#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- setup() None #
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- class kitae.modules.policy.PolicyNormalExternalStd(num_outputs: int, action_scale: jax.Array, action_bias: jax.Array, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
PolicyOutput
- action_bias: Array#
- action_scale: Array#
- name: str | None = None#
- num_outputs: int#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- setup() None #
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- class kitae.modules.policy.PolicyOutput(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- num_outputs: int#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- class kitae.modules.policy.PolicyTanhNormal(num_outputs: int, log_std_min: float, log_std_max: float, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
PolicyOutput
- log_std_max: float#
- log_std_min: float#
- name: str | None = None#
- num_outputs: int#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- setup() None #
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- kitae.modules.policy.get_log_prob(distribution: Distribution, value: Array) Array #
- kitae.modules.policy.make_policy(encoder: Module, policy_output: PolicyOutput) Module #
- kitae.modules.policy.policy_output_factory(action_space: Discrete) type[PolicyOutput] #
- kitae.modules.policy.sample_and_log_prob(distribution: Distribution, key: Array) tuple[Array, Array] #
kitae.modules.pytree module#
- class kitae.modules.pytree.AgentPyTree#
Bases:
object
Default Agent State class.
Contrary to Flax’s PyTreeNode, an AgentPyTree is mutable.
- class kitae.modules.pytree.TrainState(step: int | Array, apply_fn: Callable, params: FrozenDict[str, Any], tx: GradientTransformation, opt_state: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], target_params: FrozenDict = None)#
Bases:
TrainState
Modified TrainState with target_params attribute.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- target_params: FrozenDict = None#
kitae.modules.qvalue module#
- class kitae.modules.qvalue.DoubleQValueContinuousOutput(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- class kitae.modules.qvalue.QValueContinousOutput(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- class kitae.modules.qvalue.QValueDiscreteOutput(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- num_outputs: int#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#
- kitae.modules.qvalue.make_double_q_value(q1: Module, q2: Module) Module #
- kitae.modules.qvalue.qvalue_factory(observation_space: Space, action_space: Space, *, rearrange_pattern: str = 'b h w c -> b h w c', preprocess_fn: Callable | None = None) Type[Module] #
kitae.modules.value module#
- class kitae.modules.value.ValueOutput(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#
Bases:
Module
- name: str | None = None#
- parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
- scope: Scope | None = None#