repo status: Active license: GPL-3.0 Python ≥3.9 JAX ≥0.4.26 v2605

underPINN-v2605

A modular, GPU-accelerated Physics-Informed Neural Network framework built on JAX + Flax + Optax — from harmonic oscillators to 3-D turbulent Navier-Stokes, all JIT-compiled via XLA.

Get Started Browse Examples View on GitHub
16
Physics Examples
5
PDE Solvers
12+
CLI Runners
500×
Less GPU Dispatch
3D
Navier-Stokes
Auto
Restart/Resume

Everything you need for PINN research

From simple ODEs to full 3-D turbulence — one coherent API, no boilerplate. Automatic GPU memory management, fault-tolerant restart, and YAML-driven experiments out of the box.

Network Architectures

MLP

Standard multi-layer perceptron with tanh activations. Configurable depth and width via a simple layer list — e.g. [2, 64, 64, 64, 1] for a space-time Burgers network.

~

FourierMLP

Trainable random Fourier feature embeddings prepended to a standard MLP. Essential for oscillatory solutions — Helmholtz, wave, high-Re flows — where plain MLPs exhibit spectral bias.

■■

FBPINN + SimpleGate

Overlapping subdomain decomposition with sigmoid partition-of-unity windows. HybridAttention and SimpleGate gated residual blocks inside each subdomain for complex geometries.

Training

lax.scan XLA Fusion

Fuse N gradient steps into a single XLA kernel — 50–500× less Python dispatch overhead on GPU. Works with all solvers; callbacks fire every N outer epochs.

Cosine LR Decay

Via optax.cosine_decay_schedule. Integrates seamlessly with TrainingConfig. Recommended for runs longer than 2 000 epochs — delivers free accuracy improvement.

RAR-D Adaptive Resampling

Periodically replaces collocation points with samples drawn proportional to |residual|^k (Lu et al., 2021). Focuses compute on high-error regions without changing total batch size.

RBA Element-wise Weighting

Residual-based adaptivity assigns per-point loss weights so boundary and collocation losses are automatically balanced. Especially effective for stiff boundary conditions.

▶▮

EarlyStopping

Monitors a metric (default: total loss) and halts training after patience epochs without improvement. Works correctly inside lax.scan loops — fires at the outer-step boundary.

TrainingConfig Dataclass

Centralises all hyperparameters with runtime validation. A single TrainingConfig object is passed to every solver — no kwargs scattered across multiple calls.

Infrastructure

Restart / Resume

RestartManager snapshots params + opt state + loss histories every N epochs. Re-run an interrupted job and it picks up exactly where it left off. Config-hash check prevents stale resumption.

💾

GPU Memory Management

Sets XLA_PYTHON_CLIENT_PREALLOCATE=false automatically before import jax. A 3-layer MLP uses ~200 MB instead of reserving 73 GB on an 80 GB A100.

📤

Checkpoint & Inference

Save Flax params as params.msgpack with a JSON sidecar. Reload in one line with ModelPredictor.from_meta(path) — no need to re-specify the architecture.

🚀

Transfer Learning

Parameter transfer (different Re / diffusivity) and temporal transfer (extended time horizon). Warm-start any solver from a pre-trained model with one call.

🔍

Inverse Problems

Joint optimisation of network weights + physics parameters. Log-parameterisation ensures positivity without constraints. Demonstrated on thermal diffusivity recovery from sparse noisy data.

YAML-Driven CLI + Sweeps

Every hyperparameter lives in a YAML config. Cartesian product sweeps built-in. run, sweep, bench, list, show, version subcommands — no code changes needed.

Get underPINN running

Pure Python — install in editable mode from the repo root so all examples and the CLI resolve imports automatically.

CPU / Development

Terminal
pip install jax flax optax matplotlib scipy \
            shapely pandas pyyaml

GPU (CUDA 12)

Terminal
pip install -U "jax[cuda12]"
pip install -r requirements-gpu.txt

From Source (recommended)

Terminal
git clone https://github.com/Aeroscience-Computations-Analysis-Lab/underPINN.git
cd underPINN-v2605
pip install -e .

# Verify GPU is visible
python -c "import jax; print(jax.devices())"
# Expected on GPU: [CudaDevice(id=0)]

Requirements

jax[cpu] ≥ 0.4.26

JIT compilation, autodiff, PRNG

