bijx.rational_quadratic_spline

bijx.rational_quadratic_spline(inputs, bin_widths, bin_heights, knot_slopes, *, inverse=False, min_bin_width=0.001, min_bin_height=0.001, min_slope=0.001)[source]

Apply monotonic rational quadratic spline transformation.

Implements the rational quadratic spline bijection from Durkan et al. (arXiv:1906.04032). The transformation constructs a smooth, monotonic function using piecewise rational quadratic segments between knot points.

Type: [0, 1] → [0, 1] (with identity extension outside domain)

where \(x' = (x - x_k)/w_k\) is the normalized position within bin \(k\), and \(s_k = h_k/w_k\) is the bin slope.

Key features:
  • Monotonic by construction through parameter normalization

  • Identity transformation outside [0,1] domain

  • Numerically stable with minimum parameter constraints

  • Efficient inverse computation via quadratic formula

Parameters:
  • inputs – Input values to transform.

  • bin_widths – Unnormalized bin widths (softmax applied internally).

  • bin_heights – Unnormalized bin heights (softmax applied internally).

  • knot_slopes – Internal knot slopes (softplus applied for positivity).

  • inverse – Whether to apply inverse transformation.

  • min_bin_width – Minimum bin width for numerical stability.

  • min_bin_height – Minimum bin height for numerical stability.

  • min_slope – Minimum knot slope for numerical stability.

Returns:

Tuple of (transformed_inputs, log_determinant) where log_determinant gives the log absolute Jacobian determinant of the transformation.

Note

Boundary knot slopes are fixed to 1.0 to ensure smooth linear tails outside the spline domain. Only internal knot slopes are trainable.

Example

>>> x = jnp.array([0.2, 0.5, 0.8])
>>> widths = jnp.ones((3, 4))  # 4 bins
>>> heights = jnp.ones((3, 4))
>>> slopes = jnp.ones((3, 3))  # 3 internal knots
>>> y, log_det = rational_quadratic_spline(x, widths, heights, slopes)