Skip to content

Commit 17f2883

Browse files
Update step.h
1 parent baa3909 commit 17f2883

1 file changed

Lines changed: 138 additions & 28 deletions

File tree

src/gravity/step.h

Lines changed: 138 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,82 +16,192 @@
1616
#include <memory>
1717
#include <algorithm>
1818
#include <omp.h>
19+
#ifdef NEXT_MPI
20+
#include <mpi.h>
21+
#endif
1922

2023
/**
2124
* @brief Performs a complete Leapfrog Step (Kick-Drift-Kick) using SoA data.
2225
*/
2326
inline void Step(ParticleSystem &ps, real dt) {
2427
if (ps.size() == 0) return;
2528

26-
const real theta = 0.5;
27-
const real half = dt * real(0.5);
28-
const int N = static_cast<int>(ps.size());
29+
const real theta = real(0.5);
30+
const real half = dt * real(0.5);
31+
const int N = static_cast<int>(ps.size());
32+
33+
// ---------------------------------------------------------
34+
// MPI SETUP
35+
// ---------------------------------------------------------
36+
#ifdef NEXT_MPI
37+
int rank, size;
38+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
39+
MPI_Comm_size(MPI_COMM_WORLD, &size);
40+
41+
// Select MPI datatype matching
42+
MPI_Datatype MPI_REAL_T;
43+
# ifdef NEXT_FP64
44+
MPI_REAL_T = MPI_DOUBLE;
45+
# elif defined(NEXT_FP32)
46+
MPI_REAL_T = MPI_FLOAT;
47+
# else
48+
# error "Define NEXT_FP32 or NEXT_FP64 for 'real' type."
49+
# endif
50+
#else
51+
int rank = 0;
52+
int size = 1;
53+
#endif
54+
55+
// ---------------------------------------------------------
56+
// DOMAIN DECOMPOSITION
57+
// ---------------------------------------------------------
58+
const int start = (rank * N) / size;
59+
const int end = ((rank + 1) * N) / size;
60+
61+
#ifdef NEXT_MPI
62+
// Precompute counts and displacements for Allgatherv
63+
std::vector<int> counts(size), displs(size);
64+
for (int r = 0; r < size; ++r) {
65+
const int s = (r * N) / size;
66+
const int e = ((r + 1) * N) / size;
67+
counts[r] = e - s;
68+
displs[r] = s;
69+
}
70+
#endif
2971

30-
// Helper lambda to build the tree using the ParticleSystem indices
72+
// ---------------------------------------------------------
73+
// TREE BUILDER
74+
// ---------------------------------------------------------
3175
auto buildTree = [&]() -> std::unique_ptr<Octree> {
32-
real minx = 1e30, miny = 1e30, minz = 1e30;
33-
real maxx = -1e30, maxy = -1e30, maxz = -1e30;
76+
struct BBox { real minx, miny, minz, maxx, maxy, maxz; };
77+
BBox local{ real(1e30), real(1e30), real(1e30),
78+
real(-1e30), real(-1e30), real(-1e30) };
3479

35-
// Bounding box calculation (SoA access is very fast here)
3680
for (int i = 0; i < N; ++i) {
37-
minx = std::min(minx, ps.x[i]); miny = std::min(miny, ps.y[i]); minz = std::min(minz, ps.z[i]);
38-
maxx = std::max(maxx, ps.x[i]); maxy = std::max(maxy, ps.y[i]); maxz = std::max(maxz, ps.z[i]);
81+
local.minx = std::min(local.minx, ps.x[i]);
82+
local.miny = std::min(local.miny, ps.y[i]);
83+
local.minz = std::min(local.minz, ps.z[i]);
84+
local.maxx = std::max(local.maxx, ps.x[i]);
85+
local.maxy = std::max(local.maxy, ps.y[i]);
86+
local.maxz = std::max(local.maxz, ps.z[i]);
3987
}
4088

41-
real cx = (minx + maxx) * 0.5;
42-
real cy = (miny + maxy) * 0.5;
43-
real cz = (minz + maxz) * 0.5;
44-
real size = std::max({maxx - minx, maxy - miny, maxz - minz}) * 0.5;
89+
#ifdef NEXT_MPI
90+
BBox global;
91+
real mins[3] = {local.minx, local.miny, local.minz};
92+
real maxs[3] = {local.maxx, local.maxy, local.maxz};
93+
94+
MPI_Allreduce(MPI_IN_PLACE, mins, 3, MPI_REAL_T, MPI_MIN, MPI_COMM_WORLD);
95+
MPI_Allreduce(MPI_IN_PLACE, maxs, 3, MPI_REAL_T, MPI_MAX, MPI_COMM_WORLD);
96+
97+
global.minx = mins[0]; global.miny = mins[1]; global.minz = mins[2];
98+
global.maxx = maxs[0]; global.maxy = maxs[1]; global.maxz = maxs[2];
4599

46-
if (size <= 0) size = 1.0;
100+
#else
101+
BBox global = local;
102+
#endif
103+
104+
const real cx = (global.minx + global.maxx) * real(0.5);
105+
const real cy = (global.miny + global.maxy) * real(0.5);
106+
const real cz = (global.minz + global.maxz) * real(0.5);
107+
real size = std::max({global.maxx - global.minx,
108+
global.maxy - global.miny,
109+
global.maxz - global.minz}) * real(0.5);
110+
111+
if (size <= real(0)) size = real(1.0);
47112

48113
auto root = std::make_unique<Octree>(cx, cy, cz, size);
49114

50-
// Insert particle indices 0 to N-1
51-
for (int i = 0; i < N; ++i) {
115+
for (int i = 0; i < N; ++i)
52116
root->insert(i, ps);
53-
}
54117

55118
root->computeMass(ps);
56119
return root;
57120
};
58121

59-
// --- First Kick (dt/2) ---
122+
// ---------------------------------------------------------
123+
// FIRST KICK
124+
// ---------------------------------------------------------
60125
{
61-
std::unique_ptr<Octree> root = buildTree();
126+
auto root = buildTree();
62127

63128
#pragma omp parallel for schedule(dynamic, 64)
64-
for (int i = 0; i < N; ++i) {
65-
real ax = 0, ay = 0, az = 0;
129+
for (int i = start; i < end; ++i) {
130+
real ax = real(0), ay = real(0), az = real(0);
66131
bhAccel(root.get(), i, ps, theta, ax, ay, az);
67132

68133
ps.vx[i] += ax * half;
69134
ps.vy[i] += ay * half;
70135
ps.vz[i] += az * half;
71136
}
137+
138+
#ifdef NEXT_MPI
139+
MPI_Request reqs[3];
140+
MPI_Iallgatherv(ps.vx.data() + start, end - start, MPI_REAL_T,
141+
ps.vx.data(), counts.data(), displs.data(), MPI_REAL_T,
142+
MPI_COMM_WORLD, &reqs[0]);
143+
MPI_Iallgatherv(ps.vy.data() + start, end - start, MPI_REAL_T,
144+
ps.vy.data(), counts.data(), displs.data(), MPI_REAL_T,
145+
MPI_COMM_WORLD, &reqs[1]);
146+
MPI_Iallgatherv(ps.vz.data() + start, end - start, MPI_REAL_T,
147+
ps.vz.data(), counts.data(), displs.data(), MPI_REAL_T,
148+
MPI_COMM_WORLD, &reqs[2]);
149+
MPI_Waitall(3, reqs, MPI_STATUSES_IGNORE);
150+
#endif
72151
}
73152

74-
// --- Drift (dt) ---
75-
// Contiguous memory access makes this loop ideal for SIMD
153+
// ---------------------------------------------------------
154+
// DRIFT
155+
// ---------------------------------------------------------
76156
#pragma omp parallel for schedule(static)
77-
for (int i = 0; i < N; ++i) {
157+
for (int i = start; i < end; ++i) {
78158
ps.x[i] += ps.vx[i] * dt;
79159
ps.y[i] += ps.vy[i] * dt;
80160
ps.z[i] += ps.vz[i] * dt;
81161
}
82162

83-
// --- Second Kick (dt/2) ---
163+
#ifdef NEXT_MPI
164+
MPI_Request reqs[3];
165+
MPI_Iallgatherv(ps.x.data() + start, end - start, MPI_REAL_T,
166+
ps.x.data(), counts.data(), displs.data(), MPI_REAL_T,
167+
MPI_COMM_WORLD, &reqs[0]);
168+
MPI_Iallgatherv(ps.y.data() + start, end - start, MPI_REAL_T,
169+
ps.y.data(), counts.data(), displs.data(), MPI_REAL_T,
170+
MPI_COMM_WORLD, &reqs[1]);
171+
MPI_Iallgatherv(ps.z.data() + start, end - start, MPI_REAL_T,
172+
ps.z.data(), counts.data(), displs.data(), MPI_REAL_T,
173+
MPI_COMM_WORLD, &reqs[2]);
174+
MPI_Waitall(3, reqs, MPI_STATUSES_IGNORE);
175+
#endif
176+
177+
// ---------------------------------------------------------
178+
// SECOND KICK
179+
// ---------------------------------------------------------
84180
{
85-
std::unique_ptr<Octree> root = buildTree();
181+
auto root = buildTree();
86182

87183
#pragma omp parallel for schedule(dynamic, 64)
88-
for (int i = 0; i < N; ++i) {
89-
real ax = 0, ay = 0, az = 0;
184+
for (int i = start; i < end; ++i) {
185+
real ax = real(0), ay = real(0), az = real(0);
90186
bhAccel(root.get(), i, ps, theta, ax, ay, az);
91187

92188
ps.vx[i] += ax * half;
93189
ps.vy[i] += ay * half;
94190
ps.vz[i] += az * half;
95191
}
192+
193+
#ifdef NEXT_MPI
194+
MPI_Request reqs2[3];
195+
MPI_Iallgatherv(ps.vx.data() + start, end - start, MPI_REAL_T,
196+
ps.vx.data(), counts.data(), displs.data(), MPI_REAL_T,
197+
MPI_COMM_WORLD, &reqs2[0]);
198+
MPI_Iallgatherv(ps.vy.data() + start, end - start, MPI_REAL_T,
199+
ps.vy.data(), counts.data(), displs.data(), MPI_REAL_T,
200+
MPI_COMM_WORLD, &reqs2[1]);
201+
MPI_Iallgatherv(ps.vz.data() + start, end - start, MPI_REAL_T,
202+
ps.vz.data(), counts.data(), displs.data(), MPI_REAL_T,
203+
MPI_COMM_WORLD, &reqs2[2]);
204+
MPI_Waitall(3, reqs2, MPI_STATUSES_IGNORE);
205+
#endif
96206
}
97207
}

0 commit comments

Comments
 (0)