flax ≥ 0.8.0

Neural network layers and parameter trees

optax ≥ 0.2.0

Adam, cosine decay, gradient clipping

numpy, scipy, matplotlib

Numerics, exact solutions, plotting

shapely ≥ 2.0

Arbitrary polygon geometry support

pyyaml ≥ 6.0

YAML config loading and merging

Why does nvidia-smi show 73 GB on import?

JAX's XLA BFC allocator pre-reserves ~90% of all free VRAM the moment import jax executes — before any tensor is created — to prevent memory fragmentation during training. underPINN disables this automatically.

Default JAX behaviour (NOT what underPINN uses)

On an 80 GB A100, import jax immediately reserves ~73 GB of VRAM even if your model only needs 200 MB. This blocks other processes from using the GPU and makes it look like your job consumed the entire card.

underPINN sets XLA_PYTHON_CLIENT_PREALLOCATE=false automatically

The environment variable is set in underPINN/__main__.py (for CLI runs) and at the top of every example script (for direct python examples/... runs) before import jax. You get on-demand GPU memory growth out of the box — no configuration needed.

Manual control via environment variables

bash
# On-demand growth (default in underPINN) — frees unreserved VRAM for other jobs
export XLA_PYTHON_CLIENT_PREALLOCATE=false

# Hard cap — useful when sharing a node; limits to e.g. 20% of VRAM
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.20

# Platform allocator — no XLA pool at all (slowest, minimal fragmentation)
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

# Multi-GPU: restrict to a single device (e.g. GPU 1)
export CUDA_VISIBLE_DEVICES=1

Programmatic override (must be BEFORE import jax)

Python
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.15"
import jax  # now uses at most 15% of VRAM

Typical actual VRAM usage (preallocation disabled)

ProblemNetworkVRAM (approx)
Burgers 1-D[2,64,64,64,1]~200 MB
Wave 1-DFourierMLP~300 MB
Helmholtz 2-DFourierMLP~400 MB
LDC 2-DFBPINN~800 MB
Airfoil 2-D[2,128,128,128,3]~1.2 GB
Pipe Flow 3-D[3,64,64,64,64,4]~2.0 GB
Compressible Ramp[2,80,80,80,80,80,4]~1.8 GB
k-ε TurbulenceFBPINN~3.0 GB

Up and running in 5 minutes

Six patterns — pick the one that fits your use case. All examples set XLA_PYTHON_CLIENT_PREALLOCATE=false automatically.

Terminal — CLI (zero Python)
# Single run — point at any registered YAML config
python -m underPINN run  examples/burgers/config.yaml
python -m underPINN run  examples/wave/config.yaml
python -m underPINN run  examples/pipe_flow/pipe_flow.yaml
python -m underPINN run  examples/ramp/config.yaml

# Hyperparameter sweep (Cartesian product)
python -m underPINN sweep examples/burgers/burgers_nu_sweep.yaml

# Benchmark all problems
python -m underPINN bench

# List all registered runners
python -m underPINN list

# Print resolved config without training
python -m underPINN show examples/wave/config.yaml

# Print framework version
python -m underPINN version
Python API — Burgers PINN with restart
import os
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
import jax, optax
from underPINN.nn.mlp                  import MLP
from underPINN.pde.burgers              import BurgersPDE
from underPINN.losses.loss               import PINNLoss
from underPINN.solver.fbpinn             import FBPINNSolver
from underPINN.core.config              import TrainingConfig
from underPINN.callbacks.logging         import ConsoleLogger
from underPINN.callbacks.early_stopping  import EarlyStopping

model  = MLP(layers=[2, 64, 64, 64, 1])
pde    = BurgersPDE(model, nu=0.01)
loss   = PINNLoss(model, pde, ic_weight=100.0, bc_weight=10.0, rba=True)
solver = FBPINNSolver(model, pde, loss=loss)
solver.init(jax.random.PRNGKey(0))

config = TrainingConfig(
    epochs             = 5000,
    lr                 = 1e-3,
    lr_schedule        = optax.cosine_decay_schedule(1e-3, 5000, alpha=1e-2),
    batch_r            = 2048,
    log_every          = 500,
    out_dir            = "outputs/burgers",   # enables auto-restart
    save_restart_every = 500,
    callbacks = [
        ConsoleLogger(log_every=500),
        EarlyStopping(patience=400),
    ],
)
solver.train(*data, config=config)
examples/burgers/config.yaml
problem: burgers       # selects the runner

