Skip to content

Structured models¤

Models that exploit product-domain structure via low-rank factorization.

Note

Key notes:

  • Separable implements a CP-style expansion \(u=\sum_\ell\prod_i g_i^\ell\).
  • LatentContractionModel generalizes this to named factor models and flexible inputs.
  • LatentExecutionPolicy controls grouped-vs-flat planning preferences and fallback behavior. Supported topology modes are grouped, flat, best_effort_flat, and strict_flat.
  • LatentContractionModel supports layout hints auto, dense_points, coord_separable, hybrid, and full_tensor.
  • Any automatic fallback can be configured to warn, error, or stay silent.
  • For LatentContractionModel, partial_n / dt_n / laplacian can use an exact latent-factor derivative contraction path under backend="jet"; if that path is unavailable, execution falls back according to LatentExecutionPolicy.fallback.
  • SeparableMLP, SeparableKAN, and SeparableFeynmaNN forward scan to their internal scalar submodels.

phydrax.nn.Separable ¤

Separable wrapper using pre-initialized scalar models per coordinate.

Each coordinate model maps a scalar \(x_i\) to latent_size * out_size features (reshaped to \((L,m)\)). The wrapper multiplies per-coordinate features elementwise and sums over the latent axis:

\[ u_o(x)=\sum_{\ell=1}^{L}\prod_{i=1}^{d} g_{i,\ell,o}(x_i). \]

Supports regular array inputs and separable tuple inputs (a tuple of 1D coordinate arrays).

__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int, models: collections.abc.Sequence[phydrax.nn.models.core._base._AbstractBaseModel], output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤

Create a separable wrapper.

Keyword arguments:

  • in_size: Input dimension \(d\) (or "scalar").
  • out_size: Output size \(m\) (or "scalar").
  • latent_size: Rank \(L\) in the separable expansion.
  • models: Sequence of scalar-input models, one per coordinate (and per split_input clone), each returning latent_size * out_size features.
  • output_activation: Optional activation applied after the contraction (wrap it yourself if you want adaptive behavior).
  • split_input: If provided and in_size="scalar", replicates the scalar input across split_input coordinate models.
  • scan: If True, uses jax.lax.scan for compatible regular and clone-group execution paths; otherwise falls back to loops.
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0), **kwargs: typing.Any) -> Array ¤

Evaluate the separable model.

For vector inputs \(x=(x_1,\dots,x_d)\) this computes

\[ u_o(x)=\sum_{\ell=1}^{L}\prod_{i=1}^{d} g_{i,\ell,o}(x_i), \]

where \(g_i\) are the per-coordinate scalar models (including any split_input replication).

Inputs:

  • x: either a single point of shape (d,) (or scalar () in the replicated scalar-input case), or a separable tuple (x_1,...,x_d) of 1D coordinate arrays.

phydrax.nn.SeparableMLP ¤

Separable MLP over coordinate-wise scalar submodels.

This builds one scalar-input MLP per coordinate (and per split_input clone), then wraps them in phydrax.nn.Separable to form a low-rank separable approximation. With latent size \(L\) and output size \(m\), the resulting model has the form

\[ u_o(x)=\sum_{\ell=1}^{L}\prod_{i=1}^{d} g_{i,\ell,o}(x_i), \]

where each coordinate model \(g_i\) maps a scalar \(x_i\) to \(L\cdot m\) features (reshaped to \((L,m)\)).

__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int = 32, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, width_size: int | None = 20, depth: int | None = 6, hidden_sizes: collections.abc.Sequence[int] | None = None, activation: Callable = tanh, final_activation: collections.abc.Callable | None = None, skip_connection: bool = False, rwf: bool | tuple[float, float] = False, use_bias: bool = True, use_final_bias: bool = True, initializer: str = 'glorot_normal', scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤

Create a separable MLP.

SeparableMLP forwards MLP hyperparameters to each internal scalar coordinate model (including scan). The coordinate models output latent_size * out_size features so the wrapper can reshape them to \((L,m)\) and contract.

__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the separable MLP.

Accepts either a vector input (d,) or a separable tuple (x_1,...,x_d) of 1D coordinate arrays (see phydrax.nn.Separable).


phydrax.nn.SeparableKAN ¤

Separable KAN over coordinate-wise scalar submodels.

This builds one scalar-input KAN per coordinate (and per split_input clone), then wraps them in phydrax.nn.Separable to form a low-rank separable approximation. With latent size L and output size m, the resulting model has the form

\[ u_o(x)=\sum_{\ell=1}^{L}\prod_{i=1}^{d} g_{i,\ell,o}(x_i), \]

where each coordinate model g_i maps a scalar x_i to L*m features (reshaped to (L,m)).

__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int = 32, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, width_size: int | None = 20, depth: int | None = 6, hidden_sizes: collections.abc.Sequence[int] | None = None, degree: int | collections.abc.Sequence[int] = 5, use_tanh: bool = False, scale_mode: typing.Literal['edge', 'input', 'none'] = 'edge', init: typing.Literal['default', 'identity'] = 'default', autoscale: bool = False, final_activation: collections.abc.Callable | None = None, skip_connection: bool = True, use_bias: bool = True, poly: str = 'chebyshev', poly_params: dict | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤

Create a separable KAN.

