Wrappers¤
Composable model transforms that add structure or change output interpretation.
Note
Key notes:
EquinoxModel/EquinoxStructuredModeladapt arbitrary Equinox/JAX callables into Phydrax models by attachingin_size/out_size.ComplexOutputModelpacks/unpacks real/imag parts into complex outputs.Sequentialchains models so outputs of stageifeed stagei+1.
Equinox adapters¤
Use these wrappers when you already have an equinox.Module (or any JAX callable) and
want it to participate in Phydrax's solver/training APIs.
layout="value" (default for EquinoxModel) treats in_size/out_size as the value shape
of a single (unbatched) sample. Inputs are flattened to a vector, the wrapped module is called,
and outputs are reshaped back to the declared value shape.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx
key = jr.key(0)
mlp = eqx.nn.MLP(
in_size=4,
out_size=6,
width_size=64,
depth=2,
activation=jax.nn.tanh,
key=key,
)
# Declare value shapes: 2×2 -> 3×2 (both flatten to lengths 4 and 6 internally).
model = phx.nn.EquinoxModel(mlp, in_size=(2, 2), out_size=(3, 2))
x = jnp.zeros((2, 2))
y = model(x, key=key)
assert y.shape == (3, 2)
layout="passthrough" forwards inputs/outputs unchanged (the wrapper only supplies metadata).
This is useful if your wrapped module already owns its input/output layout.
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx
key = jr.key(0)
drop = eqx.nn.Dropout(p=0.1)
model = phx.nn.EquinoxModel(drop, in_size=4, out_size=4, layout="passthrough")
x = jnp.zeros((4,))
y = model(x, key=key, inference=True)
For structured inputs (e.g. product domains), use EquinoxStructuredModel. With
layout="passthrough" it forwards tuples unchanged:
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx
key = jr.key(0)
def stack_pair(inp, *, key=None):
del key
a, b = inp
return jnp.stack([a, b])
model = phx.nn.EquinoxStructuredModel(stack_pair, in_size=2, out_size=2, layout="passthrough")
y = model((1.0, 2.0), key=key)
assert y.shape == (2,)
With layout="value", tuple parts are concatenated into a single vector before calling the
wrapped module:
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import phydrax as phx
key = jr.key(0)
lin = eqx.nn.Linear(in_features=5, out_features=4, key=key)
model = phx.nn.EquinoxStructuredModel(lin, in_size=5, out_size=4, layout="value")
x = (jnp.ones((2,)), jnp.ones((3,)))
y = model(x, key=key)
assert y.shape == (4,)
Note
- These wrappers are pointwise by default; use
jax.vmapfor batching. iter_=is accepted for interface compatibility but is not forwarded to the wrapped callable.
phydrax.nn.EquinoxModel
¤
Adapter for arbitrary Equinox/JAX callables with Phydrax model metadata.
Default layout="value" treats in_size/out_size as value shapes and performs:
(flatten value axes) -> (call wrapped module) -> (reshape back to value axes).
Use layout="passthrough" to forward inputs/outputs unchanged.
phydrax.nn.EquinoxStructuredModel
¤
Equinox/JAX callable adapter that supports structured (tuple) inputs.
Model transforms¤
Sequential is useful for embedded pipelines, for example
RandomFourierFeatureEmbeddings -> MLP, then reused inside separable wrappers.
import jax.random as jr
import phydrax as phx
branch = phx.nn.Sequential(
(
phx.nn.RandomFourierFeatureEmbeddings(
in_size="scalar",
out_size=64,
key=jr.key(0),
),
phx.nn.MLP(
in_size=64,
out_size=16,
width_size=64,
depth=2,
key=jr.key(1),
),
)
)
phydrax.nn.Sequential
¤
Compose models in sequence.
Given models \((m_1,\dots,m_K)\), this wrapper evaluates
Adjacent models must have compatible sizes:
canonical(prev.out_size) == canonical(next.in_size).
phydrax.nn.MagnitudeDirectionModel
¤
Combine a magnitude model and a direction model.
Given a scalar magnitude \(m(x)\) and a direction field \(d(x)\), returns
If direction_model.out_size == "scalar", no normalization is applied.
__init__(magnitude_model: _AbstractBaseModel, direction_model: _AbstractBaseModel)
¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate \(y(x)=m(x)\,\frac{d(x)}{\|d(x)\|}\).
Uses a safe normalization \(\|d\|_\text{safe}=\mathop{\text{max}}(\|d\|,1)\) to avoid
division by zero. If direction_model.out_size == "scalar", the
normalization is skipped.
phydrax.nn.ComplexOutputModel
¤
Wrap model(s) and return complex outputs.
Two modes:
- Single model with real output size \(2k\): split the last axis into
\((\Re z,\Im z)\) and return \(z=\Re z + i\,\Im z\).
- Pair (real_model, imag_model) with real output size \(k\): return
\(z=u_{\text{re}}+i\,u_{\text{im}}\).
If \(k=1\) / "scalar", returns complex scalars (squeezes a trailing unit feature axis).
__init__(model_or_models: phydrax.nn.models.core._base._AbstractBaseModel | tuple[phydrax.nn.models.core._base._AbstractBaseModel, phydrax.nn.models.core._base._AbstractBaseModel])
¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate the wrapped model(s) and return complex outputs.
- Single-model mode: if the wrapped model returns \([u_{\text{re}},u_{\text{im}}]\) (feature axis size \(2k\)), this returns \(u_{\text{re}}+i\,u_{\text{im}}\).
- Two-model mode: returns \(u_{\text{re}}(x)+i\,u_{\text{im}}(x)\).