network:
  type  : mlp
  layers: [2, 64, 64, 64, 1]

physics:
  nu: 0.01

data:
  T: 2.0
  n_collocation: 6000
  n_ic: 200
  n_bc: 200

training:
  epochs                  : 5000
  lr                      : 1.0e-3
  early_stopping_patience : 400
  save_restart_every      : 500  # snapshot every 500 epochs

loss:
  ic_weight: 100.0
  bc_weight: 10.0
  rba      : true

output:
  dir        : outputs/burgers
  save_params: true
$ python -m underPINN run examples/burgers/config.yaml
Checkpoint & Inference
from underPINN.utils.checkpoint import ModelPredictor
import jax.numpy as jnp

# Option A — auto-rebuild model from saved metadata (zero boilerplate)
predictor = ModelPredictor.from_meta("outputs/burgers/")

# Option B — supply model explicitly
from underPINN.nn.mlp import MLP
predictor = ModelPredictor.from_checkpoint(
    MLP(layers=[2, 64, 64, 64, 1]),
    "outputs/burgers/",
)

# Inference
x_test = jnp.linspace(-1.0, 1.0, 500)
t_test = jnp.full(500, 0.8)
u = predictor.predict(jnp.stack([x_test, t_test], axis=1))
Transfer Learning
# Phase 1: train source model (e.g. Burgers nu=0.1)
solver_src.train(*data_src, config=cfg_src)
solver_src.save_checkpoint("outputs/source/")

# Phase 2: warm-start target from source weights (e.g. nu=0.01)
solver_tgt.load_params(solver_src.params)       # or restore_checkpoint(...)
solver_tgt.train(*data_tgt, config=cfg_tgt)     # lower lr recommended (3e-4)
# Converges 2-3x faster than training from scratch
GPU Acceleration — lax.scan + RAR-D
# Fuse 100 gradient steps per XLA kernel + adaptive resampling
config = TrainingConfig(
    epochs          = 5000,
    lr              = 1e-3,
    n_scan_steps    = 100,   # 50 Python calls instead of 5000
    resample_period = 5,     # RAR-D every 5 outer steps (= 500 epochs)
    resample_k      = 1.0,   # probability proportional to |residual|^1
    callbacks       = [ConsoleLogger(log_every=500)],
)
solver.train(*data, config=config)

TrainingConfig — full field reference

A single dataclass centralises all hyperparameters. Pass it to any solver's train() method.

FieldTypeDefaultDescription
epochsint1000Total training epochs
lrfloat1e-3Base learning rate
lr_scheduleoptax scheduleNoneOverrides lr when set; use optax.cosine_decay_schedule
batch_rint4096Collocation mini-batch size
batch_iint512Initial-condition mini-batch size
batch_bint512Boundary-condition mini-batch size
log_everyint100Print interval (used by ConsoleLogger)
seedint0PRNG seed
callbackslist[]List of Callback objects
n_scan_stepsint1Fuse N steps into one XLA kernel (1 = Python loop)
resample_periodint0RAR-D resampling every N outer steps (0 = off)
resample_candidatesint0Candidate pool size (0 → 5 × batch_r)
resample_kfloat1.0Exponent in p ∝ |residual|^k
out_dirstr""Output directory; enables auto-restart when non-empty
save_restart_everyint500Snapshot interval in epochs (0 = off)

Callbacks

ConsoleLogger + EarlyStopping
from underPINN.callbacks.logging        import ConsoleLogger
from underPINN.callbacks.early_stopping import EarlyStopping

# Prints loss every 500 epochs
ConsoleLogger(log_every=500)

# Halt after 400 epochs without improvement
EarlyStopping(
    patience  = 400,
    monitor   = "loss",
    min_delta = 1e-8,
)
ModelCheckpoint
from underPINN.callbacks.checkpoint import ModelCheckpoint

ModelCheckpoint(
    out_dir       = "outputs/burgers/",
    monitor       = "loss",
    mode          = "min",
    save_best_only= True,
    metadata      = {
        "problem": "burgers",
        "network": {
            "type"  : "mlp",
            "layers": [2, 64, 64, 64, 1],
        },
    },
)

lax.scan acceleration — n_scan_steps

