Skip to content

Commit 64cc96f

Browse files
committed
Tweak inits
1 parent 4621dee commit 64cc96f

3 files changed

Lines changed: 6 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "vbll"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
description = ""
55
authors = ["John Willes <johnwilles@gmail.com>"]
66
readme = "README.md"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
setup(
55
name="vbll",
6-
version="0.3.0",
6+
version="0.3.1",
77
packages=find_packages(),
88
install_requires=["torch"],
99
)

vbll/layers/classification.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(self,
219219
return_ood=False,
220220
prior_scale=1.,
221221
wishart_scale=1.,
222-
dof=1.,
222+
dof=2.,
223223
cov_rank=None,
224224
kn_alpha=None,
225225
):
@@ -232,7 +232,8 @@ def __init__(self,
232232
# define prior, currently fixing zero mean and arbitrarily scaled cov
233233
self.prior_dof = dof
234234
self.prior_rate = 1./wishart_scale
235-
self.prior_scale = prior_scale * (2. / in_features) # kaiming init
235+
exp_cov = self.prior_rate/(self.prior_dof - 1)
236+
self.prior_scale = prior_scale * 2. / (exp_cov * in_features)
236237

237238
# variational posterior over noise params
238239
self.noise_log_dof = nn.Parameter(torch.ones(out_features) * np.log(self.prior_dof))
@@ -259,7 +260,7 @@ def __init__(self,
259260
elif softmax_bound == 'reduced_kn':
260261
self.softmax_bound = self.reduced_kn
261262
if kn_alpha is None:
262-
self.alpha = nn.Parameter(0.1 * torch.randn(out_features))
263+
self.alpha = nn.Parameter(0.0 * torch.ones(out_features))
263264
else:
264265
self.alpha = nn.Parameter(torch.ones(out_features) * kn_alpha, requires_grad=False)
265266
else:

0 commit comments

Comments
 (0)