SeparableKAN forwards KAN hyperparameters to each internal scalar coordinate model (including scan). The coordinate models output latent_size * out_size features so the wrapper can reshape them to (L,m) and contract.

__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the separable KAN.

Accepts either a vector input (d,) or a separable tuple (x_1,...,x_d) of 1D coordinate arrays (see phydrax.nn.Separable).


phydrax.nn.SeparableFeynmaNN ¤

Separable FeynmaNN over coordinate-wise scalar submodels.

This builds one scalar-input FeynmaNN per coordinate (and per split_input clone), then wraps them in phydrax.nn.Separable to form a low-rank separable approximation. With latent size \(L\) and output size \(m\), the resulting model has the form

\[ u_o(x)=\sum_{\ell=1}^{L}\prod_{i=1}^{d} g_{i,\ell,o}(x_i), \]

where each coordinate model \(g_i\) maps a scalar \(x_i\) to \(L\cdot m\) features (reshaped to \((L,m)\)).

__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int = 32, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, width_size: int = 20, depth: int = 6, num_paths: int = 4, width_action: int = 32, phase_scale: float = 1.0, final_activation: collections.abc.Callable | None = None, modrelu_bias_init: float = 0.0, learn_gates: bool = True, rwf: bool | tuple[float, float] = False, keep_output_complex: bool = False, scan: bool = False, key: Key[Array, ''] = jr.key(0)) ¤

Create a separable FeynmaNN.

SeparableFeynmaNN forwards FeynmaNN hyperparameters to each internal scalar coordinate model (including scan). The coordinate models output latent_size * out_size features so the wrapper can reshape them to \((L,m)\) and contract.

__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate the separable FeynmaNN.

Accepts either a vector input (d,) or a separable tuple (x_1,...,x_d) of 1D coordinate arrays (see phydrax.nn.Separable).


phydrax.nn.LatentContractionModel ¤

Latent contraction wrapper for product-domain factor models.

This implements a low-rank (CP-style) factorization over a product input space. For factors \(x=(x^{(1)},\dots,x^{(d)})\) and latent size \(L\), each factor model returns features that can be reshaped to \(g_i(x^{(i)})\in\mathbb{R}^{L\times m}\) (with \(m=\texttt{out\_size}\)). The contraction returns

\[ u_o(x)=\sum_{\ell=1}^{L}\prod_{i=1}^{d} g_{i,\ell,o}(x^{(i)}). \]

Each factor model may return either: - \(L\cdot m\) features (interpreted as \((L,m)\)), or - \(L\) features (broadcast across the \(m\) outputs).

When this model is wrapped with domain.Model(...), differential operators (partial_n, dt_n, laplacian) can use an exact latent-factor derivative contraction path. This preserves derivative semantics while avoiding differentiation through the full contracted output graph when possible.

__init__(*, latent_size: int, out_size: typing.Union[int, typing.Literal['scalar']], factors: collections.abc.Mapping[str, phydrax.nn.models.core._base._AbstractBaseModel] | None = None, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, execution_policy: phydrax.nn.models.wrappers._separable_wrappers.LatentExecutionPolicy | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0), **factor_models: _AbstractBaseModel) ¤

Create a latent contraction model.

Keyword arguments:

  • latent_size: Rank \(L\) of the factorization.
  • out_size: Output size \(m\) (or "scalar").
  • factors / **factor_models: Factor models \(g_i\) mapping factor inputs to latent features.
  • output_activation: Optional activation applied after contraction (wrap it yourself if you want adaptive behavior).
  • keep_outputs_complex: If True, keeps complex outputs when the factors are complex-valued; otherwise returns the real part.
  • scan: If True, uses jax.lax.scan for aligned factor execution when factor model structure is compatible; otherwise falls back to the loop path.

Each factor model should return either \(L\) features or \(L\cdot m\) features so the wrapper can reshape to \((L,m)\).

__call__(x: Array | tuple[Array, ...] | collections.abc.Mapping[str, typing.Any], /, *, key: Key[Array, ''] = jr.key(0), **kwargs: typing.Any) -> Array ¤

phydrax.nn.LatentExecutionPolicy ¤

Execution policy for latent contraction product models.

This controls high-level planning preferences and fallback behavior.

__init__(*, topology: typing.Literal['grouped', 'flat', 'best_effort_flat', 'strict_flat'] = 'grouped', layout: typing.Literal['auto', 'dense_points', 'coord_separable', 'hybrid', 'full_tensor'] = 'auto', fallback: typing.Literal['warn', 'error', 'silent'] = 'warn') ¤

phydrax.nn.ConcatenatedModel ¤

Concatenate outputs from multiple models.

Given models \(\{m_k\}_{k=1}^K\) with shared input space, this wrapper returns

\[ y(x)=\operatorname{concat}\big(m_1(x),\dots,m_K(x)\big), \]

where concatenation is performed along axis (by default the feature axis). Scalar outputs (out_size="scalar") are treated as length-1 feature vectors for concatenation.

__init__(models: collections.abc.Sequence[phydrax.nn.models.core._base._AbstractBaseModel], *, axis: int = -1) ¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array ¤

Evaluate all child models at x and concatenate their outputs.

If more than one model is present, the input key is split so each child receives its own subkey.