A modular, GPU-accelerated Physics-Informed Neural Network framework built on JAX + Flax + Optax — from harmonic oscillators to 3-D turbulent Navier-Stokes, all JIT-compiled via XLA.
From simple ODEs to full 3-D turbulence — one coherent API, no boilerplate. Automatic GPU memory management, fault-tolerant restart, and YAML-driven experiments out of the box.
Standard multi-layer perceptron with tanh activations. Configurable depth and width via a simple layer list — e.g. [2, 64, 64, 64, 1] for a space-time Burgers network.
Trainable random Fourier feature embeddings prepended to a standard MLP. Essential for oscillatory solutions — Helmholtz, wave, high-Re flows — where plain MLPs exhibit spectral bias.
Overlapping subdomain decomposition with sigmoid partition-of-unity windows. HybridAttention and SimpleGate gated residual blocks inside each subdomain for complex geometries.
Fuse N gradient steps into a single XLA kernel — 50–500× less Python dispatch overhead on GPU. Works with all solvers; callbacks fire every N outer epochs.
Via optax.cosine_decay_schedule. Integrates seamlessly with TrainingConfig. Recommended for runs longer than 2 000 epochs — delivers free accuracy improvement.
Periodically replaces collocation points with samples drawn proportional to |residual|^k (Lu et al., 2021). Focuses compute on high-error regions without changing total batch size.
Residual-based adaptivity assigns per-point loss weights so boundary and collocation losses are automatically balanced. Especially effective for stiff boundary conditions.
Monitors a metric (default: total loss) and halts training after patience epochs without improvement. Works correctly inside lax.scan loops — fires at the outer-step boundary.
Centralises all hyperparameters with runtime validation. A single TrainingConfig object is passed to every solver — no kwargs scattered across multiple calls.
RestartManager snapshots params + opt state + loss histories every N epochs. Re-run an interrupted job and it picks up exactly where it left off. Config-hash check prevents stale resumption.
Sets XLA_PYTHON_CLIENT_PREALLOCATE=false automatically before import jax. A 3-layer MLP uses ~200 MB instead of reserving 73 GB on an 80 GB A100.
Save Flax params as params.msgpack with a JSON sidecar. Reload in one line with ModelPredictor.from_meta(path) — no need to re-specify the architecture.
Parameter transfer (different Re / diffusivity) and temporal transfer (extended time horizon). Warm-start any solver from a pre-trained model with one call.
Joint optimisation of network weights + physics parameters. Log-parameterisation ensures positivity without constraints. Demonstrated on thermal diffusivity recovery from sparse noisy data.
Every hyperparameter lives in a YAML config. Cartesian product sweeps built-in. run, sweep, bench, list, show, version subcommands — no code changes needed.
Pure Python — install in editable mode from the repo root so all examples and the CLI resolve imports automatically.
pip install jax flax optax matplotlib scipy \
shapely pandas pyyaml
pip install -U "jax[cuda12]"
pip install -r requirements-gpu.txt
git clone https://github.com/Aeroscience-Computations-Analysis-Lab/underPINN.git cd underPINN-v2605 pip install -e . # Verify GPU is visible python -c "import jax; print(jax.devices())" # Expected on GPU: [CudaDevice(id=0)]
JIT compilation, autodiff, PRNG
Neural network layers and parameter trees
Adam, cosine decay, gradient clipping
Numerics, exact solutions, plotting
Arbitrary polygon geometry support
YAML config loading and merging
JAX's XLA BFC allocator pre-reserves ~90% of all free VRAM the moment import jax executes — before any tensor is created — to prevent memory fragmentation during training. underPINN disables this automatically.
On an 80 GB A100, import jax immediately reserves ~73 GB of VRAM even if your model only needs 200 MB. This blocks other processes from using the GPU and makes it look like your job consumed the entire card.
The environment variable is set in underPINN/__main__.py (for CLI runs) and at the top of every example script (for direct python examples/... runs) before import jax. You get on-demand GPU memory growth out of the box — no configuration needed.
# On-demand growth (default in underPINN) — frees unreserved VRAM for other jobs export XLA_PYTHON_CLIENT_PREALLOCATE=false # Hard cap — useful when sharing a node; limits to e.g. 20% of VRAM export XLA_PYTHON_CLIENT_PREALLOCATE=false export XLA_PYTHON_CLIENT_MEM_FRACTION=0.20 # Platform allocator — no XLA pool at all (slowest, minimal fragmentation) export XLA_PYTHON_CLIENT_ALLOCATOR=platform # Multi-GPU: restrict to a single device (e.g. GPU 1) export CUDA_VISIBLE_DEVICES=1
import jax)import os os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.15" import jax # now uses at most 15% of VRAM
| Problem | Network | VRAM (approx) |
|---|---|---|
| Burgers 1-D | [2,64,64,64,1] | ~200 MB |
| Wave 1-D | FourierMLP | ~300 MB |
| Helmholtz 2-D | FourierMLP | ~400 MB |
| LDC 2-D | FBPINN | ~800 MB |
| Airfoil 2-D | [2,128,128,128,3] | ~1.2 GB |
| Pipe Flow 3-D | [3,64,64,64,64,4] | ~2.0 GB |
| Compressible Ramp | [2,80,80,80,80,80,4] | ~1.8 GB |
| k-ε Turbulence | FBPINN | ~3.0 GB |
Six patterns — pick the one that fits your use case. All examples set XLA_PYTHON_CLIENT_PREALLOCATE=false automatically.
# Single run — point at any registered YAML config python -m underPINN run examples/burgers/config.yaml python -m underPINN run examples/wave/config.yaml python -m underPINN run examples/pipe_flow/pipe_flow.yaml python -m underPINN run examples/ramp/config.yaml # Hyperparameter sweep (Cartesian product) python -m underPINN sweep examples/burgers/burgers_nu_sweep.yaml # Benchmark all problems python -m underPINN bench # List all registered runners python -m underPINN list # Print resolved config without training python -m underPINN show examples/wave/config.yaml # Print framework version python -m underPINN version
import os os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") import jax, optax from underPINN.nn.mlp import MLP from underPINN.pde.burgers import BurgersPDE from underPINN.losses.loss import PINNLoss from underPINN.solver.fbpinn import FBPINNSolver from underPINN.core.config import TrainingConfig from underPINN.callbacks.logging import ConsoleLogger from underPINN.callbacks.early_stopping import EarlyStopping model = MLP(layers=[2, 64, 64, 64, 1]) pde = BurgersPDE(model, nu=0.01) loss = PINNLoss(model, pde, ic_weight=100.0, bc_weight=10.0, rba=True) solver = FBPINNSolver(model, pde, loss=loss) solver.init(jax.random.PRNGKey(0)) config = TrainingConfig( epochs = 5000, lr = 1e-3, lr_schedule = optax.cosine_decay_schedule(1e-3, 5000, alpha=1e-2), batch_r = 2048, log_every = 500, out_dir = "outputs/burgers", # enables auto-restart save_restart_every = 500, callbacks = [ ConsoleLogger(log_every=500), EarlyStopping(patience=400), ], ) solver.train(*data, config=config)
problem: burgers # selects the runner network: type : mlp layers: [2, 64, 64, 64, 1] physics: nu: 0.01 data: T: 2.0 n_collocation: 6000 n_ic: 200 n_bc: 200 training: epochs : 5000 lr : 1.0e-3 early_stopping_patience : 400 save_restart_every : 500 # snapshot every 500 epochs loss: ic_weight: 100.0 bc_weight: 10.0 rba : true output: dir : outputs/burgers save_params: true
from underPINN.utils.checkpoint import ModelPredictor import jax.numpy as jnp # Option A — auto-rebuild model from saved metadata (zero boilerplate) predictor = ModelPredictor.from_meta("outputs/burgers/") # Option B — supply model explicitly from underPINN.nn.mlp import MLP predictor = ModelPredictor.from_checkpoint( MLP(layers=[2, 64, 64, 64, 1]), "outputs/burgers/", ) # Inference x_test = jnp.linspace(-1.0, 1.0, 500) t_test = jnp.full(500, 0.8) u = predictor.predict(jnp.stack([x_test, t_test], axis=1))
# Phase 1: train source model (e.g. Burgers nu=0.1) solver_src.train(*data_src, config=cfg_src) solver_src.save_checkpoint("outputs/source/") # Phase 2: warm-start target from source weights (e.g. nu=0.01) solver_tgt.load_params(solver_src.params) # or restore_checkpoint(...) solver_tgt.train(*data_tgt, config=cfg_tgt) # lower lr recommended (3e-4) # Converges 2-3x faster than training from scratch
# Fuse 100 gradient steps per XLA kernel + adaptive resampling config = TrainingConfig( epochs = 5000, lr = 1e-3, n_scan_steps = 100, # 50 Python calls instead of 5000 resample_period = 5, # RAR-D every 5 outer steps (= 500 epochs) resample_k = 1.0, # probability proportional to |residual|^1 callbacks = [ConsoleLogger(log_every=500)], ) solver.train(*data, config=config)
A single dataclass centralises all hyperparameters. Pass it to any solver's train() method.
| Field | Type | Default | Description |
|---|---|---|---|
epochs | int | 1000 | Total training epochs |
lr | float | 1e-3 | Base learning rate |
lr_schedule | optax schedule | None | Overrides lr when set; use optax.cosine_decay_schedule |
batch_r | int | 4096 | Collocation mini-batch size |
batch_i | int | 512 | Initial-condition mini-batch size |
batch_b | int | 512 | Boundary-condition mini-batch size |
log_every | int | 100 | Print interval (used by ConsoleLogger) |
seed | int | 0 | PRNG seed |
callbacks | list | [] | List of Callback objects |
n_scan_steps | int | 1 | Fuse N steps into one XLA kernel (1 = Python loop) |
resample_period | int | 0 | RAR-D resampling every N outer steps (0 = off) |
resample_candidates | int | 0 | Candidate pool size (0 → 5 × batch_r) |
resample_k | float | 1.0 | Exponent in p ∝ |residual|^k |
out_dir | str | "" | Output directory; enables auto-restart when non-empty |
save_restart_every | int | 500 | Snapshot interval in epochs (0 = off) |
from underPINN.callbacks.logging import ConsoleLogger from underPINN.callbacks.early_stopping import EarlyStopping # Prints loss every 500 epochs ConsoleLogger(log_every=500) # Halt after 400 epochs without improvement EarlyStopping( patience = 400, monitor = "loss", min_delta = 1e-8, )
from underPINN.callbacks.checkpoint import ModelCheckpoint ModelCheckpoint( out_dir = "outputs/burgers/", monitor = "loss", mode = "min", save_best_only= True, metadata = { "problem": "burgers", "network": { "type" : "mlp", "layers": [2, 64, 64, 64, 1], }, }, )
| n_scan_steps | Python calls / 5 000 epochs | Callback granularity | Use case |
|---|---|---|---|
| 1 (default) | 5 000 | every epoch | Development / debugging |
| 100 | 50 | every 100 epochs | GPU training, medium runs |
| 500 | 10 | every 500 epochs | Long GPU runs, production |
Set save_restart_every: 500 in your YAML (or TrainingConfig) and interrupted runs automatically resume from the last snapshot. No code changes needed — it just works.
Solvers check the done flag in meta.json. If done: false (run was interrupted), the snapshot is restored automatically — regardless of config changes. To verify config integrity before resuming a completed run, use python -m underPINN resume config.yaml. To force a fresh start, delete <out_dir>/restart/ manually.
Every save_restart_every epochs, RestartManager writes params.msgpack, opt_state.msgpack, hists.npz, and meta.json to <out_dir>/restart/.
done flag and resumesOn re-run, RestartManager reads meta.json. If done: false (run was interrupted), params, optimizer state, and loss histories are restored — training continues from the saved epoch. Plots stay continuous across restarts.
resume commandSolvers do not hash-check configs — an interrupted run resumes even if you changed lr, epochs, etc. Run python -m underPINN resume config.yaml to detect config changes before resuming. To force a fresh start, delete <out_dir>/restart/.
After training finishes — normally or via early stopping — done() writes "done": true to meta.json. The next run with the same config starts fresh rather than re-resuming a completed run.
| File | Contents |
|---|---|
params.msgpack | Flax-serialised model parameters at the snapshot epoch |
opt_state.msgpack | Flax-serialised optimizer state (Adam moments, step count) |
hists.npz | All loss history arrays accumulated so far (loss_hist, pde_hist, etc.) |
meta.json | {"epoch": N, "cfg_hash": null, "done": false} — cfg_hash is null when written by a solver; populated by the resume CLI |
training: save_restart_every: 500 # 0 to disable
config = TrainingConfig( epochs = 10000, out_dir = "outputs/burgers", save_restart_every = 500, ) solver.train(*data, config=config) # If killed at epoch 3700, next run resumes # from epoch 3500 (last snapshot) automatically.
Each example folder is self-contained: one script + one YAML config. Run directly (python examples/burgers/burgers.py) or via the CLI. Saves predictions, plots, and a params.msgpack checkpoint automatically.
| Problem | PDE | Network | Key Features | Config |
|---|---|---|---|---|
| Exponential Decay | du/dt + λu = 0 | MLP [1,32,32,1] | ODESolverTrainingConfig | examples/ode/config.yaml |
| Harmonic Oscillator | d²u/dt² + ω²u = 0 | MLP [1,32,32,1] | ODESolverIC derivative | examples/ode/config.yaml |
| 1-D Burgers | u_t + uu_x = νu_xx | MLP [2,64,64,64,1] | FBPINNRBAcosine LR | examples/burgers/config.yaml |
| 1-D Heat — Forward | u_t = αu_xx | MLP [2,64,64,64,1] | FBPINNSolverexact Gaussian | examples/heat/heat_forward.yaml |
| 1-D Heat — Inverse | u_t = αu_xx | MLP [2,64,64,64,1] | Recover α50 noisy obs.log-param | examples/heat/heat_inverse.yaml |
| 1-D Wave Equation | u_tt = c²u_xx | FourierMLP [2,128,128,1] | FourierMLPdual IC (u, u_t) | examples/wave/config.yaml |
| 2-D Helmholtz | Δu + k²u = f | FourierMLP [2,128,128,1] | FourierMLPk=4manufactured source | examples/helmholtz/config.yaml |
| 2-D Diffusion Inverse | u_t = α∇²u | MLP [3,64,64,64,1] | log-param joint opt.SteadySolver | examples/inverse/config.yaml |
| 2-D Lid-Driven Cavity | Steady N-S, Re=100 | FBPINN + SimpleGate | LDCSolverSimpleGate attention | examples/LDC/config.yaml |
| 2-D RANS k-ε | RANS k-ε, Re=10000 | FBPINN | RANSSolverRBAturbulent channel | examples/K-Epsilon/config.yaml |
| 2-D Compressible Ramp | Steady Euler, M=3 | MLP [2,80,80,80,80,80,4] | oblique shock θ=10°ramp geometry | examples/ramp/config.yaml |
| NACA 0012 Airfoil | Steady N-S, Re=200 | MLP [2,128,128,128,3] | exterior geometryCp curve, CLRAR-D | examples/airfoil/config.yaml |
| 3-D Pipe Flow | Steady 3-D N-S | MLP [3,64,64,64,64,4] | double-jacfwd HessianPipe geometry | examples/pipe_flow/pipe_flow.yaml |
| 3-D Unsteady Pipe — Transfer | u_t = G + ν∇²u | MLP [3,64,64,64,64,1] | Bessel exactRe & temporal TL | examples/pipe_flow/pipe_flow_unsteady_transfer.yaml |
| Burgers Transfer | Burgers | MLP [2,64,64,64,1] | param transfer (ν)temporal transfer | examples/transfer/burgers_transfer.yaml |
| Heat 2-D Transfer | 2-D unsteady heat | MLP [3,64,64,64,1] | cross-diffusivity TLtemporal transfer | examples/transfer/heat2d_transfer.yaml |
Three abstract base classes form the backbone. Every PDE, loss, and solver conforms via inheritance — no rewrites required.
class BasePDE(ABC): # Every PDE implements residual() @abstractmethod def residual(self, params, *args): ... class BaseLoss(ABC): # Returns (total_loss, aux_tuple) @abstractmethod def __call__(self, params, *args): ... class BaseSolver(ABC): # Concrete helpers on every solver: def save_checkpoint(self, out_dir, ...): ... def restore_checkpoint(self, path): ... def load_params(self, params): ... # TL warm-start
# Every PDE: class BurgersPDE(BasePDE): def residual(self, params, x, t): # returns |u_t + u*u_x - nu*u_xx| # Every geometry: class Pipe: def sample_interior(self, n, key): ... def sample_wall(self, n, key): ... def sample_inlet(self, n, key): ... def sample_outlet(self, n, key): ...
# 1. Create examples/mycase/mycase.py — define run_mycase(cfg) -> dict # 2. Create examples/mycase/config.yaml — set problem: mycase # 3. Add ONE line here: _REGISTRY = { "burgers" : ("examples/burgers/burgers.py", "run_burgers"), "wave" : ("examples/wave/wave.py", "run_wave"), "mycase" : ("examples/mycase/mycase.py", "run_mycase"), # ← add this # ... no other files need to change }
Run scripts directly or via the CLI — both point at the same YAML. Adding a new problem = one script + one YAML + one line in dispatch.py.
Point at a YAML — or run the script directly. Both work identically.
Cartesian product across any config key. Each run gets its own sub-directory.
Print the resolved config or list all registered runners.
Accuracy vs. epoch budget analysis across all problems.
One command trains every registered problem across multiple epoch budgets and produces a complete analysis package — plots, CSV, Markdown, and reusable JSON.
from underPINN.benchmark_utils import BenchmarkRunner, generate_report runner = BenchmarkRunner( problems = ["burgers", "wave", "ode_exp", "helmholtz"], epoch_budgets = [500, 1000, 2000, 5000], seed = 0, fast_only=True, verbose=True, ) results = runner.run(out_dir="outputs/bench") runner.save_json("outputs/bench/results.json") generate_report(results, runner, out_dir="outputs/bench")
--from-json replays without re-trainingTwo orthogonal optimisations stack cleanly: scan-based XLA fusion eliminates Python overhead, RAR-D concentrates compute on hard regions. Both are composable with all solvers.
n_scan_steps=500| n_scan_steps | Python calls / 5 000 epochs | Callback granularity | Use case |
|---|---|---|---|
| 1 (default) | 5 000 | every epoch | Development / debugging |
| 100 | 50 | every 100 epochs | GPU training, medium runs |
| 500 | 10 | every 500 epochs | Long GPU runs, production |
Already handled — underPINN sets XLA_PYTHON_CLIENT_PREALLOCATE=false in every entry point before import jax. If writing a new script, add this at the very top.
Use n_scan_steps=100 for medium GPU runs. For long runs (>5 000 epochs), use 500. On CPU, leave at 1 for full callback granularity.
Enable RAR-D (resample_period=5, resample_k=1.0) when the solution has sharp gradients or shocks — Burgers at low ν, Euler ramp, wave at high frequency.
Fast ODEs: patience=200. Medium PDEs: 400–800. Complex PDEs (LDC, airfoil, 3-D): 1 000–2 000. Combine with cosine LR decay for best results.
Do not call jax.config.update("jax_enable_x64", True). Float64 halves throughput on CUDA devices and is not needed for PINN training.
Use CUDA_VISIBLE_DEVICES=1 to restrict to a specific GPU. Full multi-GPU pmap training is not currently implemented; use the fastest single device.
Every solver prints timing stats at completion: 45s [JIT≈12s + 3.3ms/ep]. The JIT≈… component appears when XLA compilation overhead is detected, cleanly separating compile time from per-epoch cost.