Skip to content

Commit 48f4018

Browse files
辰言辰言
authored andcommitted
fix(machine_learning): prevent log(0) and divide-by-zero warnings/NaN in kullback_leibler_divergence
1 parent e3b01ec commit 48f4018

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

machine_learning/loss_functions.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,12 +655,26 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
655655
Traceback (most recent call last):
656656
...
657657
ValueError: Input arrays must have the same length.
658+
>>> # Zero values in y_true and y_pred are handled correctly without warnings
659+
>>> true_labels = np.array([0.0, 1.0])
660+
>>> predicted_probs = np.array([0.1, 0.9])
661+
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
662+
0.10536051565782635
663+
>>> true_labels = np.array([0.5, 0.5])
664+
>>> predicted_probs = np.array([0.0, 1.0])
665+
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
666+
16.576241016895395
658667
"""
659668
if len(y_true) != len(y_pred):
660669
raise ValueError("Input arrays must have the same length.")
661670

662-
kl_loss = y_true * np.log(y_true / y_pred)
663-
return np.sum(kl_loss)
671+
kl_loss = np.zeros_like(y_true, dtype=float)
672+
mask = y_true > 0
673+
if np.any(mask):
674+
kl_loss[mask] = y_true[mask] * np.log(
675+
y_true[mask] / np.clip(y_pred[mask], 1e-15, 1.0)
676+
)
677+
return float(np.sum(kl_loss))
664678

665679

666680
if __name__ == "__main__":

0 commit comments

Comments
 (0)