n_scan_stepsPython calls / 5 000 epochsCallback granularityUse case
1 (default)5 000every epochDevelopment / debugging
10050every 100 epochsGPU training, medium runs
50010every 500 epochsLong GPU runs, production

Fault-tolerant training — resume from any interruption

Set save_restart_every: 500 in your YAML (or TrainingConfig) and interrupted runs automatically resume from the last snapshot. No code changes needed — it just works.

How resumption is gated

Solvers check the done flag in meta.json. If done: false (run was interrupted), the snapshot is restored automatically — regardless of config changes. To verify config integrity before resuming a completed run, use python -m underPINN resume config.yaml. To force a fresh start, delete <out_dir>/restart/ manually.

  1. Snapshot written every N epochs

    Every save_restart_every epochs, RestartManager writes params.msgpack, opt_state.msgpack, hists.npz, and meta.json to <out_dir>/restart/.

  2. Re-run checks the done flag and resumes

    On re-run, RestartManager reads meta.json. If done: false (run was interrupted), params, optimizer state, and loss histories are restored — training continues from the saved epoch. Plots stay continuous across restarts.

  3. Config-change safety via resume command

    Solvers do not hash-check configs — an interrupted run resumes even if you changed lr, epochs, etc. Run python -m underPINN resume config.yaml to detect config changes before resuming. To force a fresh start, delete <out_dir>/restart/.

  4. Completion marks the snapshot as done

    After training finishes — normally or via early stopping — done() writes "done": true to meta.json. The next run with the same config starts fresh rather than re-resuming a completed run.

Snapshot directory contents

FileContents
params.msgpackFlax-serialised model parameters at the snapshot epoch
opt_state.msgpackFlax-serialised optimizer state (Adam moments, step count)
hists.npzAll loss history arrays accumulated so far (loss_hist, pde_hist, etc.)
meta.json{"epoch": N, "cfg_hash": null, "done": false}cfg_hash is null when written by a solver; populated by the resume CLI

Configuration

YAML (the only change needed)
training:
  save_restart_every: 500   # 0 to disable
Python API
config = TrainingConfig(
    epochs             = 10000,
    out_dir            = "outputs/burgers",
    save_restart_every = 500,
)
solver.train(*data, config=config)
# If killed at epoch 3700, next run resumes
# from epoch 3500 (last snapshot) automatically.

16 worked examples across 7 physics domains

Each example folder is self-contained: one script + one YAML config. Run directly (python examples/burgers/burgers.py) or via the CLI. Saves predictions, plots, and a params.msgpack checkpoint automatically.

Problem PDE Network Key Features Config
Exponential Decay du/dt + λu = 0 MLP [1,32,32,1] ODESolverTrainingConfig examples/ode/config.yaml
Harmonic Oscillator d²u/dt² + ω²u = 0 MLP [1,32,32,1] ODESolverIC derivative examples/ode/config.yaml
1-D Burgers u_t + uu_x = νu_xx MLP [2,64,64,64,1] FBPINNRBAcosine LR examples/burgers/config.yaml
1-D Heat — Forward u_t = αu_xx MLP [2,64,64,64,1] FBPINNSolverexact Gaussian examples/heat/heat_forward.yaml
1-D Heat — Inverse u_t = αu_xx MLP [2,64,64,64,1] Recover α50 noisy obs.log-param examples/heat/heat_inverse.yaml
1-D Wave Equation u_tt = c²u_xx FourierMLP [2,128,128,1] FourierMLPdual IC (u, u_t) examples/wave/config.yaml
2-D Helmholtz Δu + k²u = f FourierMLP [2,128,128,1] FourierMLPk=4manufactured source examples/helmholtz/config.yaml
2-D Diffusion Inverse u_t = α∇²u MLP [3,64,64,64,1] log-param joint opt.SteadySolver examples/inverse/config.yaml
2-D Lid-Driven Cavity Steady N-S, Re=100 FBPINN + SimpleGate LDCSolverSimpleGate attention examples/LDC/config.yaml
2-D RANS k-ε RANS k-ε, Re=10000 FBPINN RANSSolverRBAturbulent channel examples/K-Epsilon/config.yaml
2-D Compressible Ramp Steady Euler, M=3 MLP [2,80,80,80,80,80,4] oblique shock θ=10°ramp geometry examples/ramp/config.yaml
NACA 0012 Airfoil Steady N-S, Re=200 MLP [2,128,128,128,3] exterior geometryCp curve, CLRAR-D examples/airfoil/config.yaml
3-D Pipe Flow Steady 3-D N-S MLP [3,64,64,64,64,4] double-jacfwd HessianPipe geometry examples/pipe_flow/pipe_flow.yaml
3-D Unsteady Pipe — Transfer u_t = G + ν∇²u MLP [3,64,64,64,64,1] Bessel exactRe & temporal TL examples/pipe_flow/pipe_flow_unsteady_transfer.yaml
Burgers Transfer Burgers MLP [2,64,64,64,1] param transfer (ν)temporal transfer examples/transfer/burgers_transfer.yaml
Heat 2-D Transfer 2-D unsteady heat MLP [3,64,64,64,1] cross-diffusivity TLtemporal transfer examples/transfer/heat2d_transfer.yaml

