Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
74 changes: 74 additions & 0 deletions neural-pde-surrogates-for-3d-electrostatics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Rapid Design Exploration with Neural PDE Surrogates
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we generate markdown from one or both of the transolver/meshgraphnet examples and find a way to include it in the repo, either in this readme or separate markdowns? I think it's important for examples to be readable on GitHub without requring someone to clone the repo and open the files in MATLAB just to read what the examples are doing.


This demo shows how neural partial differential equation (PDE) surrogates can accelerate design exploration for a 3D electrostatics problem.
The goal is to predict the electric potential field throughout a transformer bushing insulator,
a cylindrical component with fins that provides electrical insulation between a high-voltage conductor and a grounded enclosure,
for varying bushing geometries. Two deep learning architectures are trained on finite element analysis (FEA) simulations to learn this mapping directly from geometry, enabling rapid evaluation of new designs without running a full simulation.

## Architectures

Two neural PDE surrogate architectures are demonstrated:

- **Transolver** [\[1\]](README.md#references) — A transformer-based architecture for learning PDE solutions on unstructured meshes. Rather than applying attention across all mesh nodes, Transolver learns to group nodes into a small number of physics slices and applies attention across these slices, reducing the quadratic cost of global attention. Inputs are the 3D coordinates and material properties (relative permittivity) at each mesh node.

- **MeshGraphNet** [\[2\]](README.md#references) — A graph neural network that represents the FEA mesh as a graph (nodes = mesh nodes, edges = element connectivity). Through message-passing layers, the network learns how information propagates through the geometry. Inputs are the 3D coordinates of the mesh nodes within the bushing insulator.

Both architectures output the electric potential at every input node.

## Training Data
The problem setup follows the documentation example [Electrostatic Analysis of Transformer Bushing Insulator](https://www.mathworks.com/help/pde/ug/electrostatic-analysis-of-transformer-bushing-insulator.html).
Boundary conditions are held fixed; only the transformer bushing insulator geometry varies between samples.

The training dataset consists of 75 electrostatic simulations, each on a different procedurally generated transformer bushing geometry. For each simulation, we extract:
- FEA mesh nodes and element connectivity
- Material properties (relative permittivity at each node)
- Electric potential at each node
- Electric field at each node (for optional gradient regularization)

## Results

Below are example predictions from two trained models on two geometries.

**Transolver** predictions:
![Transolver prediction on bushing geometry 39: side-by-side comparison of FEA solution, AI prediction with 2.8% relative L2 error, and relative error map](README_media/transolver_test_39.png)

![Transolver prediction on bushing geometry 57: side-by-side comparison of FEA solution, AI prediction with 2.4% relative L2 error, and relative error map](README_media/transolver_test_57.png)

**MeshGraphNet** predictions:
![MeshGraphNet prediction on bushing geometry 39: side-by-side comparison of FEA solution, AI prediction with 2.1% relative L2 error, and relative error map](README_media/meshgraphnet_test_39.png)

![MeshGraphNet prediction on bushing geometry 57: side-by-side comparison of FEA solution, AI prediction with 2.2% relative L2 error, and relative error map](README_media/meshgraphnet_test_57.png)


# Getting Started

**Setup:**
- Run [startup.m](startup.m) to set the MATLAB path. This also creates the `data/`, `STL/`, and `results/` subdirectories.

**Data generation:**
- Run [generate\_data.m](generate_data.m) to generate the training dataset (~5 minutes with 8 parallel workers). This saves simulation results to `data/` and STL geometry files to `STL/`.

**Training:**
- **Transolver:** Run [transolver\_script.m](transolver_script.m) (includes optional gradient regularization fine-tuning).
- **MeshGraphNet:** Run [meshgraphnet\_script.m](meshgraphnet_script.m).
- Set `doTrain = true` to train from scratch (GPU recommended). The trained model is saved to `results/`.
- Set `doTrain = false` and specify a filename in `loadFile` to load a pretrained model.
- Set `doFinetune = true` to finetune a trained model with gradient regularization. This adds a loss term that penalizes discrepancies between the predicted electric field $-\nabla V$ and the reference FEA electric field, where the spatial gradients are computed via automatic differentiation. This encourages spatially smooth predictions and improves derived electric field accuracy (Transolver only).

**Inference app:**
- Type `NeuralPDEInferenceApp` in the Command Window.
- Click "Load AI Model" and select a model from the `results/` subdirectory.
- Click "Load Geometry" and choose an STL file from the `STL/` folder.
- Click "Predict!" to display the predicted electric potential.
- Toggle the switch to view electric field magnitude, which is numerically derived from the predicted potential.

# Required Products
- MATLAB® (tested on R2026a)
- PDE Toolbox™
- Deep Learning Toolbox™
- Parallel Computing Toolbox™
- Statistics and Machine Learning Toolbox™

# References
1. Wu, Haixu, Huakun Luo, Haowen Wang, Jianmin Wang, and Mingsheng Long. "Transolver: A Fast Transformer Solver for PDEs on General Geometries." arXiv, June 1, 2024. https://arxiv.org/abs/2402.02366.
2. Pfaff, Tobias, Meire Fortunato, Alvaro Sanchez-Gonzalez, and Peter W. Battaglia. "Learning Mesh-Based Simulation with Graph Networks." arXiv, June 18, 2021. https://arxiv.org/abs/2010.03409.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
209 changes: 209 additions & 0 deletions neural-pde-surrogates-for-3d-electrostatics/createBushing.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
function [gm, faceIDs] = createBushing(nFins, finRadiusBase, finRadiusTop, finWidth, tubeRadiusBase, tubeRadiusTop, boreRadius, totalLength)
%createBushing Create a parametric transformer bushing geometry.
% [gm, faceIDs] = createBushing(nFins, finRadiusBase, finRadiusTop,
% finWidth, tubeRadiusBase, tubeRadiusTop, boreRadius, totalLength)
% creates an axisymmetric transformer bushing with nFins cooling fin
% rings. The fins have a raised-cosine profile and are evenly spaced.
% Fin radii are linearly interpolated from finRadiusBase (first fin) to
% finRadiusTop (last fin). The tube tapers linearly from tubeRadiusBase
% at Z=0 to tubeRadiusTop at Z=totalLength.
%
% The geometry axis is along Z. The last (topmost) fin has a flat annular
% face at its peak for applying a boundary condition.
%
% Inputs:
% nFins - Number of cooling fin rings
% finRadiusBase - Outer radius of the first (bottom) fin
% finRadiusTop - Outer radius of the last (top) fin
% finWidth - Axial width of each fin (full width of cosine bump)
% tubeRadiusBase - Tube outer radius at Z=0 (bottom)
% tubeRadiusTop - Tube outer radius at Z=totalLength (top)
% boreRadius - Inner bore radius (central hole)
% totalLength - Total axial length of the bushing
%
% Outputs:
% gm - fegeometry object
% faceIDs - struct with fields:
% .bore - face ID for the inner bore surface
% .flatAnnular - face ID for the flat annular ring at
% the top of the last fin

%% Compute fin positions and radii
% Asymmetric margins: larger at the bottom, small stub at the top.
% endStub is the distance from the last fin CENTER to the top of the
% bushing — this is the visible tube length above the split face.
endStub = 0.06 * totalLength; % short tube stub above the last fin peak
startMargin = 0.12 * totalLength; % longer tube section at the bottom
finPositions = linspace(startMargin + finWidth/2, ...
totalLength - endStub, nFins);
finRadii = linspace(finRadiusBase, finRadiusTop, nFins);

% Clamp finWidth so adjacent fins don't overlap (leave a small tube gap)
if nFins > 1
spacing = finPositions(2) - finPositions(1);
maxFinWidth = 0.9 * spacing;
if finWidth > maxFinWidth
finWidth = maxFinWidth;
end
end

%% Helper: tube radius at any axial position (linear taper)
tubeRadiusAt = @(z) tubeRadiusBase + (tubeRadiusTop - tubeRadiusBase) * z / totalLength;

%% Build the outer radius profile
nArc = 8; % points per fin bump (raised cosine)

z_all = [];
r_all = [];

% Tube section before first fin
z_start = linspace(0, finPositions(1) - finWidth/2, 3);
z_all = z_start;
r_all = tubeRadiusAt(z_start);

for i = 1:nFins
hw = finWidth / 2;
center = finPositions(i);

% Raised cosine bump for the fin, sitting on the local tube radius
zFin = linspace(center - hw, center + hw, nArc);
localTubeR = tubeRadiusAt(zFin);
bumpHeight = finRadii(i) - localTubeR;
rFin = localTubeR + 0.5 .* bumpHeight .* ...
(1 + cos(pi * (zFin - center) / hw));

z_all = [z_all, zFin]; %#ok<AGROW>
r_all = [r_all, rFin]; %#ok<AGROW>

% Tube section after this fin
if i < nFins
zTube = linspace(center + hw + 0.002, ...
finPositions(i+1) - hw - 0.002, 3);
else
zTube = linspace(center + hw + 0.002, totalLength, 3);
end
z_all = [z_all, zTube]; %#ok<AGROW>
r_all = [r_all, tubeRadiusAt(zTube)]; %#ok<AGROW>
end

% Remove duplicates and sort
[z_all, ia] = unique(z_all);
r_all = r_all(ia);

%% Split at peak of last fin to create the flat annular face
lastFinCenter = finPositions(end);
[~, splitIdx] = min(abs(z_all - lastFinCenter));

%% Revolve lower body (all fins up to peak of last fin)
xv_low = [0, r_all(1:splitIdx), 0];
yv_low = [z_all(1), z_all(1:splitIdx), z_all(splitIdx)];
gm_lower = fegeometry(revolveProfile(xv_low, yv_low));

%% Revolve upper body (tube stub above last fin)
% Clamp radii to the local tube radius (remove any residual fin contribution)
localTubeAtSplit = tubeRadiusAt(z_all(splitIdx));
r_upper_raw = r_all(splitIdx+1:end);
r_upper_clamped = min(r_upper_raw, tubeRadiusAt(z_all(splitIdx+1:end)));
r_upper = [localTubeAtSplit, r_upper_clamped];
ax_upper = [z_all(splitIdx), z_all(splitIdx+1:end)];
xv_up = [0, r_upper, 0];
yv_up = [ax_upper(1), ax_upper, ax_upper(end)];
gm_upper = fegeometry(revolveProfile(xv_up, yv_up));

%% Union the two halves and subtract the bore
gm = union(gm_lower, gm_upper);
bore = fegeometry(multicylinder(boreRadius, totalLength));
gm = subtract(gm, bore);

%% Find face IDs programmatically
z_split = z_all(splitIdx);
mid_r = (localTubeAtSplit + r_all(splitIdx)) / 2;
faceIDs.flatAnnular = nearestFace(gm, [mid_r, 0, z_split]);
faceIDs.bore = nearestFace(gm, [boreRadius, 0, totalLength/2]);

end


function tri = revolveProfile(rProfile, zProfile, nAngles)
%revolveProfile Create a triangulated solid of revolution.
% tri = revolveProfile(rProfile, zProfile) revolves the closed 2D
% polygon defined by (rProfile, zProfile) around the Z axis to produce
% a closed triangulated surface mesh. rProfile contains radial
% coordinates (>= 0) and zProfile contains axial coordinates.
% Vertices with r ≈ 0 are treated as on-axis pole points.
%
% tri = revolveProfile(rProfile, zProfile, nAngles) specifies the
% number of angular divisions (default: 72, i.e. 5-degree steps).

if nargin < 3
nAngles = 72;
end

theta = linspace(0, 2*pi, nAngles + 1);
theta(end) = []; % remove duplicate at 2*pi

nPts = numel(rProfile);
onAxis = rProfile < 1e-10;
nOff = sum(~onAxis);
nOn = sum(onAxis);

% Preallocate vertices
nVerts = nOff * nAngles + nOn;
verts = zeros(nVerts, 3);
vertMap = zeros(nPts, nAngles);

k = 0;
for i = 1:nPts
if onAxis(i)
k = k + 1;
verts(k,:) = [0, 0, zProfile(i)];
vertMap(i,:) = k; % all angles map to same pole vertex
else
for j = 1:nAngles
k = k + 1;
verts(k,:) = [rProfile(i)*cos(theta(j)), ...
rProfile(i)*sin(theta(j)), ...
zProfile(i)];
vertMap(i,j) = k;
end
end
end

% Preallocate faces (upper bound: 2 triangles per quad)
faces = zeros(2 * nPts * nAngles, 3);
f = 0;

for i = 1:nPts
inext = mod(i, nPts) + 1;
for j = 1:nAngles
jnext = mod(j, nAngles) + 1;

v1 = vertMap(i, j);
v2 = vertMap(i, jnext);
v3 = vertMap(inext, j);
v4 = vertMap(inext, jnext);

if v1 == v2 && v3 == v4
% Both on axis — degenerate, skip
continue;
elseif v1 == v2
% Current point on axis — fan triangle
f = f + 1;
faces(f,:) = [v1, v4, v3];
elseif v3 == v4
% Next point on axis — fan triangle
f = f + 1;
faces(f,:) = [v1, v3, v2];
else
% Quad — split into two triangles
f = f + 1;
faces(f,:) = [v1, v4, v3];
f = f + 1;
faces(f,:) = [v1, v2, v4];
end
end
end

faces = faces(1:f,:);
tri = triangulation(faces, verts);
end
30 changes: 30 additions & 0 deletions neural-pde-surrogates-for-3d-electrostatics/findProjectRoot.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
function rootDir = findProjectRoot(marker)
%findProjectRoot Locate the root directory of the project.
% rootDir = findProjectRoot(marker) searches upward from the current
% directory to find a directory containing the file or folder specified
% by marker (e.g., 'startup.m', '.git').
%
% Inputs:
% marker - Name of a file or folder that identifies the project root
%
% Outputs:
% rootDir - Absolute path to the project root directory

if nargin < 1
marker = 'startup.m'; % Default marker file
end

currentDir = pwd;
while true
if exist(fullfile(currentDir, marker), 'file') || ...
exist(fullfile(currentDir, marker), 'dir')
rootDir = currentDir;
return;
end
[parentDir, ~, ~] = fileparts(currentDir);
if strcmp(currentDir, parentDir)
error('Project root not found. Marker "%s" not found in any parent directory.', marker);
end
currentDir = parentDir;
end
end
Loading