bijx.default_wrap

bijx.default_wrap(x, cls=<class 'flax.nnx.variablelib.Param'>, init_fn=<function normal.<locals>.init>, init_cls=<class 'flax.nnx.variablelib.Param'>, rngs=None)[source]

Flexibly wrap parameter specifications into nnx.Variable instances.

This function provides a unified interface for parameter initialization, accepting various input types and converting them to appropriate nnx.Variable instances for use in Flax modules.

Parameters:
  • x (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Parameter specification (Variable, array, or shape).

  • cls – Variable class to use for array wrapping.

  • init_fn – Initialization function for shape-based parameters.

  • init_cls – Variable class to use for initialized parameters.

  • rngs (Rngs | None) – Random number generators for initialization.

Returns:

Wrapped parameter as an nnx.Variable instance.

Raises:

ValueError – If parameter specification type is not supported.

Example

>>> # Direct array
>>> param = default_wrap(jnp.array([1.0, 2.0]))
>>> # Shape-based initialization
>>> param = default_wrap((10, 5), rngs=rngs)
>>> # Already wrapped
>>> some_array = jnp.array([1.0, 2.0])
>>> param = default_wrap(nnx.Param(some_array))