bijx.default_wrap¶
- bijx.default_wrap(x, cls=<class 'flax.nnx.variablelib.Param'>, init_fn=<function normal.<locals>.init>, init_cls=<class 'flax.nnx.variablelib.Param'>, rngs=None)[source]¶
Flexibly wrap parameter specifications into nnx.Variable instances.
This function provides a unified interface for parameter initialization, accepting various input types and converting them to appropriate nnx.Variable instances for use in Flax modules.
- Parameters:
x (
Union
[Variable
,Array
,ndarray
,Sequence
[Union
[int
,Any
]]]) – Parameter specification (Variable, array, or shape).cls – Variable class to use for array wrapping.
init_fn – Initialization function for shape-based parameters.
init_cls – Variable class to use for initialized parameters.
rngs (
Rngs
|None
) – Random number generators for initialization.
- Returns:
Wrapped parameter as an nnx.Variable instance.
- Raises:
ValueError – If parameter specification type is not supported.
Example
>>> # Direct array >>> param = default_wrap(jnp.array([1.0, 2.0])) >>> # Shape-based initialization >>> param = default_wrap((10, 5), rngs=rngs) >>> # Already wrapped >>> some_array = jnp.array([1.0, 2.0]) >>> param = default_wrap(nnx.Param(some_array))