|
6 | 6 | from collections.abc import Callable |
7 | 7 | import abc |
8 | 8 | import warnings |
| 9 | +import fannypack.utils as fu |
9 | 10 |
|
10 | 11 | def get_parameterization(p): |
11 | 12 | if p in cov_param_dict: |
@@ -139,6 +140,80 @@ def __matmul__(self, inp): |
139 | 140 | def squeeze(self, idx): |
140 | 141 | return DenseNormal(self.loc.squeeze(idx), self.scale_tril.squeeze(idx)) |
141 | 142 |
|
| 143 | +class DenseNormalPrec(torch.distributions.MultivariateNormal): |
| 144 | + """A DenseNormal parameterized by the mean and the cholesky decomp of the precision matrix. |
| 145 | +
|
| 146 | + This function also includes a recursive_update function which performs a recursive |
| 147 | + linear regression update with effecient cholesky factor updates. |
| 148 | + """ |
| 149 | + def __init__(self, loc, cholesky, validate_args=False): |
| 150 | + prec = cholesky @ tp(cholesky) |
| 151 | + super(DenseNormalPrec, self).__init__(loc, precision_matrix=prec, validate_args=validate_args) |
| 152 | + self.tril = cholesky |
| 153 | + |
| 154 | + @property |
| 155 | + def mean(self): |
| 156 | + return self.loc |
| 157 | + |
| 158 | + @property |
| 159 | + def chol_covariance(self): |
| 160 | + raise NotImplementedError() |
| 161 | + |
| 162 | + @property |
| 163 | + def covariance(self): |
| 164 | + warnings.warn("Direct matrix inverse for dense covariances is O(N^3), consider using eg inverse weighted inner product") |
| 165 | + # TODO replace with cholesky_inverse |
| 166 | + return fu.cholesky_inverse(self.tril) |
| 167 | + |
| 168 | + @property |
| 169 | + def inverse_covariance(self): |
| 170 | + return self.precision_matrix |
| 171 | + |
| 172 | + @property |
| 173 | + def logdet_covariance(self): |
| 174 | + return -2. * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(-1) |
| 175 | + |
| 176 | + @property |
| 177 | + def trace_covariance(self): |
| 178 | + return (torch.inverse(self.tril)**2).sum(-1).sum(-1) # compute as frob norm squared |
| 179 | + |
| 180 | + def covariance_weighted_inner_prod(self, b, reduce_dim=True): |
| 181 | + assert b.shape[-1] == 1 |
| 182 | + prod = (torch.linalg.solve(self.tril, b)**2).sum(-2) |
| 183 | + return prod.squeeze(-1) if reduce_dim else prod |
| 184 | + |
| 185 | + def precision_weighted_inner_prod(self, b, reduce_dim=True): |
| 186 | + assert b.shape[-1] == 1 |
| 187 | + prod = ((tp(self.tril) @ b)**2).sum(-2) |
| 188 | + return prod.squeeze(-1) if reduce_dim else prod |
| 189 | + |
| 190 | + def __matmul__(self, inp): |
| 191 | + assert inp.shape[-2] == self.loc.shape[-1] |
| 192 | + assert inp.shape[-1] == 1 |
| 193 | + new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim = False) |
| 194 | + return Normal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min = 1e-12))) |
| 195 | + |
| 196 | + def squeeze(self, idx): |
| 197 | + return DenseNormalPrecision(self.loc.squeeze(idx), self.tril.squeeze(idx)) |
| 198 | + |
| 199 | + def recursive_update(self, X, y, noise_cov): |
| 200 | + noise_cov = noise_cov.unsqueeze(-1) |
| 201 | + prec = self.inverse_covariance # out_dim * feat_dim * feat_dim |
| 202 | + chol = self.tril |
| 203 | + |
| 204 | + XTy = (tp(y) @ X) / noise_cov # out_dim * feat_dim |
| 205 | + |
| 206 | + # recursively update cholesky |
| 207 | + for i in range(X.shape[0]): |
| 208 | + x = X[i].unsqueeze(-2) / torch.sqrt(noise_cov) # out_dim * feat_dim |
| 209 | + chol = fu.cholupdate(chol, x) |
| 210 | + |
| 211 | + cov_update = (prec @ self.loc.unsqueeze(-1)) # out_dim * feat dim * 1 |
| 212 | + cov_update += XTy.unsqueeze(-1) # out_dim * feat dim * 1 |
| 213 | + new_loc = (fu.cholesky_inverse(chol) @ cov_update).squeeze(-1) # out_dim * feat dim |
| 214 | + |
| 215 | + return chol, new_loc |
| 216 | + |
142 | 217 |
|
143 | 218 | class LowRankNormal(torch.distributions.LowRankMultivariateNormal): |
144 | 219 | def __init__(self, loc, cov_factor, diag): |
@@ -198,6 +273,7 @@ def squeeze(self, idx): |
198 | 273 |
|
199 | 274 | cov_param_dict = { |
200 | 275 | 'dense': DenseNormal, |
| 276 | + 'dense_precision': DenseNormalPrec, |
201 | 277 | 'diagonal': Normal, |
202 | 278 | 'lowrank': LowRankNormal |
203 | 279 | } |
0 commit comments