bijx.ModuleReconstructor¶
- class bijx.ModuleReconstructor[source]¶
Bases:
object
Parameter management utility for dynamically parameterizing modules.
For convenience, can decompose/reconstruct either modules or states.
Extracts parameter structure from a module/state and provides methods to reconstruct the module from different parameter representations (arrays, dicts, leaves). Useful for coupling layers where one network outputs parameters for another bijection.
- Representations include:
Single array of size params_total_size, use from_array
List of array leaves matching param_leaves, use from_leaves
Dict of params matching params_dict, use from_dict
Full nnx state, use from_params
- __init__(module_or_state, filter=<class 'flax.nnx.variablelib.Param'>)[source]¶
- Parameters:
module_or_state (State | Module)
filter (Param)
Methods
from_params
(params[, autovmap])Reconstructs the module from different parameter representations.
from_state
(params)Attributes
- property params¶
- property params_dict¶
- property params_shapes¶
- property params_shape_dict¶
- property params_dtypes¶
- property params_sizes¶
- property params_total_size¶
- property params_array_splits¶
- property has_complex_params¶
- from_params(params, autovmap=False)[source]¶
Reconstructs the module from different parameter representations.
This method dispatches to the correct reconstruction logic based on the input type.
If autovmap is True, an object is returned that behaves almost like the module except that function calls are automatically vectorized (via vmap) over parameters and inputs.
- Parameters:
params (
dict
|list
[Array
] |Array
|State
) – Can be a single array, a list of arrays, a dict, or a full nnx state.autovmap (
bool
) – If True, wrap the reconstruction in an AutoVmapReconstructor.