bijx.rational_quadratic_spline¶
- bijx.rational_quadratic_spline(inputs, bin_widths, bin_heights, knot_slopes, *, inverse=False, min_bin_width=0.001, min_bin_height=0.001, min_slope=0.001)[source]¶
Apply monotonic rational quadratic spline transformation.
Implements the rational quadratic spline bijection from Durkan et al. (arXiv:1906.04032). The transformation constructs a smooth, monotonic function using piecewise rational quadratic segments between knot points.
Type: [0, 1] → [0, 1] (with identity extension outside domain)
where \(x' = (x - x_k)/w_k\) is the normalized position within bin \(k\), and \(s_k = h_k/w_k\) is the bin slope.
- Key features:
Monotonic by construction through parameter normalization
Identity transformation outside [0,1] domain
Numerically stable with minimum parameter constraints
Efficient inverse computation via quadratic formula
- Parameters:
inputs – Input values to transform.
bin_widths – Unnormalized bin widths (softmax applied internally).
bin_heights – Unnormalized bin heights (softmax applied internally).
knot_slopes – Internal knot slopes (softplus applied for positivity).
inverse – Whether to apply inverse transformation.
min_bin_width – Minimum bin width for numerical stability.
min_bin_height – Minimum bin height for numerical stability.
min_slope – Minimum knot slope for numerical stability.
- Returns:
Tuple of (transformed_inputs, log_determinant) where log_determinant gives the log absolute Jacobian determinant of the transformation.
Note
Boundary knot slopes are fixed to 1.0 to ensure smooth linear tails outside the spline domain. Only internal knot slopes are trainable.
Example
>>> x = jnp.array([0.2, 0.5, 0.8]) >>> widths = jnp.ones((3, 4)) # 4 bins >>> heights = jnp.ones((3, 4)) >>> slopes = jnp.ones((3, 3)) # 3 internal knots >>> y, log_det = rational_quadratic_spline(x, widths, heights, slopes)