Core abstractions & module layout

Three abstract base classes form the backbone. Every PDE, loss, and solver conforms via inheritance — no rewrites required.

Package Layout

  • core/ — BasePDE, BaseLoss, BaseSolver, TrainingConfig
  • nn/ — MLP, FourierMLP, FBPINN, HybridAttention, SimpleGate
  • pde/ — Burgers, Wave, Helmholtz, Heat, N-S 2D/3D, k-ε, Euler, ODE
  • solver/ — FBPINNSolver, SteadySolver, ODESolver, LDCSolver, RANSSolver
  • losses/ — PINNLoss (RBA), ODELoss, SteadyLoss
  • callbacks/ — ConsoleLogger, EarlyStopping, ModelCheckpoint
  • geometry/ — Interval, Rectangle, Airfoil, Pipe, Ramp, Composite, ShapelyGeom
  • training/ — rar_d_resample (RAR-D adaptive collocation)
  • runner/ — dispatch.py path-registry + importlib loader
  • utils/ — save_predictions, checkpoint, restart, ModelPredictor, timing, metrics
  • benchmark_utils/ — BenchmarkRunner, evaluators, report

Core Abstractions

underPINN/core/base.py
class BasePDE(ABC):
    # Every PDE implements residual()
    @abstractmethod
    def residual(self, params, *args): ...

class BaseLoss(ABC):
    # Returns (total_loss, aux_tuple)
    @abstractmethod
    def __call__(self, params, *args): ...

class BaseSolver(ABC):
    # Concrete helpers on every solver:
    def save_checkpoint(self, out_dir, ...): ...
    def restore_checkpoint(self, path): ...
    def load_params(self, params): ...  # TL warm-start
PDE + Geometry reference
# Every PDE:
class BurgersPDE(BasePDE):
    def residual(self, params, x, t):
        # returns |u_t + u*u_x - nu*u_xx|

# Every geometry:
class Pipe:
    def sample_interior(self, n, key): ...
    def sample_wall(self, n, key): ...
    def sample_inlet(self, n, key): ...
    def sample_outlet(self, n, key): ...

Adding a new case — one line

underPINN/runner/dispatch.py
# 1. Create examples/mycase/mycase.py  — define run_mycase(cfg) -> dict
# 2. Create examples/mycase/config.yaml  — set problem: mycase
# 3. Add ONE line here:
_REGISTRY = {
    "burgers"  : ("examples/burgers/burgers.py",  "run_burgers"),
    "wave"     : ("examples/wave/wave.py",        "run_wave"),
    "mycase"   : ("examples/mycase/mycase.py",    "run_mycase"),  # ← add this
    # ... no other files need to change
}

YAML-driven experiments, zero code changes

Run scripts directly or via the CLI — both point at the same YAML. Adding a new problem = one script + one YAML + one line in dispatch.py.

Single Run

Point at a YAML — or run the script directly. Both work identically.

$ python -m underPINN run examples/burgers/config.yaml
$ python -m underPINN run examples/wave/config.yaml
$ python -m underPINN run examples/helmholtz/config.yaml
$ python -m underPINN run examples/ramp/config.yaml
$ python -m underPINN run examples/airfoil/config.yaml
$ python -m underPINN run examples/pipe_flow/pipe_flow.yaml

Hyperparameter Sweep

Cartesian product across any config key. Each run gets its own sub-directory.

$ python -m underPINN sweep \
  examples/burgers/burgers_nu_sweep.yaml
 
