Skip to content

Architectures¤

Common end-to-end model families (dense, separable, polynomial, and complex-valued).

Note

Key notes:

  • MLP is a standard feed-forward network with optional residual connection.
  • KAN replaces activations with polynomial edge functions.
  • FeynmaNN builds complex hidden states with a sum-over-paths block.
  • MLP, KAN, FeynmaNN, FNO1d, and FNO2d support scan=True to use a scan-over-depth execution path when topology is compatible.
  • scan=True is primarily a compile-time optimization for deeper repeated blocks.

phydrax.nn.MLP ¤

Multi-Layer Perceptron (MLP).

For input \(x\in\mathbb{R}^{d_\text{in}}\) this model applies a sequence of affine maps and nonlinearities. Writing \(h^{(0)}=x\), a depth-\(L\) network is

\[ h^{(k)}=\sigma_k\!\left(W_k h^{(k-1)}+b_k\right),\qquad k=1,\dots,L, \]

where hidden layers use activation, the final Linear layer uses the identity nonlinearity, and the output activation \(\phi\) (final_activation) is applied outside the last layer:

\[ y=\phi\!\left(h^{(L)}\right). \]

If skip_connection=True then a residual term is added before \(\phi\):

\[ h^{(L)}\leftarrow h^{(L)} + P x, \]

where \(P\) is the identity when \(d_\text{in}=d_\text{out}\) and otherwise a learned linear projection.

__init__(*, in_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], out_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], width_size: int | None = None, depth: int | None = None, hidden_sizes: collections.abc.Sequence[int] | None = None, activation: Callable = tanh, final_activation: collections.abc.Callable | None = None, skip_connection: bool = False, rwf: bool | tuple[float, float] = False, use_bias: bool = True, use_final_bias: bool = True, initializer: str = 'glorot_normal', scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤

Construct an MLP.

You may specify the hidden layout either with (width_size, depth) or with an explicit hidden_sizes sequence.

Arguments:

  • in_size: Input value size: "scalar", d (vector), or (..., ...) (tensor).
  • out_size: Output value size: "scalar", m (vector), or (..., ...) (tensor).
  • width_size: Uniform hidden width (mutually exclusive with hidden_sizes).
  • depth: Number of hidden layers (mutually exclusive with hidden_sizes).
  • hidden_sizes: Explicit hidden layer sizes.
  • activation: Hidden-layer activation (callable).
  • final_activation: Output activation (default: identity).
  • skip_connection: If True, adds a residual connection to the pre-activation output.
  • rwf: Random Weight Factorization for Linear layers; if (\mu,\sigma), initializes \(s\sim\mathcal{N}(\mu,\sigma^2)\).
  • use_bias: Whether to use biases in hidden Linear layers.
  • use_final_bias: Whether to use a bias in the final Linear layer.
  • initializer: Weight initializer name for Linear layers.
  • scan: If True, uses jax.lax.scan over repeated hidden layers when their topology is compatible. If not compatible, falls back to the standard loop path.
  • key: PRNG key.
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the MLP at x.

Arguments:

  • x: Input with trailing value shape implied by in_size. Leading axes are free.
  • key: PRNG key forwarded to layers (most layers are deterministic and ignore it; it is present for API consistency).

Returns:

  • Output with trailing value shape implied by out_size. If out_size == "scalar", returns a scalar per leading index (no trailing value axis).

phydrax.nn.KAN ¤

Kolmogorov-Arnold Network (KAN) with orthogonal polynomial edge functions.

Stacks KANLayer blocks; each edge uses a degree-degree orthogonal polynomial expansion (Chebyshev by default).

Enable use_tanh=True to map pre-activations into \([-1,1]\) before evaluating the basis.

__init__(*, in_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], out_size: typing.Union[int, collections.abc.Sequence[int], typing.Literal['scalar']], width_size: int | None = None, depth: int | None = None, hidden_sizes: collections.abc.Sequence[int] | None = None, degree: int | collections.abc.Sequence[int] = 5, use_tanh: bool = False, scale_mode: typing.Literal['edge', 'input', 'none'] = 'edge', init: typing.Literal['default', 'identity'] = 'default', autoscale: bool = False, final_activation: collections.abc.Callable | None = None, skip_connection: bool = True, use_bias: bool = True, poly: str = 'chebyshev', poly_params: dict | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the KAN at x.

Applies the stacked KANLayers and an optional residual connection, then applies final_activation to the result.


phydrax.nn.FeynmaNN ¤

