Functional solver¤
FunctionalSolver is the main entry point for turning a set of fields and constraints
into a differentiable objective.
For a conceptual overview (loss evaluation, enforced pipelines, training loop behavior), see Guides → Solvers and training.
Note
Key notes:
loss(...)evaluates the total objective at the current parameters.ansatz_functions()returns fields after applying enforced pipelines (if configured).solve(...)updates parameters insidefunctionsusing Optax or evosax optimizers.
Typical usage¤
import jax.random as jr
import optax
import phydrax as phx
geom = phx.domain.Interval1d(0.0, 1.0)
model = phx.nn.MLP(in_size=1, out_size="scalar", width_size=16, depth=2, key=jr.key(0))
u = geom.Model("x")(model)
structure = phx.domain.ProductStructure((("x",),))
constraint = phx.constraints.ContinuousPointwiseInteriorConstraint(
"u",
geom,
operator=lambda f: f,
num_points=128,
structure=structure,
reduction="mean",
)
solver = phx.solver.FunctionalSolver(functions={"u": u}, constraints=[constraint])
loss0 = solver.loss(key=jr.key(0))
solver = solver.solve(num_iter=20, optim=optax.adam(1e-3), seed=0)
loss1 = solver.loss(key=jr.key(1))
phydrax.solver.FunctionalSolver
¤
Assemble constraints into a differentiable scalar loss.
A FunctionalSolver holds:
- a mapping of named fields (as
DomainFunctions), e.g. \(u_\theta\); - a collection of constraints \(\ell_i\) producing scalar penalties.
The solver loss is the (weighted) sum
Optionally, enforced constraint pipelines can be applied to replace the raw fields with ansatz functions that satisfy selected boundary/initial conditions exactly.
Evaluation
ansatz_functions()applies any enforced pipelines and returns the effective field mapping used by constraints.loss(key=...)splits the provided PRNG key into one subkey per constraint and sums the resulting scalar losses.
Training
solve(...) optimizes the inexact-array leaves inside functions (via Equinox
partitioning), and passes an iter_ counter through to constraint losses so that
constraints can implement schedules.
__init__(*, functions: Mapping[str, DomainFunction], constraints: AbstractConstraint | Sequence[AbstractConstraint], constraint_pipelines: EnforcedConstraintPipelines | None = None, constraint_terms: Sequence[SingleFieldEnforcedConstraint | MultiFieldEnforcedConstraint] = (), interior_data_terms: Sequence[EnforcedInteriorData] = (), evolution_var: str = 't', include_identity_remainder: bool = True, boundary_weight_num_reference: int = 500000, boundary_weight_sampler: str = 'latin_hypercube', boundary_weight_key: Key[Array, ''] = jr.key(0))
¤
Create a functional solver.
Arguments:
functions: Mapping{name: DomainFunction}defining the fields.constraints: One or moreAbstractConstraintinstances.constraint_pipelines: Optional pre-built enforced constraint pipelines. If provided, do not also passconstraint_terms/interior_data_terms.constraint_terms: Enforced constraint terms used to buildEnforcedConstraintPipelines(boundary/initial ansätze).interior_data_terms: Enforced interior data sources used to buildEnforcedConstraintPipelines.evolution_var: Name of the time-like label used for initial staging (default"t").include_identity_remainder: Boundary blending option for enforced pipelines.boundary_weight_num_reference: Number of reference samples used for boundary blending weights.boundary_weight_sampler: Sampler used to draw boundary blending references.boundary_weight_key: PRNG key used to draw boundary blending references.
ansatz_functions() -> frozendict[str, DomainFunction]
¤
Return the current field mapping after applying enforced pipelines (if configured).
__getitem__(var: str) -> DomainFunction
¤
Convenience accessor: return the (ansatz) field named var.
loss(*, key: Key[Array, ''] = jr.key(0), **kwargs: Any) -> Array
¤
Evaluate the total loss \(L=\sum_i \ell_i\) over all configured constraints.
This:
1) applies enforced pipelines (if configured),
2) splits key into one subkey per constraint,
3) sums constraint.loss(...) over all constraints.
Any additional keyword arguments are forwarded to each constraint.
solve(*, num_iter: int, optim: optax.GradientTransformation | optax.GradientTransformationExtraArgs | Any | None = None, seed: int = 0, jit: bool = True, keep_best: bool = True, log_every: int = 1, log_constraints: bool = True, log_path: str | Path | None = None) -> FunctionalSolver
¤
Run the training loop and return an updated solver.
The optimization updates the inexact-array leaves of self.functions.
- If
optimis an OptaxGradientTransformation, a standard gradient step is used. - If
optimis an OptaxGradientTransformationExtraArgs, a line-search style update is used. - Otherwise,
optimis treated as an evosax algorithm instance.
During training, each constraint loss receives an iter_ keyword argument (the
1-based iteration index as a JAX scalar) to enable schedules.
Logging:
- If
log_every > 0, prints a progress line everylog_everyiterations. - If
log_constraints=True, also prints the per-constraint loss breakdown. - If
log_pathis provided, logs are written to that file instead of stdout.