Skip to content

Commit e84bf65

Browse files
authored
bug fix kl weight, regression
1 parent 8bc0dee commit e84bf65

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

vbll/layers/regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def loss_fn(y):
404404
# compute expected KL
405405
kl_term_ll = torch.mean(grad_correction * expected_gaussian_kl(W, self.prior_scale, expect_sigma_inv))
406406
kl_term_noise = torch.mean(grad_correction * gaussian_kl(M, self.noise_prior_scale))
407-
total_elbo -= self.regularization_weight * kl_term_noise + kl_term_ll
407+
total_elbo -= self.regularization_weight * kl_term_ll + kl_term_noise
408408
return -total_elbo
409409

410410
return loss_fn

0 commit comments

Comments
 (0)