Skip to content

Commit ae12135

Browse files
authored
Bug fix on KL weight for heteroscedastic classification
1 parent d95478b commit ae12135

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

vbll/layers/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def loss_fn(y):
525525
kl_term_noise = gaussian_kl(self.M, self.noise_prior_scale)
526526

527527
total_elbo = torch.mean(self.softmax_bound(x, y))
528-
total_elbo -= self.regularization_weight * (kl_term_noise + kl_term_ll)
528+
total_elbo -= self.regularization_weight * kl_term_noise + kl_term_ll
529529

530530
return -total_elbo
531531

0 commit comments

Comments
 (0)