bijx.Radial

class bijx.Radial[source]

Bases: ApplyBijection

Radial bijection with learnable scaling and centering.

This bijection implements a radial transformation \(g(x) = c + f(\abs{S(x-c)}) * v\), where v is the unit vector \((S(x-c))/\abs{S(x-c)}\), S is a diagonal scaling matrix, and c is a center vector. The scalar function \(f: r \rightarrow r'\) is provided by another bijection.

The transformation is composed as: 1. Centering and scaling: \(y = S(x-c)\) 2. Radial transformation on \(y\). 3. Un-scaling and un-centering.

The log-determinant contributions from the scaling \(S\) cancel out, so only the radial transformation’s Jacobian determinant is included.

Parameters:
  • scalar_bijection (Bijection) – A bijection that transforms a scalar radius. Should map positive values to positive values (R+ -> R+). Most bijections are orientation-preserving; use RayTransform if needed to ensure f(0) = 0.

  • n_dims – The dimensionality of the space.

  • center (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Initial center vector \(c\). If None, defaults to zeros.

  • scale (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Initial scale vector for \(S\). If None, defaults to ones.

__init__(scalar_bijection, center=(), scale=(), rngs=None)[source]
Parameters:
  • scalar_bijection (Bijection)

  • center (Variable | Array | ndarray | Sequence[int | Any])

  • scale (Variable | Array | ndarray | Sequence[int | Any])

  • rngs (Rngs)

Methods

apply(x, log_density[, reverse])

Apply radial transformation, forward or reverse.

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

Attributes

scale

Positive scaling factors.

property scale: Array

Positive scaling factors.

apply(x, log_density, reverse=False, **kwargs)[source]

Apply radial transformation, forward or reverse.

Parameters:
  • x (Array)

  • log_density (Array)

  • reverse (bool)