bijx.ComplexScaling

class bijx.ComplexScaling[source]

Bases: Bijection

Diagonal complex affine bijection.

Transform: \(y = e^{s} \cdot e^{i\varphi} \cdot x + b\), with optional shift \(b\), optional phase \(\varphi\), and an optional complex_mask selecting which entries carry an imaginary degree of freedom. The log-Jacobian contribution is \(\sum_i w_i s_i\) with \(w_i = 2\) in fully-complex layout or \(w_i = 1 + \text{mask}_i\) when a mask is given.

For shift and phase, ..._init=None (default) disables the term. Pass an initializer (e.g. nnx.initializers.zeros) to enable it, or pass a pre-built nnx.Variable / GroupedParam via the corresponding shift= / phase= / scale= argument to override construction entirely.

Index sharing across the event (parameter tying within label groups) is supported by passing a GroupedParam for any of the terms; see GroupedParam.from_int_index(). The bijection itself does not know about grouping — anything with a .get_value() returning an array broadcastable to shape works.

Parameters:
  • shape – Event shape; the layout the bijection acts on.

  • scale (Variable | None) – Optional pre-built parameter overriding scale_init.

  • shift (Variable | None) – Optional pre-built parameter; if None and shift_init is None, no shift term is added.

  • phase (Variable | None) – Optional pre-built parameter; if None and phase_init is None, no phase term is added.

  • scale_init – Initializer for the unconstrained log-scale.

  • shift_init – Initializer returning shape (2, *shape) (real/imag stacked); None disables the shift.

  • phase_init – Initializer for the phase angle; None disables the phase.

  • complex_mask – Optional 0/1 array broadcastable to shape marking which entries carry an imaginary DoF. When provided, phase and imaginary shift component are masked on real entries and the log-Jacobian weight becomes 1 + mask.

  • rngs – nnx random number generators (required when any term needs initialization).

__init__(shape, *, scale=None, shift=None, phase=None, scale_init=<function normal.<locals>.init>, shift_init=None, phase_init=None, complex_mask=None, rngs=None)[source]
Parameters:
  • scale (Variable | None)

  • shift (Variable | None)

  • phase (Variable | None)

Methods

apply(x, log_density[, reverse])

Apply the bijection in either direction.

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().