Skip to content

Commit 1b501ac

Browse files
authored
Add precision-parameterized dense normal distribution
1 parent 140568c commit 1b501ac

1 file changed

Lines changed: 76 additions & 0 deletions

File tree

vbll/utils/distributions.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Callable
77
import abc
88
import warnings
9+
import fannypack.utils as fu
910

1011
def get_parameterization(p):
1112
if p in cov_param_dict:
@@ -139,6 +140,80 @@ def __matmul__(self, inp):
139140
def squeeze(self, idx):
140141
return DenseNormal(self.loc.squeeze(idx), self.scale_tril.squeeze(idx))
141142

143+
class DenseNormalPrec(torch.distributions.MultivariateNormal):
144+
"""A DenseNormal parameterized by the mean and the cholesky decomp of the precision matrix.
145+
146+
This function also includes a recursive_update function which performs a recursive
147+
linear regression update with effecient cholesky factor updates.
148+
"""
149+
def __init__(self, loc, cholesky, validate_args=False):
150+
prec = cholesky @ tp(cholesky)
151+
super(DenseNormalPrec, self).__init__(loc, precision_matrix=prec, validate_args=validate_args)
152+
self.tril = cholesky
153+
154+
@property
155+
def mean(self):
156+
return self.loc
157+
158+
@property
159+
def chol_covariance(self):
160+
raise NotImplementedError()
161+
162+
@property
163+
def covariance(self):
164+
warnings.warn("Direct matrix inverse for dense covariances is O(N^3), consider using eg inverse weighted inner product")
165+
# TODO replace with cholesky_inverse
166+
return fu.cholesky_inverse(self.tril)
167+
168+
@property
169+
def inverse_covariance(self):
170+
return self.precision_matrix
171+
172+
@property
173+
def logdet_covariance(self):
174+
return -2. * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(-1)
175+
176+
@property
177+
def trace_covariance(self):
178+
return (torch.inverse(self.tril)**2).sum(-1).sum(-1) # compute as frob norm squared
179+
180+
def covariance_weighted_inner_prod(self, b, reduce_dim=True):
181+
assert b.shape[-1] == 1
182+
prod = (torch.linalg.solve(self.tril, b)**2).sum(-2)
183+
return prod.squeeze(-1) if reduce_dim else prod
184+
185+
def precision_weighted_inner_prod(self, b, reduce_dim=True):
186+
assert b.shape[-1] == 1
187+
prod = ((tp(self.tril) @ b)**2).sum(-2)
188+
return prod.squeeze(-1) if reduce_dim else prod
189+
190+
def __matmul__(self, inp):
191+
assert inp.shape[-2] == self.loc.shape[-1]
192+
assert inp.shape[-1] == 1
193+
new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim = False)
194+
return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min = 1e-12)))
195+
196+
def squeeze(self, idx):
197+
return DenseNormalPrecision(self.loc.squeeze(idx), self.tril.squeeze(idx))
198+
199+
def recursive_update(self, X, y, noise_cov):
200+
noise_cov = noise_cov.unsqueeze(-1)
201+
prec = self.inverse_covariance # out_dim * feat_dim * feat_dim
202+
chol = self.tril
203+
204+
XTy = (tp(y) @ X) / noise_cov # out_dim * feat_dim
205+
206+
# recursively update cholesky
207+
for i in range(X.shape[0]):
208+
x = X[i].unsqueeze(-2) / torch.sqrt(noise_cov) # out_dim * feat_dim
209+
chol = fu.cholupdate(chol, x)
210+
211+
cov_update = (prec @ self.loc.unsqueeze(-1)) # out_dim * feat dim * 1
212+
cov_update += XTy.unsqueeze(-1) # out_dim * feat dim * 1
213+
new_loc = (fu.cholesky_inverse(chol) @ cov_update).squeeze(-1) # out_dim * feat dim
214+
215+
return chol, new_loc
216+
142217

143218
class LowRankNormal(torch.distributions.LowRankMultivariateNormal):
144219
def __init__(self, loc, cov_factor, diag):
@@ -198,6 +273,7 @@ def squeeze(self, idx):
198273

199274
cov_param_dict = {
200275
'dense': DenseNormal,
276+
'dense_precision': DenseNormalPrec,
201277
'diagonal': Normal,
202278
'lowrank': LowRankNormal
203279
}

0 commit comments

Comments
 (0)