kitae.modules package

Contents

kitae.modules package#

Submodules#

kitae.modules.encoder module#

class kitae.modules.encoder.VectorEncoder(preprocess_fn: Callable = None, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
preprocess_fn: Callable = None#
scope: Scope | None = None#
class kitae.modules.encoder.VisionEncoder(rearrange_pattern: str, preprocess_fn: Callable = None, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
preprocess_fn: Callable = None#
rearrange_pattern: str#
scope: Scope | None = None#
kitae.modules.encoder.encoder_factory(observation_space: Space, *, rearrange_pattern: str = 'b h w c -> b h w c', preprocess_fn: Callable | None = None) Type[Module]#

kitae.modules.modules module#

class kitae.modules.modules.IndependentVariable(init_fn: ~typing.Callable, shape: tuple[int], name: str | None = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>)#

Bases: Module

Class for independent learnable variables.

Used in SAC as the temperature coefficient.

init_fn: Callable#
name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
shape: tuple[int]#
class kitae.modules.modules.MLP(layers: list[int], activation: Callable, final_activation: Callable, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

activation: Callable#
final_activation: Callable#
layers: list[int]#
name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
class kitae.modules.modules.PassThrough(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
kitae.modules.modules.conv_layer(features: int, kernel_size: int, strides: int, kernel_init_std: float = 1.4142135623730951, bias_init_cst: float = 0.0) Conv#
kitae.modules.modules.init_params(key: Array, module: Module, input_shapes: Sequence[tuple[int]], tabulate: bool) FrozenDict#

Initializes a module parameters.

kitae.modules.modules.parallel_copy(module: Module, n: int)#

Encapsulates copies of a module and infer them in parallel.

kitae.modules.optimizer module#

kitae.modules.optimizer.linear_learning_rate_schedule(init_learning_rate: float, end_learning_rate: float, *, n_envs: int, n_env_steps: int, max_buffer_size: int, batch_size: int, num_epochs: int) Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int]#

kitae.modules.policy module#

class kitae.modules.policy.PolicyCategorical(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: PolicyOutput

name: str | None = None#
num_outputs: int#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
class kitae.modules.policy.PolicyNormal(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: PolicyOutput

name: str | None = None#
num_outputs: int#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
setup() None#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class kitae.modules.policy.PolicyNormalExternalStd(num_outputs: int, action_scale: jax.Array, action_bias: jax.Array, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: PolicyOutput

action_bias: Array#
action_scale: Array#
name: str | None = None#
num_outputs: int#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
setup() None#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class kitae.modules.policy.PolicyOutput(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
num_outputs: int#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
class kitae.modules.policy.PolicyTanhNormal(num_outputs: int, log_std_min: float, log_std_max: float, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: PolicyOutput

log_std_max: float#
log_std_min: float#
name: str | None = None#
num_outputs: int#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
setup() None#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

kitae.modules.policy.get_log_prob(distribution: Distribution, value: Array) Array#
kitae.modules.policy.make_policy(encoder: Module, policy_output: PolicyOutput) Module#
kitae.modules.policy.policy_output_factory(action_space: Discrete) type[PolicyOutput]#
kitae.modules.policy.sample_and_log_prob(distribution: Distribution, key: Array) tuple[Array, Array]#

kitae.modules.pytree module#

class kitae.modules.pytree.AgentPyTree#

Bases: object

Default Agent State class.

Contrary to Flax’s PyTreeNode, an AgentPyTree is mutable.

class kitae.modules.pytree.TrainState(step: int | Array, apply_fn: Callable, params: FrozenDict[str, Any], tx: GradientTransformation, opt_state: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], target_params: FrozenDict = None)#

Bases: TrainState

Modified TrainState with target_params attribute.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

target_params: FrozenDict = None#

kitae.modules.qvalue module#

class kitae.modules.qvalue.DoubleQValueContinuousOutput(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
class kitae.modules.qvalue.QValueContinousOutput(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
class kitae.modules.qvalue.QValueDiscreteOutput(num_outputs: int, parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
num_outputs: int#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#
kitae.modules.qvalue.make_double_q_value(q1: Module, q2: Module) Module#
kitae.modules.qvalue.qvalue_factory(observation_space: Space, action_space: Space, *, rearrange_pattern: str = 'b h w c -> b h w c', preprocess_fn: Callable | None = None) Type[Module]#

kitae.modules.value module#

class kitae.modules.value.ValueOutput(parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f46e474eb30>, name: Optional[str] = None)#

Bases: Module

name: str | None = None#
parent: Type[Module] | Scope | Type[_Sentinel] | None = None#
scope: Scope | None = None#

Module contents#