bijx.ComplexScaling¶
- class bijx.ComplexScaling[source]¶
Bases:
BijectionDiagonal complex affine bijection.
Transform: \(y = e^{s} \cdot e^{i\varphi} \cdot x + b\), with optional shift \(b\), optional phase \(\varphi\), and an optional
complex_maskselecting which entries carry an imaginary degree of freedom. The log-Jacobian contribution is \(\sum_i w_i s_i\) with \(w_i = 2\) in fully-complex layout or \(w_i = 1 + \text{mask}_i\) when a mask is given.For shift and phase,
..._init=None(default) disables the term. Pass an initializer (e.g.nnx.initializers.zeros) to enable it, or pass a pre-builtnnx.Variable/GroupedParamvia the correspondingshift=/phase=/scale=argument to override construction entirely.Index sharing across the event (parameter tying within label groups) is supported by passing a
GroupedParamfor any of the terms; seeGroupedParam.from_int_index(). The bijection itself does not know about grouping — anything with a.get_value()returning an array broadcastable toshapeworks.- Parameters:
shape – Event shape; the layout the bijection acts on.
scale (
Variable|None) – Optional pre-built parameter overridingscale_init.shift (
Variable|None) – Optional pre-built parameter; ifNoneandshift_init is None, no shift term is added.phase (
Variable|None) – Optional pre-built parameter; ifNoneandphase_init is None, no phase term is added.scale_init – Initializer for the unconstrained log-scale.
shift_init – Initializer returning shape
(2, *shape)(real/imag stacked);Nonedisables the shift.phase_init – Initializer for the phase angle;
Nonedisables the phase.complex_mask – Optional 0/1 array broadcastable to
shapemarking which entries carry an imaginary DoF. When provided, phase and imaginary shift component are masked on real entries and the log-Jacobian weight becomes1 + mask.rngs – nnx random number generators (required when any term needs initialization).
- __init__(shape, *, scale=None, shift=None, phase=None, scale_init=<function normal.<locals>.init>, shift_init=None, phase_init=None, complex_mask=None, rngs=None)[source]¶
- Parameters:
scale (Variable | None)
shift (Variable | None)
phase (Variable | None)
Methods
apply(x, log_density[, reverse])Apply the bijection in either direction.
forward(x, log_density, **kwargs)Apply forward transformation.
invert()Create an inverted version of this bijection.
reverse(x, log_density, **kwargs)Apply reverse (inverse) transformation.
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().