bijx.nn.conv

Symmetric convolutions with group invariance for lattice field theory.

This module implements convolutions that preserve spatial symmetries, particularly useful for lattice field theory applications. This is implemented by explicit parameter sharing forcing equivariance. In particular, only the “trivial” representation is supported with respect to the channel “fibers” (i.e. spatial scalar “images” to scalar “images”).

Key components:
  • ConvSym: Symmetric convolution layer with orbit-based parameter sharing

  • kernel_d4(): D4 symmetry group (rotations and reflections) for 2D lattices

  • kernel_equidist(): Distance-based parameter sharing for isotropic kernels.

Symmetric convolutions implement convolutions that commute with group actions:

\[g \cdot (W * x) = W * (g \cdot x)\]
where \(g\) is a group element (rotation, reflection, etc.) and \(W\) is the convolution kernel.

This is achieved through orbit decomposition, where lattice sites are grouped into orbits of equivalent positions under the symmetry group. Parameters are shared within each orbit, reducing the parameter count and enforcing the desired symmetries.

The orbit construction is implemented somewhat naively, by simply applying the group operators to lattice indices and collecting the results.

Functions

conv_indices(shape[, return_flat, center])

Generate index matrix for translation invariant convolution layers.

flip_lattice(lattice, axis)

Apply reflection symmetry along specified axis with proper boundary handling.

fold_kernel(kernel_weights, orbits, orbit_count)

Extract symmetric parameters from full convolution kernel.

gather_orbits(shape, transformations)

Compute orbit decomposition for lattice under given symmetry group.

kernel_d4(shape)

Compute orbit decomposition for D4 dihedral group symmetry.

kernel_equidist(shape)

Compute orbit decomposition based on distance from lattice center.

resize_kernel_weights(kernel, new_shape, *)

Resize convolution kernel while preserving periodicity and symmetries.

rot_lattice_90(lattice, ax1, ax2)

Apply 90-degree rotation in the plane defined by two axes.

unfold_kernel(kernel_params, orbits)

Expand symmetric kernel parameters to full convolution kernel.

unique_index_kernel(shape)

Generate unique indices for lattice sites ordered by distance to center.

Classes

ConvSym

Symmetric convolution layer with orbit-based parameter sharing.

bijx.nn.conv.conv_indices(shape, return_flat=True, center=True)[source]

Generate index matrix for translation invariant convolution layers.

Creates the orbital index structure for translation-invariant convolutions by computing relative position differences between all pairs of lattice sites. This forms the basis for implementing convolutions through orbit decomposition.

Parameters:
  • shape (tuple[int, ...]) – Spatial dimensions of the lattice.

  • return_flat – If True, flatten spatial dimensions to single axis.

  • center – If True, center indices around lattice midpoint.

Returns:

Index array with shape [spatial_rank, *shape, *shape] if return_flat=False, or [spatial_rank, prod(shape), prod(shape)] if return_flat=True. Entry [i, x, y] contains the i-th component of displacement (y-x).

bijx.nn.conv.unique_index_kernel(shape)[source]

Generate unique indices for lattice sites ordered by distance to center.

Creates a unique identifier for each lattice site, with indices assigned in order of increasing distance from the lattice center. Sites at the same distance receive consecutive indices, which facilitates orbit decomposition algorithms.

Parameters:

shape (tuple[int, ...]) – Spatial dimensions of the lattice.

Returns:

Integer array of given shape where each entry contains a unique index. Indices are assigned in order of increasing distance from the center, with the central site receiving index 0.

bijx.nn.conv.flip_lattice(lattice, axis)[source]

Apply reflection symmetry along specified axis with proper boundary handling.

Implements lattice reflection that respects periodic boundary conditions. The reflection is followed by a boundary-preserving roll operation to maintain correct periodicity at lattice edges.

Parameters:
  • lattice (Array) – Input lattice array to reflect.

  • axis (int) – Spatial axis along which to apply the reflection.

Return type:

Array

Returns:

Reflected lattice array with same shape as input.

Note

The roll operation by (shape[axis] % 2 - 1) ensures that the reflection operation is consistent with periodic boundary conditions, which is crucial for maintaining proper lattice symmetries.

