Skip to content

Wrappers¤

Composable model transforms that add structure or change output interpretation.

Note

Key notes:

  • EquinoxModel / EquinoxStructuredModel adapt arbitrary Equinox/JAX callables into Phydrax models by attaching in_size / out_size.
  • ComplexOutputModel packs/unpacks real/imag parts into complex outputs.
  • Sequential chains models so outputs of stage i feed stage i+1.

Equinox adapters¤

Use these wrappers when you already have an equinox.Module (or any JAX callable) and want it to participate in Phydrax's solver/training APIs.

layout="value" (default for EquinoxModel) treats in_size/out_size as the value shape of a single (unbatched) sample. Inputs are flattened to a vector, the wrapped module is called, and outputs are reshaped back to the declared value shape.

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx

key = jr.key(0)

mlp = eqx.nn.MLP(
    in_size=4,
    out_size=6,
    width_size=64,
    depth=2,
    activation=jax.nn.tanh,
    key=key,
)

# Declare value shapes: 2×2 -> 3×2 (both flatten to lengths 4 and 6 internally).
model = phx.nn.EquinoxModel(mlp, in_size=(2, 2), out_size=(3, 2))

x = jnp.zeros((2, 2))
y = model(x, key=key)
assert y.shape == (3, 2)

layout="passthrough" forwards inputs/outputs unchanged (the wrapper only supplies metadata). This is useful if your wrapped module already owns its input/output layout.

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx

key = jr.key(0)

drop = eqx.nn.Dropout(p=0.1)
model = phx.nn.EquinoxModel(drop, in_size=4, out_size=4, layout="passthrough")

x = jnp.zeros((4,))
y = model(x, key=key, inference=True)

For structured inputs (e.g. product domains), use EquinoxStructuredModel. With layout="passthrough" it forwards tuples unchanged:

import jax.numpy as jnp
import jax.random as jr
import phydrax as phx

key = jr.key(0)

def stack_pair(inp, *, key=None):
    del key
    a, b = inp
    return jnp.stack([a, b])

model = phx.nn.EquinoxStructuredModel(stack_pair, in_size=2, out_size=2, layout="passthrough")
y = model((1.0, 2.0), key=key)
assert y.shape == (2,)

With layout="value", tuple parts are concatenated into a single vector before calling the wrapped module:

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx

key = jr.key(0)

lin = eqx.nn.Linear(in_features=5, out_features=4, key=key)
model = phx.nn.EquinoxStructuredModel(lin, in_size=5, out_size=4, layout="value")

x = (jnp.ones((2,)), jnp.ones((3,)))
y = model(x, key=key)
assert y.shape == (4,)

Note

  • These wrappers are pointwise by default; use jax.vmap for batching.
  • iter_= is accepted for interface compatibility but is not forwarded to the wrapped callable.

phydrax.nn.EquinoxModel ¤

Adapter for arbitrary Equinox/JAX callables with Phydrax model metadata.

Default layout="value" treats in_size/out_size as value shapes and performs: (flatten value axes) -> (call wrapped module) -> (reshape back to value axes).

Use layout="passthrough" to forward inputs/outputs unchanged.

__init__(module: Any, /, *, in_size: SizeLike, out_size: SizeLike, layout: _Layout = 'value') ¤
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0), iter_: Array | None = None, **kwargs: Any) -> Array ¤

phydrax.nn.EquinoxStructuredModel ¤

Equinox/JAX callable adapter that supports structured (tuple) inputs.

__init__(module: Any, /, *, in_size: SizeLike, out_size: SizeLike, layout: _Layout = 'passthrough') ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0), iter_: Array | None = None, **kwargs: Any) -> Array ¤

Model transforms¤

Sequential is useful for embedded pipelines, for example RandomFourierFeatureEmbeddings -> MLP, then reused inside separable wrappers.

import jax.random as jr
import phydrax as phx

branch = phx.nn.Sequential(
    (
        phx.nn.RandomFourierFeatureEmbeddings(
            in_size="scalar",
            out_size=64,
            key=jr.key(0),
        ),
        phx.nn.MLP(
            in_size=64,
            out_size=16,
            width_size=64,
            depth=2,
            key=jr.key(1),
        ),
    )
)

phydrax.nn.Sequential ¤

Compose models in sequence.

Given models \((m_1,\dots,m_K)\), this wrapper evaluates

\[ y(x)=m_K(\cdots m_2(m_1(x))\cdots). \]

Adjacent models must have compatible sizes: canonical(prev.out_size) == canonical(next.in_size).

__init__(models: collections.abc.Sequence[phydrax.nn.models.core._base._AbstractBaseModel]) ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the model pipeline.

If tuple input is provided, the first stage must support structured input.


phydrax.nn.MagnitudeDirectionModel ¤

Combine a magnitude model and a direction model.

Given a scalar magnitude \(m(x)\) and a direction field \(d(x)\), returns

\[ y(x)=m(x)\,\frac{d(x)}{\|d(x)\|}. \]

If direction_model.out_size == "scalar", no normalization is applied.

__init__(magnitude_model: _AbstractBaseModel, direction_model: _AbstractBaseModel) ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate \(y(x)=m(x)\,\frac{d(x)}{\|d(x)\|}\).

Uses a safe normalization \(\|d\|_\text{safe}=\mathop{\text{max}}(\|d\|,1)\) to avoid division by zero. If direction_model.out_size == "scalar", the normalization is skipped.


phydrax.nn.ComplexOutputModel ¤

Wrap model(s) and return complex outputs.

Two modes: - Single model with real output size \(2k\): split the last axis into \((\Re z,\Im z)\) and return \(z=\Re z + i\,\Im z\). - Pair (real_model, imag_model) with real output size \(k\): return \(z=u_{\text{re}}+i\,u_{\text{im}}\).

If \(k=1\) / "scalar", returns complex scalars (squeezes a trailing unit feature axis).

__init__(model_or_models: phydrax.nn.models.core._base._AbstractBaseModel | tuple[phydrax.nn.models.core._base._AbstractBaseModel, phydrax.nn.models.core._base._AbstractBaseModel]) ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the wrapped model(s) and return complex outputs.

  • Single-model mode: if the wrapped model returns \([u_{\text{re}},u_{\text{im}}]\) (feature axis size \(2k\)), this returns \(u_{\text{re}}+i\,u_{\text{im}}\).
  • Two-model mode: returns \(u_{\text{re}}(x)+i\,u_{\text{im}}(x)\).