bijx.MonotoneRQSpline

class bijx.MonotoneRQSpline[source]

Bases: ApplyBijection

Monotonic rational quadratic spline bijection.

Implements element-wise rational quadratic spline transformations that maintain monotonicity through parameter normalization. Each element is transformed independently using the same spline but different parameters.

Type: [0, 1]^n → [0, 1]^n (with linear extension outside)

The spline divides the [0,1] interval into bins and constructs smooth rational quadratic curves between knot points. Parameters are automatically normalized to ensure monotonicity and numerical stability.

Parameters:
  • knots – Number of knot points (creates knots-1 bins).

  • event_shape – Shape of individual events being transformed.

  • min_bin_width – Minimum bin width for numerical stability.

  • min_bin_height – Minimum bin height for numerical stability.

  • min_slope – Minimum internal knot slope for numerical stability.

  • widths_init – Initializer for bin width parameters.

  • heights_init – Initializer for bin height parameters.

  • slopes_init – Initializer for internal slope parameters.

  • rngs (Rngs) – Random number generators for parameter initialization.

Example

>>> spline = MonotoneRQSpline(
...     knots=8, event_shape=(3,), rngs=rngs
... )
>>> x = jnp.array([[0.2, 0.5, 0.8]])
>>> y, log_det = spline.forward(x, jnp.zeros(1))
__init__(knots, event_shape=(), *, min_bin_width=0.001, min_bin_height=0.001, min_slope=0.001, widths_init=<function normal.<locals>.init>, heights_init=<function normal.<locals>.init>, slopes_init=<function normal.<locals>.init>, rngs)[source]

Initialize monotonic rational quadratic spline.

Creates trainable parameters for bin widths, heights, and internal knot slopes. Boundary slopes are fixed to 1.0 for linear tails.

Parameters:

rngs (Rngs)

Methods

apply(x, log_density, reverse, **kwargs)

Apply rational quadratic spline transformation.

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

Attributes

param_count

widths + heights + internal slopes.

param_splits

Parameter split sizes for widths, heights, and internal slopes.

property param_count

widths + heights + internal slopes.

Type:

Total number of parameters

property param_splits

Parameter split sizes for widths, heights, and internal slopes.

Returns:

List of sizes [knots, knots, knots-1] for splitting a flattened parameter vector into width, height, and slope components.

apply(x, log_density, reverse, **kwargs)[source]

Apply rational quadratic spline transformation.

Transforms input through the spline bijection and updates log density with the log absolute Jacobian determinant.

Parameters:
  • x – Input array to transform.

  • log_density – Current log density values.

  • reverse – Whether to apply inverse transformation.

  • **kwargs – Additional arguments (unused).

Returns:

Tuple of (transformed_x, updated_log_density).