bijx.nn.conv.rot_lattice_90(lattice, ax1, ax2)[source]

Apply 90-degree rotation in the plane defined by two axes.

Implements a 90-degree counterclockwise rotation by swapping the specified axes and applying a reflection. This operation is a fundamental building block for constructing the D4 symmetry group on 2D lattices.

Parameters:
  • lattice (Array) – Input lattice array to rotate.

  • ax1 (int) – First axis defining the rotation plane.

  • ax2 (int) – Second axis defining the rotation plane.

Return type:

Array

Returns:

Rotated lattice array with same shape as input.

Note

The composition of axis swapping followed by reflection along ax1 implements a proper 90-degree rotation that preserves the lattice structure and boundary conditions.

bijx.nn.conv.gather_orbits(shape, transformations)[source]

Compute orbit decomposition for lattice under given symmetry group.

Determines equivalence classes (orbits) of lattice sites under the action of the symmetry group generated by the provided transformations. Sites in the same orbit can be transformed into each other by group operations, enabling parameter sharing in symmetric convolutions.

Parameters:
  • shape (tuple[int, ...]) – Spatial dimensions of the lattice.

  • transformations (list[Callable[[Array], Array]]) – List of group generators (symmetry operations). Each transformation takes a lattice array and returns the transformed array under that group element.

Returns:

  • num_orbits: Total number of distinct orbits

  • orbit_indices: Array of same shape as lattice, where each entry contains the orbit index (0 to num_orbits-1) for that site

Return type:

Tuple containing

bijx.nn.conv.kernel_d4(shape)[source]

Compute orbit decomposition for D4 dihedral group symmetry.

Implements the D4 symmetry group consisting of rotations and reflections for 2d shapes. Naturally generalizes to higher dimensions.

Example

>>> # 3x3 kernel with D4 symmetry
>>> num_orbits, orbits = kernel_d4((3, 3))
>>> # num_orbits = 3 (center, corners, edges)
Parameters:

shape (tuple[int, ...]) – Spatial dimensions of the lattice (must be square for rotations).

Returns:

  • num_orbits: Number of distinct orbits under D4 symmetry

  • orbit_indices: Array, same shape as lattice with orbit index for each site

Return type:

Tuple containing

Raises:

AssertionError – If lattice is not square (required for rotation symmetry).

bijx.nn.conv.kernel_equidist(shape)[source]

Compute orbit decomposition based on distance from lattice center.

Creates orbits by grouping lattice sites at equal Euclidean distance from the central origin.

Parameters:

shape (tuple[int, ...]) – Spatial dimensions of the lattice.

Returns:

  • num_orbits: Number of distinct distance shells

  • orbit_indices: Array, same shape as lattice with orbit index for each site

Return type:

Tuple containing

Example

>>> # 5x5 kernel with distance-based orbits
>>> num_orbits, orbits = kernel_equidist((5, 5))
>>> # Sites at distances 0, 1, √2, 2, √5, etc. form separate orbits
bijx.nn.conv.unfold_kernel(kernel_params, orbits)[source]

Expand symmetric kernel parameters to full convolution kernel.

Reconstructs the complete convolution kernel from the compressed orbit representation by broadcasting shared parameters to all sites within each orbit. This operation converts from the efficient symmetric representation back to standard convolution kernel format.

Note

This function is the inverse of fold_kernel().

Parameters:
  • kernel_params (Array) – Compressed parameters with shape (num_orbits, in_channels, out_channels).

  • orbits (Array) – Integer array giving the orbit index for each lattice site.

Return type:

Array

Returns:

Full convolution kernel with shape (*lattice_shape, in_channels, out_channels) where lattice_shape matches the shape of the orbits array.

bijx.nn.conv.resize_kernel_weights(kernel, new_shape, *, mode='constant', constant_values=0.0, **pad_args)[source]

Resize convolution kernel while preserving periodicity and symmetries.

Increases the spatial size of a convolution kernel through careful padding that respects periodic boundary conditions and maintains proper normalization. This is particularly important for symmetric kernels with boundary effects.

Note

For even-sized dimensions, the boundary values are handled specially to maintain proper circular padding. Values at wrap-around edges are split between copies to preserve total weight normalization.

