Solvers and training¤
This guide explains how FunctionalSolver evaluates losses and how solve() updates parameters.
What FunctionalSolver does¤
A FunctionalSolver is a lightweight orchestrator that holds:
functions: a mapping{name: DomainFunction}of the current fields,constraints: a list/tuple of constraint objects, each producing a scalar loss,- optional
constraint_pipelines: enforced-constraint pipelines that replace raw fields with ansatz functions satisfying selected conditions exactly.
The total objective is the sum of constraint losses:
Loss evaluation (loss(...))¤
When you call solver.loss(key=...):
1) If enforced pipelines are configured, the current functions mapping is transformed into
ansatz functions via solver.ansatz_functions().
2) The provided PRNG key is split into one subkey per constraint.
3) Each constraint loss is evaluated and summed.
Additional keyword arguments are forwarded to each constraint's .loss(...) method.
Enforced-constraint pipelines¤
Enforced pipelines are optional, but common when you want to enforce boundary/initial conditions exactly by construction (rather than penalizing violations).
Pipelines are applied before any soft constraints are evaluated, so all residuals see the post-processed (enforced) fields.
See API → Solver → Enforced constraint pipelines for the pipeline types and constructors.
Training (solve(...))¤
FunctionalSolver.solve(...) runs an optimization loop over the parameters contained inside
solver.functions. Under the hood it uses Equinox to split the function PyTree into:
- trainable parameters: inexact arrays (floating/complex arrays),
- static part: everything else.
Optimizer support¤
optim= can be:
- an Optax
GradientTransformation(standard first-order optimizers), - an Optax
GradientTransformationExtraArgs(line-search style optimizers), - an evosax algorithm instance (evolutionary strategies).
Iteration counter (iter_)¤
During training, the current epoch index is passed to each constraint loss as iter_ (as a JAX
scalar), so constraints can implement schedules (annealing, curriculum weights, etc.).
jit and keep_best¤
- If
jit=True, the per-step update is JIT-compiled when using standard Optax optimizers. (Line-search optimizers are not JIT-wrapped.) - If
keep_best=True, the returned solver uses the best parameter set observed over all epochs (by objective value); otherwise it returns the final parameters.
Minimal example¤
import equinox as eqx
import jax.random as jr
import optax
import phydrax as phx
geom = phx.domain.Interval1d(0.0, 1.0)
# Trainable scalar field u_theta(x)
model = phx.nn.MLP(
in_size=1,
out_size="scalar",
width_size=16,
depth=2,
key=eqx.internal.doc_repr(jr.key(0), "jr.key(0)"),
)
u = geom.Model("x")(model)
structure = phx.domain.ProductStructure((("x",),))
# A toy interior objective that encourages u(x) ≈ 0 in Ω (replace with a PDE operator in real use).
constraint = phx.constraints.ContinuousPointwiseInteriorConstraint(
"u",
geom,
operator=lambda f: f,
num_points=128,
structure=structure,
reduction="mean",
)
solver = phx.solver.FunctionalSolver(functions={"u": u}, constraints=[constraint])
loss0 = solver.loss(key=eqx.internal.doc_repr(jr.key(0), "jr.key(0)"))
solver = solver.solve(num_iter=20, optim=optax.adam(1e-3), seed=0)
loss1 = solver.loss(key=jr.key(1))
print(loss0, loss1)
EvoSax example (gradient-free)¤
To use evolutionary strategies, pass an evosax algorithm instance as optim=...:
import equinox as eqx
from evosax import algorithms as evo_algos
import phydrax as phx
# Continuing from the minimal example above:
# solver = phx.solver.FunctionalSolver(functions={"u": u}, constraints=[constraint])
# evosax expects a "solution" PyTree matching the trainable parameter structure.
params, _ = eqx.partition(solver.functions, eqx.is_inexact_array)
algo = evo_algos.Open_ES(population_size=8, solution=params)
solver = solver.solve(num_iter=20, optim=algo, seed=0)