Skip to content

Commit e79282e

Browse files
committed
Fix layer loading
1 parent 9e51684 commit e79282e

3 files changed

Lines changed: 4 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "vbll"
3-
version = "0.2.4"
3+
version = "0.2.5"
44
description = ""
55
authors = ["John Willes <johnwilles@gmail.com>"]
66
readme = "README.md"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
setup(
55
name="vbll",
6-
version="0.2.4",
6+
version="0.2.5",
77
packages=find_packages(),
88
install_requires=["torch"],
99
)

vbll/layers/regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import torch
33
from dataclasses import dataclass
4-
from vbll.utils.distributions import Normal, DenseNormal, LowRankNormal, get_parameterization
4+
from vbll.utils.distributions import Normal, DenseNormal, LowRankNormal, DenseNormalPrec, get_parameterization
55
from collections.abc import Callable
66
import torch.nn as nn
77

@@ -103,7 +103,7 @@ def W(self):
103103
cov_diag = torch.exp(self.W_logdiag)
104104
if self.W_dist == Normal:
105105
cov = self.W_dist(self.W_mean, cov_diag)
106-
elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrecision):
106+
elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrec):
107107
tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
108108
cov = self.W_dist(self.W_mean, tril)
109109
elif self.W_dist == LowRankNormal:

0 commit comments

Comments
 (0)