@@ -53,7 +53,7 @@ class Regression(nn.Module):
5353 regularization_weight : float
5454 Weight on regularization term in ELBO
5555 parameterization : str
56- Parameterization of covariance matrix. Currently supports {'dense', 'diagonal', 'lowrank'}
56+ Parameterization of covariance matrix. Currently supports {'dense', 'diagonal', 'lowrank', 'dense_precision' }
5757 prior_scale : float
5858 Scale of prior covariance matrix
5959 wishart_scale : float
@@ -77,27 +77,33 @@ def __init__(self,
7777 self .regularization_weight = regularization_weight
7878
7979 # define prior, currently fixing zero mean and arbitrarily scaled cov
80- self .prior_scale = prior_scale * (2 . / in_features ) # kaiming init
80+ self .prior_scale = prior_scale * (1 . / in_features )
8181
8282 # noise distribution
8383 self .noise_mean = nn .Parameter (torch .zeros (out_features ), requires_grad = False )
84- self .noise_logdiag = nn .Parameter (torch .randn (out_features ) - 1 )
84+ self .noise_logdiag = nn .Parameter (torch .randn (out_features ) * ( np . log ( wishart_scale )) )
8585
8686 # last layer distribution
8787 self .W_dist = get_parameterization (parameterization )
8888 self .W_mean = nn .Parameter (torch .randn (out_features , in_features ))
89-
90- self .W_logdiag = nn .Parameter (torch .randn (out_features , in_features ) - 0.5 * np .log (in_features ))
91- if parameterization == 'dense' :
89+
90+ if parameterization == 'diagonal' :
91+ self .W_logdiag = nn .Parameter (torch .randn (out_features , in_features ) - 0.5 * np .log (in_features ))
92+ elif parameterization == 'dense' :
93+ self .W_logdiag = nn .Parameter (torch .randn (out_features , in_features ) - 0.5 * np .log (in_features ))
9294 self .W_offdiag = nn .Parameter (torch .randn (out_features , in_features , in_features )/ in_features )
95+ elif parameterization == 'dense_precision' :
96+ self .W_logdiag = nn .Parameter (torch .randn (out_features , in_features ) + 0.5 * np .log (in_features ))
97+ self .W_offdiag = nn .Parameter (torch .randn (out_features , in_features , in_features )* 0.0 )
9398 elif parameterization == 'lowrank' :
99+ self .W_logdiag = nn .Parameter (torch .randn (out_features , in_features ) - 0.5 * np .log (in_features ))
94100 self .W_offdiag = nn .Parameter (torch .randn (out_features , in_features , cov_rank )/ in_features )
95101
96102 def W (self ):
97103 cov_diag = torch .exp (self .W_logdiag )
98104 if self .W_dist == Normal :
99105 cov = self .W_dist (self .W_mean , cov_diag )
100- elif self .W_dist == DenseNormal :
106+ elif ( self .W_dist == DenseNormal ) or ( self . W_dist == DenseNormalPrecision ) :
101107 tril = torch .tril (self .W_offdiag , diagonal = - 1 ) + torch .diag_embed (cov_diag )
102108 cov = self .W_dist (self .W_mean , tril )
103109 elif self .W_dist == LowRankNormal :
@@ -199,9 +205,11 @@ def __init__(self,
199205 self .W_logdiag = nn .Parameter (torch .randn (out_features , in_features ))
200206 if parameterization == 'dense' :
201207 self .W_offdiag = nn .Parameter (torch .randn (out_features , in_features , in_features ))
202-
203208 elif parameterization == 'lowrank' :
204209 self .W_offdiag = nn .Parameter (torch .randn (out_features , in_features , cov_rank ))
210+ elif parameterization == 'dense_precision' :
211+ raise NotImplementedError ()
212+
205213
206214 @property
207215 def W (self ):
0 commit comments