@@ -290,45 +290,35 @@ def backward(
290290 Mathematical Background:
291291 -----------------------
292292 For the von Mises-Fisher distribution, the gradient of log C_m(κ) with
293- For the von Mises-Fisher distribution, the gradient of log C_m(κ) with
294293 respect to κ is given by the ratio of modified Bessel functions:
295294
296-
297295 ∂/∂κ log C_m(κ) = (m/2-1)/κ - I_{m/2}(κ)/I_{m/2-1}(κ)
298296
299-
300297 For m=3, this simplifies to the exact formula:
301298 ∂/∂κ log C_3(κ) = 1/κ - 1/tanh(κ)
302299
303-
304300 For small κ values, we use the Taylor series approximation:
305301 f(κ) = -κ/3 + κ³/45 - 2κ⁵/945 + O(κ⁷)
306302
307- The first-order approximation -κ/3 provides sufficient accuracy for
308-
309303 The first-order approximation -κ/3 provides sufficient accuracy for
310304 |κ| < 1e-6, with truncation error bounded by |κ|³/45 ≲ O(10⁻²¹).
311305
312-
313306 Implementation Details:
314307 ----------------------
315308 Uses boolean masking to avoid double evaluation and RuntimeWarnings:
316309 - Small κ: |κ| < 1e-6 → gradient = -κ/3 (Taylor approximation)
317310 - Large κ: |κ| ≥ 1e-6 → gradient = 1/κ - 1/tanh(κ) (exact formula)
318311
319-
320312 References:
321313 ----------
322314 [1] von Mises-Fisher distribution: Wikipedia
323315 [2] arXiv:1812.04616, Section 8.2
324316 [3] MIT License (c) 2019 Max Ryabinin - Modified for GraphNeT
325317
326-
327318 Args:
328319 ctx: Autograd context containing saved tensors and metadata.
329320 grad_output: Gradient with respect to the output tensor.
330321
331-
332322 Returns:
333323 Tuple of gradients: (None for m, gradient w.r.t. κ).
334324 """
@@ -340,8 +330,7 @@ def backward(
340330 # Initialize gradient array
341331 grads = np .zeros_like (kappa )
342332
343- # Handle small kappa values (including zero)
344- # to avoid division by zero
333+ # Handle small kappa values (including zero) to avoid division by zero
345334 small_mask = np .abs (kappa ) < 1e-6
346335 grads [small_mask ] = - kappa [small_mask ] / 3
347336
@@ -350,7 +339,6 @@ def backward(
350339 if np .any (large_mask ):
351340 kappa_large = kappa [large_mask ]
352341 grads [large_mask ] = 1 / kappa_large - 1 / np .tanh (kappa_large )
353- grads [large_mask ] = 1 / kappa_large - 1 / np .tanh (kappa_large )
354342 else :
355343 grads = - (
356344 (scipy .special .iv (m / 2.0 , kappa ))
0 commit comments