bijx.load_shapes_magic

bijx.load_shapes_magic()[source]

Load IPython magic command for inspecting JAX pytree shapes.

Registers the %shapes magic command that displays the shape structure of any JAX pytree. Useful for debugging tensor dimensions in notebooks.

# In IPython/Jupyter: # %shapes some_complex_pytree_or_expression # Displays nested shape structure

Note

Only works in IPython/Jupyter environments. Prints warning if IPython is not available.