-
Notifications
You must be signed in to change notification settings - Fork 48
Adding 3D Electrostatics SciML example #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mmarkows17
wants to merge
1
commit into
matlab-deep-learning:main
Choose a base branch
from
mmarkows17:publish/3D-electrostatics-example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file added
BIN
+67 KB
neural-pde-surrogates-for-3d-electrostatics/NeuralPDEInferenceApp.mlapp
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # Rapid Design Exploration with Neural PDE Surrogates | ||
|
|
||
| This demo shows how neural partial differential equation (PDE) surrogates can accelerate design exploration for a 3D electrostatics problem. | ||
| The goal is to predict the electric potential field throughout a transformer bushing insulator, | ||
| a cylindrical component with fins that provides electrical insulation between a high-voltage conductor and a grounded enclosure, | ||
| for varying bushing geometries. Two deep learning architectures are trained on finite element analysis (FEA) simulations to learn this mapping directly from geometry, enabling rapid evaluation of new designs without running a full simulation. | ||
|
|
||
| ## Architectures | ||
|
|
||
| Two neural PDE surrogate architectures are demonstrated: | ||
|
|
||
| - **Transolver** [\[1\]](README.md#references) — A transformer-based architecture for learning PDE solutions on unstructured meshes. Rather than applying attention across all mesh nodes, Transolver learns to group nodes into a small number of physics slices and applies attention across these slices, reducing the quadratic cost of global attention. Inputs are the 3D coordinates and material properties (relative permittivity) at each mesh node. | ||
|
|
||
| - **MeshGraphNet** [\[2\]](README.md#references) — A graph neural network that represents the FEA mesh as a graph (nodes = mesh nodes, edges = element connectivity). Through message-passing layers, the network learns how information propagates through the geometry. Inputs are the 3D coordinates of the mesh nodes within the bushing insulator. | ||
|
|
||
| Both architectures output the electric potential at every input node. | ||
|
|
||
| ## Training Data | ||
| The problem setup follows the documentation example [Electrostatic Analysis of Transformer Bushing Insulator](https://www.mathworks.com/help/pde/ug/electrostatic-analysis-of-transformer-bushing-insulator.html). | ||
| Boundary conditions are held fixed; only the transformer bushing insulator geometry varies between samples. | ||
|
|
||
| The training dataset consists of 75 electrostatic simulations, each on a different procedurally generated transformer bushing geometry. For each simulation, we extract: | ||
| - FEA mesh nodes and element connectivity | ||
| - Material properties (relative permittivity at each node) | ||
| - Electric potential at each node | ||
| - Electric field at each node (for optional gradient regularization) | ||
|
|
||
| ## Results | ||
|
|
||
| Below are example predictions from two trained models on two geometries. | ||
|
|
||
| **Transolver** predictions: | ||
|  | ||
|
|
||
|  | ||
|
|
||
| **MeshGraphNet** predictions: | ||
|  | ||
|
|
||
|  | ||
|
|
||
|
|
||
| # Getting Started | ||
|
|
||
| **Setup:** | ||
| - Run [startup.m](startup.m) to set the MATLAB path. This also creates the `data/`, `STL/`, and `results/` subdirectories. | ||
|
|
||
| **Data generation:** | ||
| - Run [generate\_data.m](generate_data.m) to generate the training dataset (~5 minutes with 8 parallel workers). This saves simulation results to `data/` and STL geometry files to `STL/`. | ||
|
|
||
| **Training:** | ||
| - **Transolver:** Run [transolver\_script.m](transolver_script.m) (includes optional gradient regularization fine-tuning). | ||
| - **MeshGraphNet:** Run [meshgraphnet\_script.m](meshgraphnet_script.m). | ||
| - Set `doTrain = true` to train from scratch (GPU recommended). The trained model is saved to `results/`. | ||
| - Set `doTrain = false` and specify a filename in `loadFile` to load a pretrained model. | ||
| - Set `doFinetune = true` to finetune a trained model with gradient regularization. This adds a loss term that penalizes discrepancies between the predicted electric field $-\nabla V$ and the reference FEA electric field, where the spatial gradients are computed via automatic differentiation. This encourages spatially smooth predictions and improves derived electric field accuracy (Transolver only). | ||
|
|
||
| **Inference app:** | ||
| - Type `NeuralPDEInferenceApp` in the Command Window. | ||
| - Click "Load AI Model" and select a model from the `results/` subdirectory. | ||
| - Click "Load Geometry" and choose an STL file from the `STL/` folder. | ||
| - Click "Predict!" to display the predicted electric potential. | ||
| - Toggle the switch to view electric field magnitude, which is numerically derived from the predicted potential. | ||
|
|
||
| # Required Products | ||
| - MATLAB® (tested on R2026a) | ||
| - PDE Toolbox™ | ||
| - Deep Learning Toolbox™ | ||
| - Parallel Computing Toolbox™ | ||
| - Statistics and Machine Learning Toolbox™ | ||
|
|
||
| # References | ||
| 1. Wu, Haixu, Huakun Luo, Haowen Wang, Jianmin Wang, and Mingsheng Long. "Transolver: A Fast Transformer Solver for PDEs on General Geometries." arXiv, June 1, 2024. https://arxiv.org/abs/2402.02366. | ||
| 2. Pfaff, Tobias, Meire Fortunato, Alvaro Sanchez-Gonzalez, and Peter W. Battaglia. "Learning Mesh-Based Simulation with Graph Networks." arXiv, June 18, 2021. https://arxiv.org/abs/2010.03409. | ||
Binary file added
BIN
+109 KB
neural-pde-surrogates-for-3d-electrostatics/README_media/meshgraphnet_test_39.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+108 KB
neural-pde-surrogates-for-3d-electrostatics/README_media/meshgraphnet_test_57.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+87.5 KB
neural-pde-surrogates-for-3d-electrostatics/README_media/transolver_test_39.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+88.1 KB
neural-pde-surrogates-for-3d-electrostatics/README_media/transolver_test_57.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
209 changes: 209 additions & 0 deletions
209
neural-pde-surrogates-for-3d-electrostatics/createBushing.m
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| function [gm, faceIDs] = createBushing(nFins, finRadiusBase, finRadiusTop, finWidth, tubeRadiusBase, tubeRadiusTop, boreRadius, totalLength) | ||
| %createBushing Create a parametric transformer bushing geometry. | ||
| % [gm, faceIDs] = createBushing(nFins, finRadiusBase, finRadiusTop, | ||
| % finWidth, tubeRadiusBase, tubeRadiusTop, boreRadius, totalLength) | ||
| % creates an axisymmetric transformer bushing with nFins cooling fin | ||
| % rings. The fins have a raised-cosine profile and are evenly spaced. | ||
| % Fin radii are linearly interpolated from finRadiusBase (first fin) to | ||
| % finRadiusTop (last fin). The tube tapers linearly from tubeRadiusBase | ||
| % at Z=0 to tubeRadiusTop at Z=totalLength. | ||
| % | ||
| % The geometry axis is along Z. The last (topmost) fin has a flat annular | ||
| % face at its peak for applying a boundary condition. | ||
| % | ||
| % Inputs: | ||
| % nFins - Number of cooling fin rings | ||
| % finRadiusBase - Outer radius of the first (bottom) fin | ||
| % finRadiusTop - Outer radius of the last (top) fin | ||
| % finWidth - Axial width of each fin (full width of cosine bump) | ||
| % tubeRadiusBase - Tube outer radius at Z=0 (bottom) | ||
| % tubeRadiusTop - Tube outer radius at Z=totalLength (top) | ||
| % boreRadius - Inner bore radius (central hole) | ||
| % totalLength - Total axial length of the bushing | ||
| % | ||
| % Outputs: | ||
| % gm - fegeometry object | ||
| % faceIDs - struct with fields: | ||
| % .bore - face ID for the inner bore surface | ||
| % .flatAnnular - face ID for the flat annular ring at | ||
| % the top of the last fin | ||
|
|
||
| %% Compute fin positions and radii | ||
| % Asymmetric margins: larger at the bottom, small stub at the top. | ||
| % endStub is the distance from the last fin CENTER to the top of the | ||
| % bushing — this is the visible tube length above the split face. | ||
| endStub = 0.06 * totalLength; % short tube stub above the last fin peak | ||
| startMargin = 0.12 * totalLength; % longer tube section at the bottom | ||
| finPositions = linspace(startMargin + finWidth/2, ... | ||
| totalLength - endStub, nFins); | ||
| finRadii = linspace(finRadiusBase, finRadiusTop, nFins); | ||
|
|
||
| % Clamp finWidth so adjacent fins don't overlap (leave a small tube gap) | ||
| if nFins > 1 | ||
| spacing = finPositions(2) - finPositions(1); | ||
| maxFinWidth = 0.9 * spacing; | ||
| if finWidth > maxFinWidth | ||
| finWidth = maxFinWidth; | ||
| end | ||
| end | ||
|
|
||
| %% Helper: tube radius at any axial position (linear taper) | ||
| tubeRadiusAt = @(z) tubeRadiusBase + (tubeRadiusTop - tubeRadiusBase) * z / totalLength; | ||
|
|
||
| %% Build the outer radius profile | ||
| nArc = 8; % points per fin bump (raised cosine) | ||
|
|
||
| z_all = []; | ||
| r_all = []; | ||
|
|
||
| % Tube section before first fin | ||
| z_start = linspace(0, finPositions(1) - finWidth/2, 3); | ||
| z_all = z_start; | ||
| r_all = tubeRadiusAt(z_start); | ||
|
|
||
| for i = 1:nFins | ||
| hw = finWidth / 2; | ||
| center = finPositions(i); | ||
|
|
||
| % Raised cosine bump for the fin, sitting on the local tube radius | ||
| zFin = linspace(center - hw, center + hw, nArc); | ||
| localTubeR = tubeRadiusAt(zFin); | ||
| bumpHeight = finRadii(i) - localTubeR; | ||
| rFin = localTubeR + 0.5 .* bumpHeight .* ... | ||
| (1 + cos(pi * (zFin - center) / hw)); | ||
|
|
||
| z_all = [z_all, zFin]; %#ok<AGROW> | ||
| r_all = [r_all, rFin]; %#ok<AGROW> | ||
|
|
||
| % Tube section after this fin | ||
| if i < nFins | ||
| zTube = linspace(center + hw + 0.002, ... | ||
| finPositions(i+1) - hw - 0.002, 3); | ||
| else | ||
| zTube = linspace(center + hw + 0.002, totalLength, 3); | ||
| end | ||
| z_all = [z_all, zTube]; %#ok<AGROW> | ||
| r_all = [r_all, tubeRadiusAt(zTube)]; %#ok<AGROW> | ||
| end | ||
|
|
||
| % Remove duplicates and sort | ||
| [z_all, ia] = unique(z_all); | ||
| r_all = r_all(ia); | ||
|
|
||
| %% Split at peak of last fin to create the flat annular face | ||
| lastFinCenter = finPositions(end); | ||
| [~, splitIdx] = min(abs(z_all - lastFinCenter)); | ||
|
|
||
| %% Revolve lower body (all fins up to peak of last fin) | ||
| xv_low = [0, r_all(1:splitIdx), 0]; | ||
| yv_low = [z_all(1), z_all(1:splitIdx), z_all(splitIdx)]; | ||
| gm_lower = fegeometry(revolveProfile(xv_low, yv_low)); | ||
|
|
||
| %% Revolve upper body (tube stub above last fin) | ||
| % Clamp radii to the local tube radius (remove any residual fin contribution) | ||
| localTubeAtSplit = tubeRadiusAt(z_all(splitIdx)); | ||
| r_upper_raw = r_all(splitIdx+1:end); | ||
| r_upper_clamped = min(r_upper_raw, tubeRadiusAt(z_all(splitIdx+1:end))); | ||
| r_upper = [localTubeAtSplit, r_upper_clamped]; | ||
| ax_upper = [z_all(splitIdx), z_all(splitIdx+1:end)]; | ||
| xv_up = [0, r_upper, 0]; | ||
| yv_up = [ax_upper(1), ax_upper, ax_upper(end)]; | ||
| gm_upper = fegeometry(revolveProfile(xv_up, yv_up)); | ||
|
|
||
| %% Union the two halves and subtract the bore | ||
| gm = union(gm_lower, gm_upper); | ||
| bore = fegeometry(multicylinder(boreRadius, totalLength)); | ||
| gm = subtract(gm, bore); | ||
|
|
||
| %% Find face IDs programmatically | ||
| z_split = z_all(splitIdx); | ||
| mid_r = (localTubeAtSplit + r_all(splitIdx)) / 2; | ||
| faceIDs.flatAnnular = nearestFace(gm, [mid_r, 0, z_split]); | ||
| faceIDs.bore = nearestFace(gm, [boreRadius, 0, totalLength/2]); | ||
|
|
||
| end | ||
|
|
||
|
|
||
| function tri = revolveProfile(rProfile, zProfile, nAngles) | ||
| %revolveProfile Create a triangulated solid of revolution. | ||
| % tri = revolveProfile(rProfile, zProfile) revolves the closed 2D | ||
| % polygon defined by (rProfile, zProfile) around the Z axis to produce | ||
| % a closed triangulated surface mesh. rProfile contains radial | ||
| % coordinates (>= 0) and zProfile contains axial coordinates. | ||
| % Vertices with r ≈ 0 are treated as on-axis pole points. | ||
| % | ||
| % tri = revolveProfile(rProfile, zProfile, nAngles) specifies the | ||
| % number of angular divisions (default: 72, i.e. 5-degree steps). | ||
|
|
||
| if nargin < 3 | ||
| nAngles = 72; | ||
| end | ||
|
|
||
| theta = linspace(0, 2*pi, nAngles + 1); | ||
| theta(end) = []; % remove duplicate at 2*pi | ||
|
|
||
| nPts = numel(rProfile); | ||
| onAxis = rProfile < 1e-10; | ||
| nOff = sum(~onAxis); | ||
| nOn = sum(onAxis); | ||
|
|
||
| % Preallocate vertices | ||
| nVerts = nOff * nAngles + nOn; | ||
| verts = zeros(nVerts, 3); | ||
| vertMap = zeros(nPts, nAngles); | ||
|
|
||
| k = 0; | ||
| for i = 1:nPts | ||
| if onAxis(i) | ||
| k = k + 1; | ||
| verts(k,:) = [0, 0, zProfile(i)]; | ||
| vertMap(i,:) = k; % all angles map to same pole vertex | ||
| else | ||
| for j = 1:nAngles | ||
| k = k + 1; | ||
| verts(k,:) = [rProfile(i)*cos(theta(j)), ... | ||
| rProfile(i)*sin(theta(j)), ... | ||
| zProfile(i)]; | ||
| vertMap(i,j) = k; | ||
| end | ||
| end | ||
| end | ||
|
|
||
| % Preallocate faces (upper bound: 2 triangles per quad) | ||
| faces = zeros(2 * nPts * nAngles, 3); | ||
| f = 0; | ||
|
|
||
| for i = 1:nPts | ||
| inext = mod(i, nPts) + 1; | ||
| for j = 1:nAngles | ||
| jnext = mod(j, nAngles) + 1; | ||
|
|
||
| v1 = vertMap(i, j); | ||
| v2 = vertMap(i, jnext); | ||
| v3 = vertMap(inext, j); | ||
| v4 = vertMap(inext, jnext); | ||
|
|
||
| if v1 == v2 && v3 == v4 | ||
| % Both on axis — degenerate, skip | ||
| continue; | ||
| elseif v1 == v2 | ||
| % Current point on axis — fan triangle | ||
| f = f + 1; | ||
| faces(f,:) = [v1, v4, v3]; | ||
| elseif v3 == v4 | ||
| % Next point on axis — fan triangle | ||
| f = f + 1; | ||
| faces(f,:) = [v1, v3, v2]; | ||
| else | ||
| % Quad — split into two triangles | ||
| f = f + 1; | ||
| faces(f,:) = [v1, v4, v3]; | ||
| f = f + 1; | ||
| faces(f,:) = [v1, v2, v4]; | ||
| end | ||
| end | ||
| end | ||
|
|
||
| faces = faces(1:f,:); | ||
| tri = triangulation(faces, verts); | ||
| end |
30 changes: 30 additions & 0 deletions
30
neural-pde-surrogates-for-3d-electrostatics/findProjectRoot.m
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| function rootDir = findProjectRoot(marker) | ||
| %findProjectRoot Locate the root directory of the project. | ||
| % rootDir = findProjectRoot(marker) searches upward from the current | ||
| % directory to find a directory containing the file or folder specified | ||
| % by marker (e.g., 'startup.m', '.git'). | ||
| % | ||
| % Inputs: | ||
| % marker - Name of a file or folder that identifies the project root | ||
| % | ||
| % Outputs: | ||
| % rootDir - Absolute path to the project root directory | ||
|
|
||
| if nargin < 1 | ||
| marker = 'startup.m'; % Default marker file | ||
| end | ||
|
|
||
| currentDir = pwd; | ||
| while true | ||
| if exist(fullfile(currentDir, marker), 'file') || ... | ||
| exist(fullfile(currentDir, marker), 'dir') | ||
| rootDir = currentDir; | ||
| return; | ||
| end | ||
| [parentDir, ~, ~] = fileparts(currentDir); | ||
| if strcmp(currentDir, parentDir) | ||
| error('Project root not found. Marker "%s" not found in any parent directory.', marker); | ||
| end | ||
| currentDir = parentDir; | ||
| end | ||
| end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we generate markdown from one or both of the transolver/meshgraphnet examples and find a way to include it in the repo, either in this readme or separate markdowns? I think it's important for examples to be readable on GitHub without requring someone to clone the repo and open the files in MATLAB just to read what the examples are doing.