A modular differentiable multi-physics framework in JAX for coupled thermal-fluidic-structural topology optimization.
- Couples Stokes-Brinkman, conjugate heat transfer, and linear thermoelasticity
- Reverse-mode AD with custom VJP rules (no manual adjoint derivation)
- Sparse solver backend: O(n^1.1) scaling in 2D, GPU-native CG for 3D
- Warm-start strategy enables three-physics optimization where cold-start fails
pip install jax==0.4.30 jaxlib==0.4.30 scipy==1.12.0 numpy==1.26.4 pyamg==5.2.1 matplotlib# Run 2D warm-start benchmark (80x40 mesh, ~3 min)
python src/benchmarks/run_benchmark_v2.py
# Run aluminum material validation
python src/benchmarks/run_aluminum_validation.py
# Validate Stokes assumption
python src/benchmarks/validate_stokes.pysrc/
├── solvers/ # Physics solver modules
│ ├── fem.py # 2D FEM infrastructure (Q4 elements)
│ ├── fem3d.py # 3D FEM (H8 elements)
│ ├── fem3d_q2q1.py # 3D Taylor-Hood Q2-Q1 elements
│ ├── fluid.py # Stokes-Brinkman solver (2D)
│ ├── fluid_3d.py # Stokes-Brinkman solver (3D)
│ ├── thermal.py # Conjugate heat transfer
│ ├── structural.py # Linear elasticity + thermal stress
│ ├── coupled.py # 2D coupled multi-physics pipeline
│ ├── heat_sink_solver.py # 2D integrated multi-physics
│ ├── heat_sink_3d.py # 3D integrated multi-physics
│ ├── sparse_solve.py # Sparse solver + custom VJP (CPU)
│ ├── gpu_solve.py # GPU-native CG solver + custom VJP
│ ├── gpu_stokes.py # GPU Schur complement for Q2-Q1
│ └── stokes_coupled.py # 3D Schur complement (CPU)
├── optimization/ # SIMP, filtering, MMA, warm-start
│ ├── optimizer.py # Main optimization loop + warm-start scheduler
│ ├── mma_optimizer.py # Method of Moving Asymptotes
│ ├── al_optimizer.py # Augmented Lagrangian optimizer
│ └── filters.py # PDE-based filtering + Heaviside projection
├── benchmarks/ # Reproducible experiment scripts
│ ├── run_benchmark_v2.py # Primary 2D warm-start benchmark
│ ├── run_3d_optimization.py
│ ├── run_3d_large_gpu.py
│ ├── run_beta_ablation.py
│ ├── run_3d_beta_ablation.py
│ ├── run_aluminum_validation.py
│ ├── validate_stokes.py
│ ├── run_robustness_study.py
│ ├── run_warmstart_benchmark.py
│ ├── run_gpu_benchmark.py
│ ├── run_profiling.py
│ ├── benchmark_mesh_convergence.py
│ └── ... # Additional benchmarks and tests
├── analysis/ # Paper figure generation
│ ├── generate_paper_figures_v4.py
│ ├── generate_methodology_figures.py
│ └── generate_arch_figure.py
├── surrogate/ # Surrogate model experiments
│ ├── generate_data.py
│ ├── train_surrogate.py
│ └── generate_figures.py
└── utils/ # Visualization, I/O
└── visualization.py
results/ # Pre-generated data (JSON, NPY)
paper/ # LaTeX manuscript + figures
Total reproduction time: ~10-12 hours (CPU), ~14 hours for 3D GPU case.
See src/benchmarks/ for individual experiment scripts.
- CPU: Intel Xeon Gold 6248R (3.00 GHz, 48 cores)
- GPU: NVIDIA Quadro RTX 5000 (16 GB) -- for 3D 40x40x20 case
[Paper citation to be added upon publication]
MIT