Bridge with FlowJax¶
As mentioned in the introduction, the goal of bijx is not necessarily to reimplement all common normalizing flow constructions, in particular since there are other libraries that do that well.
Happily, a benefit of the great flexibility of JAX is that it allows a straightfoward construction of “bridges” between different libraries, making the ecosystem very modular.
As an example, from_flowjax()
and to_flowjax()
of the bijx.flowjax
submodule (which has to be explicitly imported, since flowjax is not an explicit dependency of bijx) translate between bijax and flowjax.
Constructions with either library can be used in the other.
This kind of translation can likely be extended to other jax-based libraries.
import bijx
import jax.numpy as jnp
import jax.random as jr
from flax import nnx
bijx.utils.load_shapes_magic()
# import flowjax
import flowjax
import flowjax.bijections
from flowjax.flows import block_neural_autoregressive_flow
from flowjax.distributions import Normal
# has to be imported explicitly & flowjax has to be installed by the user
import bijx.flowjax as bridge
Flowjax → bijx¶
# example from flowjax documentation
data_key, flow_key, train_key, sample_key = jr.split(jr.key(0), 4)
x = jr.uniform(data_key, (100, 2)) # Toy data
dist_flowjax = block_neural_autoregressive_flow(
key=flow_key,
base_dist=Normal(jnp.zeros(x.shape[1])),
)
# generate samples with flowjax model
%shapes dist_flowjax.sample(sample_key, (100,))
(100, 2)
dist_bijx = bridge.from_flowjax(dist_flowjax)
# also returns log-density, like flowjax's sample_and_log_prob
%shapes dist_bijx.sample((100,), rng=sample_key)
((100, 2), (100,))
Bijx → flowjax¶
flow_bijx = bijx.MonotoneRQSpline(10, (), rngs=nnx.Rngs(params=0))
dist_bijx = bijx.Transformed(bijx.IndependentUniform(()), flow_bijx)
# need to specify additional information in this direction
# because flowjax has safer shape constraints
flow_flowjax = bridge.to_flowjax(flow_bijx, shape=())
# same for distribution
dist_flowjax = bridge.to_flowjax(dist_bijx, shape=())
%shapes dist_flowjax.sample(sample_key, (100,))
(100,)