Architectures¤
Common end-to-end model families (dense, separable, polynomial, and complex-valued).
Note
Key notes:
MLPis a standard feed-forward network with optional residual connection.KANreplaces activations with polynomial edge functions.FeynmaNNbuilds complex hidden states with a sum-over-paths block.MLP,KAN,FeynmaNN,FNO1d, andFNO2dsupportscan=Trueto use a scan-over-depth execution path when topology is compatible.scan=Trueis primarily a compile-time optimization for deeper repeated blocks.
phydrax.nn.MLP
¤
Multi-Layer Perceptron (MLP).
For input \(x\in\mathbb{R}^{d_\text{in}}\) this model applies a sequence of affine maps and nonlinearities. Writing \(h^{(0)}=x\), a depth-\(L\) network is
where hidden layers use activation, the final Linear layer uses the
identity nonlinearity, and the output activation \(\phi\) (final_activation)
is applied outside the last layer:
If skip_connection=True then a residual term is added before \(\phi\):
where \(P\) is the identity when \(d_\text{in}=d_\text{out}\) and otherwise a learned linear projection.
__init__(*, in_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], out_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], width_size: int | None = None, depth: int | None = None, hidden_sizes: collections.abc.Sequence[int] | None = None, activation: Callable = tanh, final_activation: collections.abc.Callable | None = None, skip_connection: bool = False, rwf: bool | tuple[float, float] = False, use_bias: bool = True, use_final_bias: bool = True, initializer: str = 'glorot_normal', scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
Construct an MLP.
You may specify the hidden layout either with (width_size, depth) or with
an explicit hidden_sizes sequence.
Arguments:
in_size: Input value size:"scalar",d(vector), or(..., ...)(tensor).out_size: Output value size:"scalar",m(vector), or(..., ...)(tensor).width_size: Uniform hidden width (mutually exclusive withhidden_sizes).depth: Number of hidden layers (mutually exclusive withhidden_sizes).hidden_sizes: Explicit hidden layer sizes.activation: Hidden-layer activation (callable).final_activation: Output activation (default: identity).skip_connection: IfTrue, adds a residual connection to the pre-activation output.rwf: Random Weight Factorization forLinearlayers; if(\mu,\sigma), initializes \(s\sim\mathcal{N}(\mu,\sigma^2)\).use_bias: Whether to use biases in hiddenLinearlayers.use_final_bias: Whether to use a bias in the finalLinearlayer.initializer: Weight initializer name forLinearlayers.scan: IfTrue, usesjax.lax.scanover repeated hidden layers when their topology is compatible. If not compatible, falls back to the standard loop path.key: PRNG key.
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate the MLP at x.
Arguments:
x: Input with trailing value shape implied byin_size. Leading axes are free.key: PRNG key forwarded to layers (most layers are deterministic and ignore it; it is present for API consistency).
Returns:
- Output with trailing value shape implied by
out_size. Ifout_size == "scalar", returns a scalar per leading index (no trailing value axis).
phydrax.nn.KAN
¤
Kolmogorov-Arnold Network (KAN) with orthogonal polynomial edge functions.
Stacks KANLayer blocks; each edge uses a degree-degree orthogonal polynomial
expansion (Chebyshev by default).
Enable use_tanh=True to map pre-activations into \([-1,1]\) before evaluating the basis.
__init__(*, in_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], out_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], width_size: int | None = None, depth: int | None = None, hidden_sizes: collections.abc.Sequence[int] | None = None, degree: int | collections.abc.Sequence[int] = 5, use_tanh: bool = False, scale_mode: typing.Literal['edge', 'input', 'none'] = 'edge', init: typing.Literal['default', 'identity'] = 'default', autoscale: bool = False, final_activation: collections.abc.Callable | None = None, skip_connection: bool = True, use_bias: bool = True, poly: str = 'chebyshev', poly_params: dict | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate the KAN at x.
Applies the stacked KANLayers and an optional residual connection, then
applies final_activation to the result.
phydrax.nn.FeynmaNN
¤
Feynman path-integral style network with complex hidden blocks.
This model builds a complex-valued hidden state and updates it with a sum-over-paths block. For a hidden vector \(z\) and \(K\) paths, one block computes
where \(g=\text{softmax}(\text{logits})\) are learned gates and
\(\alpha(z)\) is produced by a small real action network (scaled by
phase_scale).
The nonlinearity is ModReLU:
Stacking depth blocks yields a complex latent representation which is
mapped to the requested output either by a real readout
(concatenating \(\Re z\) and \(\Im z\)) or by a complex linear projection.
__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], width_size: int, depth: int, num_paths: int = 4, width_action: int = 32, phase_scale: float = 1.0, final_activation: collections.abc.Callable | None = None, modrelu_bias_init: float = 0.0, learn_gates: bool = True, key: Key[Array, ''] = jr.key(0), rwf: bool | tuple[float, float] = False, keep_output_complex: bool = False, scan: bool = False)
¤
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Operator networks¤
Operator-learning models typically consume structured inputs from product domains, e.g. \((\text{data}, x)\) for a dataset factor and a spatial geometry.
Phydrax's Domain.Model(...) wrapper will pass dependency arguments to models as either:
- point inputs (dense sampling): coordinate arrays with leading batch axes, or
- coord-separable inputs: tuples of 1D coordinate axes (from
CoordSeparableBatch).
Note
Input conventions:
DeepONetexpects(branch_input, coords...), wherecoordsare eithercoord_dimseparate 1D axes (grid mode) or a single array with trailing dimensioncoord_dim(point mode).FNO1dexpects(grid_values, x_axis)and requires grid evaluation (x_axismust be 1D with length > 1).FNO2dexpects(grid_values, x_axis, y_axis)and requires grid evaluation (x_axis,y_axis1D with length > 1).- These operator models are minimal implementations intended as building blocks; for production use you may want additional features (padding/dealiasing, normalization, richer positional encodings, batching utilities, etc.).
phydrax.nn.DeepONet
¤
Minimal Deep Operator Network (DeepONet) for structured inputs.
DeepONet represents an operator \(G\) by a low-rank expansion
where the branch network maps the input function/data \(f\) to coefficients \(b(f)\in\mathbb{R}^p\) and the trunk network maps coordinates \(x\) to basis values \(t(x)\in\mathbb{R}^p\) (optionally replicated for vector outputs).
Expects a tuple input (branch_input, coords...). For point inputs, coords
may be a single array with trailing dimension coord_dim. For coord-separable
grid inputs, coords should be coord_dim separate 1D axis arrays.
phydrax.nn.FNO1d
¤
Minimal 1D Fourier Neural Operator for coord-separable grid evaluation.
Input convention (structured tuple):
data: grid values with leading axis lengthn(shape(n,)or(n, c_in)),x_axis: 1D coordinate axis of shape(n,)(must haven>1).
The axis values are used for sanity checking and to enforce "grid mode" usage.
__init__(*, in_channels: int | Literal[scalar] = 'scalar', out_channels: int | Literal[scalar] = 'scalar', width: int = 32, depth: int = 4, modes: int = 16, scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
phydrax.nn.FNO2d
¤
Minimal 2D Fourier Neural Operator for coord-separable grid evaluation.
Input convention (structured tuple):
data: grid values with leading axes(n_x, n_y)(optionally with channels),x_axis: 1D coordinate axis of shape(n_x,)(must haven_x>1),y_axis: 1D coordinate axis of shape(n_y,)(must haven_y>1).
The axis values are used for sanity checking and to enforce "grid mode" usage.