bijx.GroupedParam

class bijx.GroupedParam[source]

Bases: Module

Index-sharing parameter wrapper.

Stores a compressed underlying parameter of shape (*lead, K) and exposes it gathered to the full event shape via get_value(), which returns param[..., fund_idx] (the fundamental axis must be the last axis of the underlying).

Construct directly when you have an already-compressed value, or via from_int_index() to build one from an event-shape integer label array and an initializer.

Parameters:
  • value – Underlying parameter array, shape (*lead, K).

  • fund_idx – Integer label array (already compressed to 0..K-1) with shape equal to the target event shape.

  • n_fund (int) – Number of distinct groups K.

__init__(value, fund_idx, n_fund)[source]
Parameters:

n_fund (int)

Methods

from_int_index(int_index, init_fn, rngs, ...)

Build a GroupedParam from a label array and an initializer.

get_value()

Return the full-shape gathered array param[..., fund_idx].

get_value()[source]

Return the full-shape gathered array param[..., fund_idx].

classmethod from_int_index(int_index, init_fn, rngs, full_shape, *, leading_shape=())[source]

Build a GroupedParam from a label array and an initializer.

Labels are compressed to 0..K-1. The initializer is called with shape (*leading_shape, K). If it instead returns the full layout (*leading_shape, *full_shape) (e.g. a precomputed spectrum), values are mean-pooled per group and the within-group population variances are stored on init_pool_var.

Parameters:
  • int_index – Event-shape integer label array.

  • init_fn – Initializer (rng, shape) -> array.

  • rngs – nnx random number generators.

  • full_shape – Target event shape (must match int_index.shape).

  • leading_shape – Optional leading axes (e.g. (2,) for real/imag shift); the fundamental axis is placed after these.