|
16 | 16 | #include <memory> |
17 | 17 | #include <algorithm> |
18 | 18 | #include <omp.h> |
| 19 | +#ifdef NEXT_MPI |
| 20 | + #include <mpi.h> |
| 21 | +#endif |
19 | 22 |
|
20 | 23 | /** |
21 | 24 | * @brief Performs a complete Leapfrog Step (Kick-Drift-Kick) using SoA data. |
22 | 25 | */ |
23 | 26 | inline void Step(ParticleSystem &ps, real dt) { |
24 | 27 | if (ps.size() == 0) return; |
25 | 28 |
|
26 | | - const real theta = 0.5; |
27 | | - const real half = dt * real(0.5); |
28 | | - const int N = static_cast<int>(ps.size()); |
| 29 | + const real theta = real(0.5); |
| 30 | + const real half = dt * real(0.5); |
| 31 | + const int N = static_cast<int>(ps.size()); |
| 32 | + |
| 33 | + // --------------------------------------------------------- |
| 34 | + // MPI SETUP |
| 35 | + // --------------------------------------------------------- |
| 36 | +#ifdef NEXT_MPI |
| 37 | + int rank, size; |
| 38 | + MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 39 | + MPI_Comm_size(MPI_COMM_WORLD, &size); |
| 40 | + |
| 41 | + // Select MPI datatype matching |
| 42 | + MPI_Datatype MPI_REAL_T; |
| 43 | +# ifdef NEXT_FP64 |
| 44 | + MPI_REAL_T = MPI_DOUBLE; |
| 45 | +# elif defined(NEXT_FP32) |
| 46 | + MPI_REAL_T = MPI_FLOAT; |
| 47 | +# else |
| 48 | +# error "Define NEXT_FP32 or NEXT_FP64 for 'real' type." |
| 49 | +# endif |
| 50 | +#else |
| 51 | + int rank = 0; |
| 52 | + int size = 1; |
| 53 | +#endif |
| 54 | + |
| 55 | + // --------------------------------------------------------- |
| 56 | + // DOMAIN DECOMPOSITION |
| 57 | + // --------------------------------------------------------- |
| 58 | + const int start = (rank * N) / size; |
| 59 | + const int end = ((rank + 1) * N) / size; |
| 60 | + |
| 61 | +#ifdef NEXT_MPI |
| 62 | + // Precompute counts and displacements for Allgatherv |
| 63 | + std::vector<int> counts(size), displs(size); |
| 64 | + for (int r = 0; r < size; ++r) { |
| 65 | + const int s = (r * N) / size; |
| 66 | + const int e = ((r + 1) * N) / size; |
| 67 | + counts[r] = e - s; |
| 68 | + displs[r] = s; |
| 69 | + } |
| 70 | +#endif |
29 | 71 |
|
30 | | - // Helper lambda to build the tree using the ParticleSystem indices |
| 72 | + // --------------------------------------------------------- |
| 73 | + // TREE BUILDER |
| 74 | + // --------------------------------------------------------- |
31 | 75 | auto buildTree = [&]() -> std::unique_ptr<Octree> { |
32 | | - real minx = 1e30, miny = 1e30, minz = 1e30; |
33 | | - real maxx = -1e30, maxy = -1e30, maxz = -1e30; |
| 76 | + struct BBox { real minx, miny, minz, maxx, maxy, maxz; }; |
| 77 | + BBox local{ real(1e30), real(1e30), real(1e30), |
| 78 | + real(-1e30), real(-1e30), real(-1e30) }; |
34 | 79 |
|
35 | | - // Bounding box calculation (SoA access is very fast here) |
36 | 80 | for (int i = 0; i < N; ++i) { |
37 | | - minx = std::min(minx, ps.x[i]); miny = std::min(miny, ps.y[i]); minz = std::min(minz, ps.z[i]); |
38 | | - maxx = std::max(maxx, ps.x[i]); maxy = std::max(maxy, ps.y[i]); maxz = std::max(maxz, ps.z[i]); |
| 81 | + local.minx = std::min(local.minx, ps.x[i]); |
| 82 | + local.miny = std::min(local.miny, ps.y[i]); |
| 83 | + local.minz = std::min(local.minz, ps.z[i]); |
| 84 | + local.maxx = std::max(local.maxx, ps.x[i]); |
| 85 | + local.maxy = std::max(local.maxy, ps.y[i]); |
| 86 | + local.maxz = std::max(local.maxz, ps.z[i]); |
39 | 87 | } |
40 | 88 |
|
41 | | - real cx = (minx + maxx) * 0.5; |
42 | | - real cy = (miny + maxy) * 0.5; |
43 | | - real cz = (minz + maxz) * 0.5; |
44 | | - real size = std::max({maxx - minx, maxy - miny, maxz - minz}) * 0.5; |
| 89 | +#ifdef NEXT_MPI |
| 90 | + BBox global; |
| 91 | + real mins[3] = {local.minx, local.miny, local.minz}; |
| 92 | + real maxs[3] = {local.maxx, local.maxy, local.maxz}; |
| 93 | + |
| 94 | + MPI_Allreduce(MPI_IN_PLACE, mins, 3, MPI_REAL_T, MPI_MIN, MPI_COMM_WORLD); |
| 95 | + MPI_Allreduce(MPI_IN_PLACE, maxs, 3, MPI_REAL_T, MPI_MAX, MPI_COMM_WORLD); |
| 96 | + |
| 97 | + global.minx = mins[0]; global.miny = mins[1]; global.minz = mins[2]; |
| 98 | + global.maxx = maxs[0]; global.maxy = maxs[1]; global.maxz = maxs[2]; |
45 | 99 |
|
46 | | - if (size <= 0) size = 1.0; |
| 100 | +#else |
| 101 | + BBox global = local; |
| 102 | +#endif |
| 103 | + |
| 104 | + const real cx = (global.minx + global.maxx) * real(0.5); |
| 105 | + const real cy = (global.miny + global.maxy) * real(0.5); |
| 106 | + const real cz = (global.minz + global.maxz) * real(0.5); |
| 107 | + real size = std::max({global.maxx - global.minx, |
| 108 | + global.maxy - global.miny, |
| 109 | + global.maxz - global.minz}) * real(0.5); |
| 110 | + |
| 111 | + if (size <= real(0)) size = real(1.0); |
47 | 112 |
|
48 | 113 | auto root = std::make_unique<Octree>(cx, cy, cz, size); |
49 | 114 |
|
50 | | - // Insert particle indices 0 to N-1 |
51 | | - for (int i = 0; i < N; ++i) { |
| 115 | + for (int i = 0; i < N; ++i) |
52 | 116 | root->insert(i, ps); |
53 | | - } |
54 | 117 |
|
55 | 118 | root->computeMass(ps); |
56 | 119 | return root; |
57 | 120 | }; |
58 | 121 |
|
59 | | - // --- First Kick (dt/2) --- |
| 122 | + // --------------------------------------------------------- |
| 123 | + // FIRST KICK |
| 124 | + // --------------------------------------------------------- |
60 | 125 | { |
61 | | - std::unique_ptr<Octree> root = buildTree(); |
| 126 | + auto root = buildTree(); |
62 | 127 |
|
63 | 128 | #pragma omp parallel for schedule(dynamic, 64) |
64 | | - for (int i = 0; i < N; ++i) { |
65 | | - real ax = 0, ay = 0, az = 0; |
| 129 | + for (int i = start; i < end; ++i) { |
| 130 | + real ax = real(0), ay = real(0), az = real(0); |
66 | 131 | bhAccel(root.get(), i, ps, theta, ax, ay, az); |
67 | 132 |
|
68 | 133 | ps.vx[i] += ax * half; |
69 | 134 | ps.vy[i] += ay * half; |
70 | 135 | ps.vz[i] += az * half; |
71 | 136 | } |
| 137 | + |
| 138 | +#ifdef NEXT_MPI |
| 139 | + MPI_Request reqs[3]; |
| 140 | + MPI_Iallgatherv(ps.vx.data() + start, end - start, MPI_REAL_T, |
| 141 | + ps.vx.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 142 | + MPI_COMM_WORLD, &reqs[0]); |
| 143 | + MPI_Iallgatherv(ps.vy.data() + start, end - start, MPI_REAL_T, |
| 144 | + ps.vy.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 145 | + MPI_COMM_WORLD, &reqs[1]); |
| 146 | + MPI_Iallgatherv(ps.vz.data() + start, end - start, MPI_REAL_T, |
| 147 | + ps.vz.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 148 | + MPI_COMM_WORLD, &reqs[2]); |
| 149 | + MPI_Waitall(3, reqs, MPI_STATUSES_IGNORE); |
| 150 | +#endif |
72 | 151 | } |
73 | 152 |
|
74 | | - // --- Drift (dt) --- |
75 | | - // Contiguous memory access makes this loop ideal for SIMD |
| 153 | + // --------------------------------------------------------- |
| 154 | + // DRIFT |
| 155 | + // --------------------------------------------------------- |
76 | 156 | #pragma omp parallel for schedule(static) |
77 | | - for (int i = 0; i < N; ++i) { |
| 157 | + for (int i = start; i < end; ++i) { |
78 | 158 | ps.x[i] += ps.vx[i] * dt; |
79 | 159 | ps.y[i] += ps.vy[i] * dt; |
80 | 160 | ps.z[i] += ps.vz[i] * dt; |
81 | 161 | } |
82 | 162 |
|
83 | | - // --- Second Kick (dt/2) --- |
| 163 | +#ifdef NEXT_MPI |
| 164 | + MPI_Request reqs[3]; |
| 165 | + MPI_Iallgatherv(ps.x.data() + start, end - start, MPI_REAL_T, |
| 166 | + ps.x.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 167 | + MPI_COMM_WORLD, &reqs[0]); |
| 168 | + MPI_Iallgatherv(ps.y.data() + start, end - start, MPI_REAL_T, |
| 169 | + ps.y.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 170 | + MPI_COMM_WORLD, &reqs[1]); |
| 171 | + MPI_Iallgatherv(ps.z.data() + start, end - start, MPI_REAL_T, |
| 172 | + ps.z.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 173 | + MPI_COMM_WORLD, &reqs[2]); |
| 174 | + MPI_Waitall(3, reqs, MPI_STATUSES_IGNORE); |
| 175 | +#endif |
| 176 | + |
| 177 | + // --------------------------------------------------------- |
| 178 | + // SECOND KICK |
| 179 | + // --------------------------------------------------------- |
84 | 180 | { |
85 | | - std::unique_ptr<Octree> root = buildTree(); |
| 181 | + auto root = buildTree(); |
86 | 182 |
|
87 | 183 | #pragma omp parallel for schedule(dynamic, 64) |
88 | | - for (int i = 0; i < N; ++i) { |
89 | | - real ax = 0, ay = 0, az = 0; |
| 184 | + for (int i = start; i < end; ++i) { |
| 185 | + real ax = real(0), ay = real(0), az = real(0); |
90 | 186 | bhAccel(root.get(), i, ps, theta, ax, ay, az); |
91 | 187 |
|
92 | 188 | ps.vx[i] += ax * half; |
93 | 189 | ps.vy[i] += ay * half; |
94 | 190 | ps.vz[i] += az * half; |
95 | 191 | } |
| 192 | + |
| 193 | +#ifdef NEXT_MPI |
| 194 | + MPI_Request reqs2[3]; |
| 195 | + MPI_Iallgatherv(ps.vx.data() + start, end - start, MPI_REAL_T, |
| 196 | + ps.vx.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 197 | + MPI_COMM_WORLD, &reqs2[0]); |
| 198 | + MPI_Iallgatherv(ps.vy.data() + start, end - start, MPI_REAL_T, |
| 199 | + ps.vy.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 200 | + MPI_COMM_WORLD, &reqs2[1]); |
| 201 | + MPI_Iallgatherv(ps.vz.data() + start, end - start, MPI_REAL_T, |
| 202 | + ps.vz.data(), counts.data(), displs.data(), MPI_REAL_T, |
| 203 | + MPI_COMM_WORLD, &reqs2[2]); |
| 204 | + MPI_Waitall(3, reqs2, MPI_STATUSES_IGNORE); |
| 205 | +#endif |
96 | 206 | } |
97 | 207 | } |
0 commit comments