Skip to content

Commit d95478b

Browse files
authored
Bug fix on KL weight for heteroscedastic
1 parent 45221f9 commit d95478b

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

vbll/layers/regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def loss_fn(y):
404404
# compute expected KL
405405
kl_term_ll = torch.mean(grad_correction * expected_gaussian_kl(W, self.prior_scale, expect_sigma_inv))
406406
kl_term_noise = torch.mean(grad_correction * gaussian_kl(M, self.noise_prior_scale))
407-
total_elbo -= self.regularization_weight * (kl_term_noise + kl_term_ll)
407+
total_elbo -= self.regularization_weight * kl_term_noise + kl_term_ll
408408
return -total_elbo
409409

410410
return loss_fn

0 commit comments

Comments
 (0)