Skip to content

Commit 525a7d8

Browse files
authored
Add dense_precision distribution handling in regression models
1 parent 1b501ac commit 525a7d8

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

vbll/layers/regression.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Regression(nn.Module):
5353
regularization_weight : float
5454
Weight on regularization term in ELBO
5555
parameterization : str
56-
Parameterization of covariance matrix. Currently supports {'dense', 'diagonal', 'lowrank'}
56+
Parameterization of covariance matrix. Currently supports {'dense', 'diagonal', 'lowrank', 'dense_precision'}
5757
prior_scale : float
5858
Scale of prior covariance matrix
5959
wishart_scale : float
@@ -77,27 +77,33 @@ def __init__(self,
7777
self.regularization_weight = regularization_weight
7878

7979
# define prior, currently fixing zero mean and arbitrarily scaled cov
80-
self.prior_scale = prior_scale * (2. / in_features) # kaiming init
80+
self.prior_scale = prior_scale * (1. / in_features)
8181

8282
# noise distribution
8383
self.noise_mean = nn.Parameter(torch.zeros(out_features), requires_grad = False)
84-
self.noise_logdiag = nn.Parameter(torch.randn(out_features) - 1)
84+
self.noise_logdiag = nn.Parameter(torch.randn(out_features) * (np.log(wishart_scale)))
8585

8686
# last layer distribution
8787
self.W_dist = get_parameterization(parameterization)
8888
self.W_mean = nn.Parameter(torch.randn(out_features, in_features))
89-
90-
self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features) - 0.5 * np.log(in_features))
91-
if parameterization == 'dense':
89+
90+
if parameterization == 'diagonal':
91+
self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features) - 0.5 * np.log(in_features))
92+
elif parameterization == 'dense':
93+
self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features) - 0.5 * np.log(in_features))
9294
self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, in_features)/in_features)
95+
elif parameterization == 'dense_precision':
96+
self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features) + 0.5 * np.log(in_features))
97+
self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, in_features)*0.0)
9398
elif parameterization == 'lowrank':
99+
self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features) - 0.5 * np.log(in_features))
94100
self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, cov_rank)/in_features)
95101

96102
def W(self):
97103
cov_diag = torch.exp(self.W_logdiag)
98104
if self.W_dist == Normal:
99105
cov = self.W_dist(self.W_mean, cov_diag)
100-
elif self.W_dist == DenseNormal:
106+
elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrecision):
101107
tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
102108
cov = self.W_dist(self.W_mean, tril)
103109
elif self.W_dist == LowRankNormal:
@@ -199,9 +205,11 @@ def __init__(self,
199205
self.W_logdiag = nn.Parameter(torch.randn(out_features, in_features))
200206
if parameterization == 'dense':
201207
self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, in_features))
202-
203208
elif parameterization == 'lowrank':
204209
self.W_offdiag = nn.Parameter(torch.randn(out_features, in_features, cov_rank))
210+
elif parameterization == 'dense_precision':
211+
raise NotImplementedError()
212+
205213

206214
@property
207215
def W(self):

0 commit comments

Comments
 (0)