Structured models¤
Models that exploit product-domain structure via low-rank factorization.
Note
Key notes:
Separableimplements a CP-style expansion \(u=\sum_\ell\prod_i g_i^\ell\).LatentContractionModelgeneralizes this to named factor models and flexible inputs.LatentExecutionPolicycontrols grouped-vs-flat planning preferences and fallback behavior. Supported topology modes aregrouped,flat,best_effort_flat, andstrict_flat.LatentContractionModelsupports layout hintsauto,dense_points,coord_separable,hybrid, andfull_tensor.- Any automatic fallback can be configured to warn, error, or stay silent.
- For
LatentContractionModel,partial_n/dt_n/laplaciancan use an exact latent-factor derivative contraction path underbackend="jet"; if that path is unavailable, execution falls back according toLatentExecutionPolicy.fallback. SeparableMLP,SeparableKAN, andSeparableFeynmaNNforwardscanto their internal scalar submodels.
phydrax.nn.Separable
¤
Separable wrapper using pre-initialized scalar models per coordinate.
Each coordinate model maps a scalar \(x_i\) to latent_size * out_size features
(reshaped to \((L,m)\)). The wrapper multiplies per-coordinate features
elementwise and sums over the latent axis:
Supports regular array inputs and separable tuple inputs (a tuple of 1D coordinate arrays).
__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int, models: collections.abc.Sequence[phydrax.nn.models.core._base._AbstractBaseModel], output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
Create a separable wrapper.
Keyword arguments:
in_size: Input dimension \(d\) (or"scalar").out_size: Output size \(m\) (or"scalar").latent_size: Rank \(L\) in the separable expansion.models: Sequence of scalar-input models, one per coordinate (and persplit_inputclone), each returninglatent_size * out_sizefeatures.output_activation: Optional activation applied after the contraction (wrap it yourself if you want adaptive behavior).split_input: If provided andin_size="scalar", replicates the scalar input acrosssplit_inputcoordinate models.scan: IfTrue, usesjax.lax.scanfor compatible regular and clone-group execution paths; otherwise falls back to loops.
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0), **kwargs: typing.Any) -> Array
¤
Evaluate the separable model.
For vector inputs \(x=(x_1,\dots,x_d)\) this computes
where \(g_i\) are the per-coordinate scalar models (including any
split_input replication).
Inputs:
x: either a single point of shape(d,)(or scalar()in the replicated scalar-input case), or a separable tuple(x_1,...,x_d)of 1D coordinate arrays.
phydrax.nn.SeparableMLP
¤
Separable MLP over coordinate-wise scalar submodels.
This builds one scalar-input MLP per coordinate (and per split_input
clone), then wraps them in phydrax.nn.Separable to form a low-rank separable
approximation. With latent size \(L\) and output size \(m\), the resulting model
has the form
where each coordinate model \(g_i\) maps a scalar \(x_i\) to \(L\cdot m\) features (reshaped to \((L,m)\)).
__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int = 32, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, width_size: int | None = 20, depth: int | None = 6, hidden_sizes: collections.abc.Sequence[int] | None = None, activation: Callable = tanh, final_activation: collections.abc.Callable | None = None, skip_connection: bool = False, rwf: bool | tuple[float, float] = False, use_bias: bool = True, use_final_bias: bool = True, initializer: str = 'glorot_normal', scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
Create a separable MLP.
SeparableMLP forwards MLP hyperparameters to each internal scalar
coordinate model (including scan). The coordinate models output
latent_size * out_size
features so the wrapper can reshape them to \((L,m)\) and contract.
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate the separable MLP.
Accepts either a vector input (d,) or a separable tuple (x_1,...,x_d)
of 1D coordinate arrays (see phydrax.nn.Separable).
phydrax.nn.SeparableKAN
¤
Separable KAN over coordinate-wise scalar submodels.
This builds one scalar-input KAN per coordinate (and per split_input
clone), then wraps them in phydrax.nn.Separable to form a low-rank separable
approximation. With latent size L and output size m, the resulting model
has the form
where each coordinate model g_i maps a scalar x_i to L*m features
(reshaped to (L,m)).
__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int = 32, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, width_size: int | None = 20, depth: int | None = 6, hidden_sizes: collections.abc.Sequence[int] | None = None, degree: int | collections.abc.Sequence[int] = 5, use_tanh: bool = False, scale_mode: typing.Literal['edge', 'input', 'none'] = 'edge', init: typing.Literal['default', 'identity'] = 'default', autoscale: bool = False, final_activation: collections.abc.Callable | None = None, skip_connection: bool = True, use_bias: bool = True, poly: str = 'chebyshev', poly_params: dict | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
Create a separable KAN.
SeparableKAN forwards KAN hyperparameters to each internal scalar
coordinate model (including scan). The coordinate models output
latent_size * out_size
features so the wrapper can reshape them to (L,m) and contract.
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate the separable KAN.
Accepts either a vector input (d,) or a separable tuple (x_1,...,x_d)
of 1D coordinate arrays (see phydrax.nn.Separable).
phydrax.nn.SeparableFeynmaNN
¤
Separable FeynmaNN over coordinate-wise scalar submodels.
This builds one scalar-input FeynmaNN per coordinate (and per split_input
clone), then wraps them in phydrax.nn.Separable to form a low-rank separable
approximation. With latent size \(L\) and output size \(m\), the resulting model
has the form
where each coordinate model \(g_i\) maps a scalar \(x_i\) to \(L\cdot m\) features (reshaped to \((L,m)\)).
__init__(*, in_size: typing.Union[int, typing.Literal['scalar']], out_size: typing.Union[int, typing.Literal['scalar']], latent_size: int = 32, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, split_input: int | None = None, width_size: int = 20, depth: int = 6, num_paths: int = 4, width_action: int = 32, phase_scale: float = 1.0, final_activation: collections.abc.Callable | None = None, modrelu_bias_init: float = 0.0, learn_gates: bool = True, rwf: bool | tuple[float, float] = False, keep_output_complex: bool = False, scan: bool = False, key: Key[Array, ''] = jr.key(0))
¤
Create a separable FeynmaNN.
SeparableFeynmaNN forwards FeynmaNN hyperparameters to each internal
scalar coordinate model (including scan). The coordinate models output
latent_size * out_size
features so the wrapper can reshape them to \((L,m)\) and contract.
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate the separable FeynmaNN.
Accepts either a vector input (d,) or a separable tuple (x_1,...,x_d)
of 1D coordinate arrays (see phydrax.nn.Separable).
phydrax.nn.LatentContractionModel
¤
Latent contraction wrapper for product-domain factor models.
This implements a low-rank (CP-style) factorization over a product input space. For factors \(x=(x^{(1)},\dots,x^{(d)})\) and latent size \(L\), each factor model returns features that can be reshaped to \(g_i(x^{(i)})\in\mathbb{R}^{L\times m}\) (with \(m=\texttt{out\_size}\)). The contraction returns
Each factor model may return either: - \(L\cdot m\) features (interpreted as \((L,m)\)), or - \(L\) features (broadcast across the \(m\) outputs).
When this model is wrapped with domain.Model(...), differential operators
(partial_n, dt_n, laplacian) can use an exact latent-factor derivative
contraction path. This preserves derivative semantics while avoiding
differentiation through the full contracted output graph when possible.
__init__(*, latent_size: int, out_size: typing.Union[int, typing.Literal['scalar']], factors: collections.abc.Mapping[str, phydrax.nn.models.core._base._AbstractBaseModel] | None = None, output_activation: collections.abc.Callable | None = None, keep_outputs_complex: bool = False, execution_policy: phydrax.nn.models.wrappers._separable_wrappers.LatentExecutionPolicy | None = None, scan: bool = False, key: Key[Array, ''] = jr.key(0), **factor_models: _AbstractBaseModel)
¤
Create a latent contraction model.
Keyword arguments:
latent_size: Rank \(L\) of the factorization.out_size: Output size \(m\) (or"scalar").factors/**factor_models: Factor models \(g_i\) mapping factor inputs to latent features.output_activation: Optional activation applied after contraction (wrap it yourself if you want adaptive behavior).keep_outputs_complex: IfTrue, keeps complex outputs when the factors are complex-valued; otherwise returns the real part.scan: IfTrue, usesjax.lax.scanfor aligned factor execution when factor model structure is compatible; otherwise falls back to the loop path.
Each factor model should return either \(L\) features or \(L\cdot m\) features so the wrapper can reshape to \((L,m)\).
__call__(x: Array | tuple[Array, ...] | collections.abc.Mapping[str, typing.Any], /, *, key: Key[Array, ''] = jr.key(0), **kwargs: typing.Any) -> Array
¤
phydrax.nn.LatentExecutionPolicy
¤
Execution policy for latent contraction product models.
This controls high-level planning preferences and fallback behavior.
__init__(*, topology: typing.Literal['grouped', 'flat', 'best_effort_flat', 'strict_flat'] = 'grouped', layout: typing.Literal['auto', 'dense_points', 'coord_separable', 'hybrid', 'full_tensor'] = 'auto', fallback: typing.Literal['warn', 'error', 'silent'] = 'warn')
¤
phydrax.nn.ConcatenatedModel
¤
Concatenate outputs from multiple models.
Given models \(\{m_k\}_{k=1}^K\) with shared input space, this wrapper returns
where concatenation is performed along axis (by default the feature axis).
Scalar outputs (out_size="scalar") are treated as length-1 feature vectors
for concatenation.
__init__(models: collections.abc.Sequence[phydrax.nn.models.core._base._AbstractBaseModel], *, axis: int = -1)
¤
__call__(x: Array | tuple[Array, ...], /, *, key: Key[Array, ''] = jr.key(0)) -> Array
¤
Evaluate all child models at x and concatenate their outputs.
If more than one model is present, the input key is split so each child
receives its own subkey.