We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 45221f9 commit d95478bCopy full SHA for d95478b
1 file changed
vbll/layers/regression.py
@@ -404,7 +404,7 @@ def loss_fn(y):
404
# compute expected KL
405
kl_term_ll = torch.mean(grad_correction * expected_gaussian_kl(W, self.prior_scale, expect_sigma_inv))
406
kl_term_noise = torch.mean(grad_correction * gaussian_kl(M, self.noise_prior_scale))
407
- total_elbo -= self.regularization_weight * (kl_term_noise + kl_term_ll)
+ total_elbo -= self.regularization_weight * kl_term_noise + kl_term_ll
408
return -total_elbo
409
410
return loss_fn
0 commit comments