Example

>>> # Resize 3x3 kernel to 5x5
>>> kernel_3x3 = jnp.ones((3, 3, 1, 1))
>>> kernel_5x5 = resize_kernel_weights(kernel_3x3, (5, 5))
Parameters:
  • kernel (Array) – Convolution kernel to resize with shape (*spatial_dims, in_channels, out_channels).

  • new_shape (int | tuple[int, ...]) – Target spatial dimensions. Can be integer (for square kernel) or tuple specifying each dimension.

  • mode (str) – Padding mode for dimensions beyond the original kernel size.

  • constant_values (float) – Fill value when using constant padding mode.

  • **pad_args – Additional arguments passed to numpy padding function.

Return type:

Array

Returns:

Resized kernel with spatial shape matching new_shape while preserving the original channel dimensions.

bijx.nn.conv.fold_kernel(kernel_weights, orbits, orbit_count)[source]

Extract symmetric parameters from full convolution kernel.

Compresses a full convolution kernel into the orbit-based symmetric representation by averaging parameters within each orbit. This is useful for initializing symmetric convolutions from pre-trained standard kernels.

Note

This function is the inverse of unfold_kernel().

Parameters:
  • kernel_weights (Array) – Full convolution kernel to compress.

  • orbits (Array) – Integer array giving the orbit index for each lattice site.

  • orbit_count (int) – Total number of distinct orbits.

Return type:

Array

Returns:

Compressed parameter array with shape (num_orbits, in_channels, out_channels) containing the average parameter values for each orbit.

class bijx.nn.conv.ConvSym[source]

Bases: Module

Symmetric convolution layer with orbit-based parameter sharing.

Implements convolutions that preserve discrete symmetries by sharing parameters among equivalent lattice sites (orbits). This dramatically reduces parameter count while maintaining desired symmetries, making it particularly suitable for lattice field theory and physics applications.

Note

The orbit_function determines the symmetry group. Common choices: - kernel_d4(): D4 dihedral group (rotations + reflections) - kernel_equidist(): Radial symmetry (distance-based orbits) - None: No symmetry (standard convolution)

Example

>>> # D4-symmetric convolution for 2D lattice
>>> conv = ConvSym(
...     in_features=1, out_features=16, kernel_size=(3, 3),
...     orbit_function=kernel_d4, rngs=rngs
... )
>>> y = conv(phi[..., None])  # Preserves rotations and reflections
Parameters:
  • in_features (int) – Number of input feature channels.

  • out_features (int) – Number of output feature channels.

  • kernel_size (Union[int, Sequence[int]]) – Spatial dimensions of convolution kernel.

  • orbit_function (Optional[Callable]) – Function to compute orbit decomposition (default: D4 symmetry).

  • strides (Union[int, Sequence[int]]) – Convolution stride in each spatial dimension.

  • padding (Union[str, int, Sequence[Union[int, tuple[int, int]]]]) – Padding strategy (‘CIRCULAR’ for periodic boundaries).

  • input_dilation (Union[int, Sequence[int]]) – Input dilation factors for each spatial dimension.

  • kernel_dilation (Union[int, Sequence[int]]) – Kernel dilation factors (atrous convolution).

  • feature_group_count (int) – Number of feature groups for grouped convolution.

  • use_bias (bool) – Whether to include learnable bias terms.

  • mask (Array | None) – Optional mask for weights during masked convolution.

  • dtype (Union[str, type[Any], dtype, SupportsDType, Any, None]) – Computation dtype (inferred if None).

  • param_dtype (Union[str, type[Any], dtype, SupportsDType, Any, None]) – Parameter initialization dtype.

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision]]) – Numerical precision specification.

  • kernel_init (Union[Initializer, Callable[..., Any]]) – Kernel parameter initializer.

  • bias_init (Union[Initializer, Callable[..., Any]]) – Bias parameter initializer.

  • conv_general_dilated (Callable[..., Union[Array, Any]]) – Convolution implementation function.

  • promote_dtype (PromoteDtypeFn) – Dtype promotion function.

  • rngs (Rngs) – Random number generator state.

property kernel: Param[Array]

Construct full kernel from orbit-shared parameters.