bijx.ModuleReconstructor

class bijx.ModuleReconstructor[source]

Bases: object

Parameter management utility for dynamically parameterizing modules.

For convenience, can decompose/reconstruct either modules or states.

Extracts parameter structure from a module/state and provides methods to reconstruct the module from different parameter representations (arrays, dicts, leaves). Useful for coupling layers where one network outputs parameters for another bijection.

Representations include:
  • Single array of size params_total_size, use from_array

  • List of array leaves matching param_leaves, use from_leaves

  • Dict of params matching params_dict, use from_dict

  • Full nnx state, use from_params

__init__(module_or_state, filter=<class 'flax.nnx.variablelib.Param'>)[source]
Parameters:
  • module_or_state (State | Module)

  • filter (Param)

Methods

from_params(params[, autovmap])

Reconstructs the module from different parameter representations.

from_state(params)

Attributes

property params
property params_dict
property params_shapes
property params_shape_dict
property params_dtypes
property params_sizes
property params_total_size
property params_array_splits
property has_complex_params
from_state(params)[source]
Parameters:

params (State)

from_params(params, autovmap=False)[source]

Reconstructs the module from different parameter representations.

This method dispatches to the correct reconstruction logic based on the input type.

If autovmap is True, an object is returned that behaves almost like the module except that function calls are automatically vectorized (via vmap) over parameters and inputs.

Parameters:
  • params (dict | list[Array] | Array | State) – Can be a single array, a list of arrays, a dict, or a full nnx state.

  • autovmap (bool) – If True, wrap the reconstruction in an AutoVmapReconstructor.