# sweep YAML anatomy:
# base: { problem: burgers, ... }
# sweep:
# physics.nu: [0.1, 0.05, 0.01]
# training.epochs: [3000, 5000]

Inspect & List

Print the resolved config or list all registered runners.

$ python -m underPINN show examples/wave/config.yaml
$ python -m underPINN resume examples/burgers/config.yaml
$ python -m underPINN list
$ python -m underPINN version
 
Registered runners: burgers, wave,
helmholtz, heat_forward, heat_inverse,
ode, ldc, airfoil, pipe_flow, ramp,
burgers_transfer, pipe_flow_unsteady_transfer,
inverse_diffusion, ...

Benchmark Suite

Accuracy vs. epoch budget analysis across all problems.

$ python -m underPINN bench
$ python -m underPINN bench \
  --problems burgers wave ode_exp \
  --epochs 500 2000 5000 \
  --output outputs/bench
$ python -m underPINN bench --all
$ python -m underPINN bench \
  --from-json outputs/bench/results.json

Systematic accuracy-vs-epoch analysis

One command trains every registered problem across multiple epoch budgets and produces a complete analysis package — plots, CSV, Markdown, and reusable JSON.

Programmatic usage
from underPINN.benchmark_utils import BenchmarkRunner, generate_report

runner = BenchmarkRunner(
    problems      = ["burgers", "wave", "ode_exp", "helmholtz"],
    epoch_budgets = [500, 1000, 2000, 5000],
    seed          = 0, fast_only=True, verbose=True,
)
results = runner.run(out_dir="outputs/bench")
runner.save_json("outputs/bench/results.json")
generate_report(results, runner, out_dir="outputs/bench")
📈
accuracy_vs_epochs.pngLog-log rel-L² vs epoch budget, one curve per problem
📊
accuracy_summary_bar.pngGrouped bar chart at each epoch budget
wall_time_vs_epochs.pngTraining time vs epoch budget
📋
ms_per_epoch.pngBar chart of training throughput per problem
🔄
loss_grid.pngConvergence curves for every problem
📄
benchmark_results.csvFull raw data table — importable into pandas
📝
benchmark_summary.mdMarkdown table, one row per problem at max epochs
🔄
results.jsonReusable for --from-json replays without re-training

Engineered for GPU throughput

Two orthogonal optimisations stack cleanly: scan-based XLA fusion eliminates Python overhead, RAR-D concentrates compute on hard regions. Both are composable with all solvers.

500×
Less Python dispatch overhead on GPU with n_scan_steps=500
|r|^k
RAR-D resampling probability proportional to residual magnitude
float32
All arrays cast to float32 — optimal throughput on all GPUs; do not enable x64
0 MB
Wasted VRAM — on-demand XLA allocation by default via XLA_PYTHON_CLIENT_PREALLOCATE=false

lax.scan — n_scan_steps reference

n_scan_stepsPython calls / 5 000 epochsCallback granularityUse case
1 (default)5 000every epochDevelopment / debugging
10050every 100 epochsGPU training, medium runs
50010every 500 epochsLong GPU runs, production

Performance tips summary

GPU Memory

Already handled — underPINN sets XLA_PYTHON_CLIENT_PREALLOCATE=false in every entry point before import jax. If writing a new script, add this at the very top.

lax.scan on GPU

Use n_scan_steps=100 for medium GPU runs. For long runs (>5 000 epochs), use 500. On CPU, leave at 1 for full callback granularity.

RAR-D for sharp solutions

Enable RAR-D (resample_period=5, resample_k=1.0) when the solution has sharp gradients or shocks — Burgers at low ν, Euler ramp, wave at high frequency.

Early stopping patience

Fast ODEs: patience=200. Medium PDEs: 400–800. Complex PDEs (LDC, airfoil, 3-D): 1 000–2 000. Combine with cosine LR decay for best results.

Float32 — do not use float64

Do not call jax.config.update("jax_enable_x64", True). Float64 halves throughput on CUDA devices and is not needed for PINN training.

Multi-GPU

Use CUDA_VISIBLE_DEVICES=1 to restrict to a specific GPU. Full multi-GPU pmap training is not currently implemented; use the fastest single device.

Training time reporting

Every solver prints timing stats at completion: 45s [JIT≈12s + 3.3ms/ep]. The JIT≈… component appears when XLA compilation overhead is detected, cleanly separating compile time from per-epoch cost.