Bijections & normalizing flows with JAX/NNX. Provides flexible tools for building normalizing flows and bijections with tractable change of densities, focusing on research in physics.
Diffusion utilities for consistent basis transforms, schedules, and sampling. Implements the methods described in 'GUD: Generation with Unified Diffusion', built around latent component spaces.
Machine learning Calabi-Yau metrics with JAX. A library for numerically approximating Calabi-Yau metrics using machine learning, implementing the algebraic ansatz from Donaldson's algorithm.
Continuous normalizing flow for lattice quantum field theory. Research code for 'Sampling Lattice Field Theory with Continuous Normalizing Flows'; superseded by Bijx.
Automatic vectorization inference for JAX. A utility to automatically apply `jax.vmap` by inferring axes from input shapes, simplifying batch pipelines.
Graph-based lazy evaluation with caching. Demonstrates metaprogramming and advanced python features. Used in Cyjax.