Skip to content

Commit 40ee8ef

Browse files
committed
Fix hard-coded water model in uniform_random_rotation kernel.
1 parent 7f5e4de commit 40ee8ef

1 file changed

Lines changed: 37 additions & 40 deletions

File tree

src/loch/_kernels.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126
}
127127
128128
// Perform a random rotation about a unit sphere.
129-
DEVICE void uniform_random_rotation(float* v, float r0, float r1, float r2)
129+
DEVICE void uniform_random_rotation(float* v, int num_points, float r0, float r1, float r2)
130130
{
131131
/* Adapted from:
132132
https://www.blopig.com/blog/2021/08/uniformly-sampled-3d-rotation-matrices/
@@ -167,50 +167,47 @@
167167
168168
// Now compute M = -(H @ R), i.e. rotate all points around the x axis.
169169
float M[3][3];
170-
M[0][0] = -(H[0][0] * R[0][0] + H[0][1] * R[1][0] + H[0][2] * R[2][0]);
171-
M[0][1] = -(H[0][0] * R[0][1] + H[0][1] * R[1][1] + H[0][2] * R[2][1]);
172-
M[0][2] = -(H[0][0] * R[0][2] + H[0][1] * R[1][2] + H[0][2] * R[2][2]);
173-
M[1][0] = -(H[1][0] * R[0][0] + H[1][1] * R[1][0] + H[1][2] * R[2][0]);
174-
M[1][1] = -(H[1][0] * R[0][1] + H[1][1] * R[1][1] + H[1][2] * R[2][1]);
175-
M[1][2] = -(H[1][0] * R[0][2] + H[1][1] * R[1][2] + H[1][2] * R[2][2]);
176-
M[2][0] = -(H[2][0] * R[0][0] + H[2][1] * R[1][0] + H[2][2] * R[2][0]);
177-
M[2][1] = -(H[2][0] * R[0][1] + H[2][1] * R[1][1] + H[2][2] * R[2][1]);
178-
M[2][2] = -(H[2][0] * R[0][2] + H[2][1] * R[1][2] + H[2][2] * R[2][2]);
170+
for (int i = 0; i < 3; i++)
171+
{
172+
for (int j = 0; j < 3; j++)
173+
{
174+
M[i][j] = -(H[i][0] * R[0][j] + H[i][1] * R[1][j] + H[i][2] * R[2][j]);
175+
}
176+
}
179177
180178
// Compute the mean coordinate of the water molecule.
181179
float mean_coord[3];
182-
mean_coord[0] = (v[0] + v[3] + v[6]) / 3.0f;
183-
mean_coord[1] = (v[1] + v[4] + v[7]) / 3.0f;
184-
mean_coord[2] = (v[2] + v[5] + v[8]) / 3.0f;
180+
mean_coord[0] = 0.0f;
181+
mean_coord[1] = 0.0f;
182+
mean_coord[2] = 0.0f;
183+
for (int i = 0; i < num_points; i++)
184+
{
185+
mean_coord[0] += v[i * 3];
186+
mean_coord[1] += v[i * 3 + 1];
187+
mean_coord[2] += v[i * 3 + 2];
188+
}
189+
mean_coord[0] /= (float)num_points;
190+
mean_coord[1] /= (float)num_points;
191+
mean_coord[2] /= (float)num_points;
185192
186193
// Precompute mean_coord @ M (avoids redundant calculations).
187194
float mean_M[3];
188-
mean_M[0] = fmaf(mean_coord[0], M[0][0], fmaf(mean_coord[1], M[1][0], mean_coord[2] * M[2][0]));
189-
mean_M[1] = fmaf(mean_coord[0], M[0][1], fmaf(mean_coord[1], M[1][1], mean_coord[2] * M[2][1]));
190-
mean_M[2] = fmaf(mean_coord[0], M[0][2], fmaf(mean_coord[1], M[1][2], mean_coord[2] * M[2][2]));
191-
192-
// Now compute ((v - mean_coord) @ M) + mean_M.
193-
float x[3][3];
194-
x[0][0] = v[0] - mean_coord[0];
195-
x[0][1] = v[1] - mean_coord[1];
196-
x[0][2] = v[2] - mean_coord[2];
197-
x[1][0] = v[3] - mean_coord[0];
198-
x[1][1] = v[4] - mean_coord[1];
199-
x[1][2] = v[5] - mean_coord[2];
200-
x[2][0] = v[6] - mean_coord[0];
201-
x[2][1] = v[7] - mean_coord[1];
202-
x[2][2] = v[8] - mean_coord[2];
203-
204-
// Compute the rotated coordinates using fma.
205-
v[0] = fmaf(x[0][0], M[0][0], fmaf(x[0][1], M[1][0], fmaf(x[0][2], M[2][0], mean_M[0])));
206-
v[1] = fmaf(x[0][0], M[0][1], fmaf(x[0][1], M[1][1], fmaf(x[0][2], M[2][1], mean_M[1])));
207-
v[2] = fmaf(x[0][0], M[0][2], fmaf(x[0][1], M[1][2], fmaf(x[0][2], M[2][2], mean_M[2])));
208-
v[3] = fmaf(x[1][0], M[0][0], fmaf(x[1][1], M[1][0], fmaf(x[1][2], M[2][0], mean_M[0])));
209-
v[4] = fmaf(x[1][0], M[0][1], fmaf(x[1][1], M[1][1], fmaf(x[1][2], M[2][1], mean_M[1])));
210-
v[5] = fmaf(x[1][0], M[0][2], fmaf(x[1][1], M[1][2], fmaf(x[1][2], M[2][2], mean_M[2])));
211-
v[6] = fmaf(x[2][0], M[0][0], fmaf(x[2][1], M[1][0], fmaf(x[2][2], M[2][0], mean_M[0])));
212-
v[7] = fmaf(x[2][0], M[0][1], fmaf(x[2][1], M[1][1], fmaf(x[2][2], M[2][1], mean_M[1])));
213-
v[8] = fmaf(x[2][0], M[0][2], fmaf(x[2][1], M[1][2], fmaf(x[2][2], M[2][2], mean_M[2])));
195+
for (int j = 0; j < 3; j++)
196+
{
197+
mean_M[j] = fmaf(mean_coord[0], M[0][j], fmaf(mean_coord[1], M[1][j], mean_coord[2] * M[2][j]));
198+
}
199+
200+
// Compute ((v - mean_coord) @ M) + mean_M for each atom.
201+
for (int i = 0; i < num_points; i++)
202+
{
203+
float dx = v[i * 3] - mean_coord[0];
204+
float dy = v[i * 3 + 1] - mean_coord[1];
205+
float dz = v[i * 3 + 2] - mean_coord[2];
206+
207+
v[i * 3] = fmaf(dx, M[0][0], fmaf(dy, M[1][0], fmaf(dz, M[2][0], mean_M[0])));
208+
v[i * 3 + 1] = fmaf(dx, M[0][1], fmaf(dy, M[1][1], fmaf(dz, M[2][1], mean_M[1])));
209+
v[i * 3 + 2] = fmaf(dx, M[0][2], fmaf(dy, M[1][2], fmaf(dz, M[2][2], mean_M[2])));
210+
}
214211
}
215212
216213
// Update a single water.
@@ -302,7 +299,7 @@
302299
}
303300
304301
// Rotate the water randomly using pre-generated randoms.
305-
uniform_random_rotation(water,
302+
uniform_random_rotation(water, num_points,
306303
randoms_rotation[tidx * 3],
307304
randoms_rotation[tidx * 3 + 1],
308305
randoms_rotation[tidx * 3 + 2]);

0 commit comments

Comments
 (0)