Feynman path-integral style network with complex hidden blocks.

This model builds a complex-valued hidden state and updates it with a sum-over-paths block. For a hidden vector \(z\) and \(K\) paths, one block computes

\[ \text{Block}(z) = \sum_{k=1}^{K} g_k\;e^{i\,\alpha_k(z)}\,(W_k z + b_k), \]

where \(g=\text{softmax}(\text{logits})\) are learned gates and \(\alpha(z)\) is produced by a small real action network (scaled by phase_scale).

The nonlinearity is ModReLU:

\[ \text{ModReLU}(z)=\mathop{\text{max}}(|z|+b,0)\,\frac{z}{|z|+\varepsilon}. \]

Stacking depth blocks yields a complex latent representation which is mapped to the requested output either by a real readout (concatenating \(\Re z\) and \(\Im z\)) or by a complex linear projection.

__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], width_size: int, depth: int, num_paths: int = 4, width_action: int = 32, phase_scale: float = 1.0, final_activation: collections.abc.Callable | None = None, modrelu_bias_init: float = 0.0, learn_gates: bool = True, key: Key[Array, ''] = jr.key(0), rwf: bool | tuple[float, float] = False, keep_output_complex: bool = False, scan: bool = False) ¤
__call__(x: Array, /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Operator networks¤

Operator-learning models typically consume structured inputs from product domains, e.g. \((\text{data}, x)\) for a dataset factor and a spatial geometry.

Phydrax's Domain.Model(...) wrapper will pass dependency arguments to models as either:

  • point inputs (dense sampling): coordinate arrays with leading batch axes, or
  • coord-separable inputs: tuples of 1D coordinate axes (from CoordSeparableBatch).

Note

Input conventions:

  • DeepONet expects (branch_input, coords...), where coords are either coord_dim separate 1D axes (grid mode) or a single array with trailing dimension coord_dim (point mode).
  • FNO1d expects (grid_values, x_axis) and requires grid evaluation (x_axis must be 1D with length > 1).
  • FNO2d expects (grid_values, x_axis, y_axis) and requires grid evaluation (x_axis, y_axis 1D with length > 1).
  • These operator models are minimal implementations intended as building blocks; for production use you may want additional features (padding/dealiasing, normalization, richer positional encodings, batching utilities, etc.).

phydrax.nn.DeepONet ¤

Minimal Deep Operator Network (DeepONet) for structured inputs.

DeepONet represents an operator \(G\) by a low-rank expansion

\[ (Gf)(x)\;\approx\;\sum_{k=1}^p b_k(f)\,t_k(x), \]

where the branch network maps the input function/data \(f\) to coefficients \(b(f)\in\mathbb{R}^p\) and the trunk network maps coordinates \(x\) to basis values \(t(x)\in\mathbb{R}^p\) (optionally replicated for vector outputs).

Expects a tuple input (branch_input, coords...). For point inputs, coords may be a single array with trailing dimension coord_dim. For coord-separable grid inputs, coords should be coord_dim separate 1D axis arrays.

__init__(*, branch: _AbstractBaseModel, trunk: _AbstractBaseModel, coord_dim: int, latent_size: int, out_size: int | Literal[scalar] = 'scalar', in_size: int | Literal[scalar] = 'scalar') ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

phydrax.nn.FNO1d ¤

Minimal 1D Fourier Neural Operator for coord-separable grid evaluation.

Input convention (structured tuple):

  • data: grid values with leading axis length n (shape (n,) or (n, c_in)),
  • x_axis: 1D coordinate axis of shape (n,) (must have n>1).

The axis values are used for sanity checking and to enforce "grid mode" usage.

__init__(*, in_channels: int | Literal[scalar] = 'scalar', out_channels: int | Literal[scalar] = 'scalar', width: int = 32, depth: int = 4, modes: int = 16, scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

phydrax.nn.FNO2d ¤

Minimal 2D Fourier Neural Operator for coord-separable grid evaluation.

Input convention (structured tuple):

  • data: grid values with leading axes (n_x, n_y) (optionally with channels),
  • x_axis: 1D coordinate axis of shape (n_x,) (must have n_x>1),
  • y_axis: 1D coordinate axis of shape (n_y,) (must have n_y>1).

The axis values are used for sanity checking and to enforce "grid mode" usage.

__init__(*, in_channels: int | Literal[scalar] = 'scalar', out_channels: int | Literal[scalar] = 'scalar', width: int = 32, depth: int = 4, modes: int = 12, modes_y: int | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