Skip to content

Functional solver¤

FunctionalSolver is the main entry point for turning a set of fields and constraints into a differentiable objective.

For a conceptual overview (loss evaluation, enforced pipelines, training loop behavior), see Guides → Solvers and training.

Note

Key notes:

  • loss(...) evaluates the total objective at the current parameters.
  • ansatz_functions() returns fields after applying enforced pipelines (if configured).
  • solve(...) updates parameters inside functions using Optax or evosax optimizers.

Typical usage¤

import jax.random as jr
import optax
import phydrax as phx

geom = phx.domain.Interval1d(0.0, 1.0)

model = phx.nn.MLP(in_size=1, out_size="scalar", width_size=16, depth=2, key=jr.key(0))
u = geom.Model("x")(model)

structure = phx.domain.ProductStructure((("x",),))
constraint = phx.constraints.ContinuousPointwiseInteriorConstraint(
    "u",
    geom,
    operator=lambda f: f,
    num_points=128,
    structure=structure,
    reduction="mean",
)

solver = phx.solver.FunctionalSolver(functions={"u": u}, constraints=[constraint])
loss0 = solver.loss(key=jr.key(0))
solver = solver.solve(num_iter=20, optim=optax.adam(1e-3), seed=0)
loss1 = solver.loss(key=jr.key(1))

phydrax.solver.FunctionalSolver ¤

Assemble constraints into a differentiable scalar loss.

A FunctionalSolver holds:

  • a mapping of named fields (as DomainFunctions), e.g. \(u_\theta\);
  • a collection of constraints \(\ell_i\) producing scalar penalties.

The solver loss is the (weighted) sum

\[ L = \sum_i \ell_i. \]

Optionally, enforced constraint pipelines can be applied to replace the raw fields with ansatz functions that satisfy selected boundary/initial conditions exactly.

Evaluation

  • ansatz_functions() applies any enforced pipelines and returns the effective field mapping used by constraints.
  • loss(key=...) splits the provided PRNG key into one subkey per constraint and sums the resulting scalar losses.

Training

solve(...) optimizes the inexact-array leaves inside functions (via Equinox partitioning), and passes an iter_ counter through to constraint losses so that constraints can implement schedules.

__init__(*, functions: Mapping[str, DomainFunction], constraints: AbstractConstraint | Sequence[AbstractConstraint], constraint_pipelines: EnforcedConstraintPipelines | None = None, constraint_terms: Sequence[SingleFieldEnforcedConstraint | MultiFieldEnforcedConstraint] = (), interior_data_terms: Sequence[EnforcedInteriorData] = (), evolution_var: str = 't', include_identity_remainder: bool = True, boundary_weight_num_reference: int = 500000, boundary_weight_sampler: str = 'latin_hypercube', boundary_weight_key: Key[Array, ''] = jr.key(0)) ¤

Create a functional solver.

Arguments:

  • functions: Mapping {name: DomainFunction} defining the fields.
  • constraints: One or more AbstractConstraint instances.
  • constraint_pipelines: Optional pre-built enforced constraint pipelines. If provided, do not also pass constraint_terms/interior_data_terms.
  • constraint_terms: Enforced constraint terms used to build EnforcedConstraintPipelines (boundary/initial ansätze).
  • interior_data_terms: Enforced interior data sources used to build EnforcedConstraintPipelines.
  • evolution_var: Name of the time-like label used for initial staging (default "t").
  • include_identity_remainder: Boundary blending option for enforced pipelines.
  • boundary_weight_num_reference: Number of reference samples used for boundary blending weights.
  • boundary_weight_sampler: Sampler used to draw boundary blending references.
  • boundary_weight_key: PRNG key used to draw boundary blending references.
ansatz_functions() -> frozendict[str, DomainFunction] ¤

Return the current field mapping after applying enforced pipelines (if configured).

__getitem__(var: str) -> DomainFunction ¤

Convenience accessor: return the (ansatz) field named var.

loss(*, key: Key[Array, ''] = jr.key(0), **kwargs: Any) -> Array ¤

Evaluate the total loss \(L=\sum_i \ell_i\) over all configured constraints.

This:

1) applies enforced pipelines (if configured), 2) splits key into one subkey per constraint, 3) sums constraint.loss(...) over all constraints.

Any additional keyword arguments are forwarded to each constraint.

solve(*, num_iter: int, optim: optax.GradientTransformation | optax.GradientTransformationExtraArgs | Any | None = None, seed: int = 0, jit: bool = True, keep_best: bool = True, log_every: int = 1, log_constraints: bool = True, log_path: str | Path | None = None) -> FunctionalSolver ¤

Run the training loop and return an updated solver.

The optimization updates the inexact-array leaves of self.functions.

  • If optim is an Optax GradientTransformation, a standard gradient step is used.
  • If optim is an Optax GradientTransformationExtraArgs, a line-search style update is used.
  • Otherwise, optim is treated as an evosax algorithm instance.

During training, each constraint loss receives an iter_ keyword argument (the 1-based iteration index as a JAX scalar) to enable schedules.

Logging:

  • If log_every > 0, prints a progress line every log_every iterations.
  • If log_constraints=True, also prints the per-constraint loss breakdown.
  • If log_path is provided, logs are written to that file instead of stdout.