Skip to content

Commit 0a89fa8

Browse files
committed
Fixing merging issue
1 parent ad342c3 commit 0a89fa8

1 file changed

Lines changed: 1 addition & 13 deletions

File tree

src/graphnet/training/loss_functions.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -290,45 +290,35 @@ def backward(
290290
Mathematical Background:
291291
-----------------------
292292
For the von Mises-Fisher distribution, the gradient of log C_m(κ) with
293-
For the von Mises-Fisher distribution, the gradient of log C_m(κ) with
294293
respect to κ is given by the ratio of modified Bessel functions:
295294
296-
297295
∂/∂κ log C_m(κ) = (m/2-1)/κ - I_{m/2}(κ)/I_{m/2-1}(κ)
298296
299-
300297
For m=3, this simplifies to the exact formula:
301298
∂/∂κ log C_3(κ) = 1/κ - 1/tanh(κ)
302299
303-
304300
For small κ values, we use the Taylor series approximation:
305301
f(κ) = -κ/3 + κ³/45 - 2κ⁵/945 + O(κ⁷)
306302
307-
The first-order approximation -κ/3 provides sufficient accuracy for
308-
309303
The first-order approximation -κ/3 provides sufficient accuracy for
310304
|κ| < 1e-6, with truncation error bounded by |κ|³/45 ≲ O(10⁻²¹).
311305
312-
313306
Implementation Details:
314307
----------------------
315308
Uses boolean masking to avoid double evaluation and RuntimeWarnings:
316309
- Small κ: |κ| < 1e-6 → gradient = -κ/3 (Taylor approximation)
317310
- Large κ: |κ| ≥ 1e-6 → gradient = 1/κ - 1/tanh(κ) (exact formula)
318311
319-
320312
References:
321313
----------
322314
[1] von Mises-Fisher distribution: Wikipedia
323315
[2] arXiv:1812.04616, Section 8.2
324316
[3] MIT License (c) 2019 Max Ryabinin - Modified for GraphNeT
325317
326-
327318
Args:
328319
ctx: Autograd context containing saved tensors and metadata.
329320
grad_output: Gradient with respect to the output tensor.
330321
331-
332322
Returns:
333323
Tuple of gradients: (None for m, gradient w.r.t. κ).
334324
"""
@@ -340,8 +330,7 @@ def backward(
340330
# Initialize gradient array
341331
grads = np.zeros_like(kappa)
342332

343-
# Handle small kappa values (including zero)
344-
# to avoid division by zero
333+
# Handle small kappa values (including zero) to avoid division by zero
345334
small_mask = np.abs(kappa) < 1e-6
346335
grads[small_mask] = -kappa[small_mask] / 3
347336

@@ -350,7 +339,6 @@ def backward(
350339
if np.any(large_mask):
351340
kappa_large = kappa[large_mask]
352341
grads[large_mask] = 1 / kappa_large - 1 / np.tanh(kappa_large)
353-
grads[large_mask] = 1 / kappa_large - 1 / np.tanh(kappa_large)
354342
else:
355343
grads = -(
356344
(scipy.special.iv(m / 2.0, kappa))

0 commit comments

Comments
 (0)