Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 5313d3e

Browse files
add velocity and multi-factorization
1 parent 3cfb0fe commit 5313d3e

2 files changed

Lines changed: 15 additions & 6 deletions

File tree

notepredictor/notepredictor/distributions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def get_params(self, h):
4848

4949

5050
def forward(self, h, x):
51-
"""log prob of x under distribution parameterized by h"""
51+
"""log prob of x under distribution parameterized by h
52+
Args:
53+
h: Tensor[...,n_params]
54+
x: Tensor[...]
55+
"..." dims must broadcast
56+
"""
5257
log_pi, loc, s = self.get_params(h)
5358

5459
d = self.res/2
@@ -59,8 +64,8 @@ def forward(self, h, x):
5964
# # censoring
6065
lo_cens = x <= self.lo+d
6166
hi_cens = x >= self.hi-d
62-
ones = torch.ones_like(x_)
63-
zeros = torch.zeros_like(x_)
67+
ones = torch.ones_like(s)
68+
zeros = torch.zeros_like(s)
6469

6570
diff_term = torch.where(lo_cens | hi_cens,
6671
ones, sd.exp() - (-sd).exp()

notepredictor/notepredictor/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,20 @@ def mask_cat(missing, present, mask):
206206
# chunk into 7 and discard unmasked positions;
207207
# stack the masked positions along new first dim
208208
pitch_params, time_params, vel_params = (
209-
torch.cat([
210-
ch[None].expand(m, -1, -1, -1)
209+
torch.stack([
210+
ch
211211
for m,ch in zip(mask, dp.chunk(7, 0)) if m>0
212212
], 0)
213213
for mask,dp in zip(masks, dist_params)
214214
)
215215

216+
#TODO: weighting
217+
# weights = np.log([[m for m in mask if m>0] for mask in masks]) # 3 x 4
218+
216219
# get likelihoods
217220
pitch_logits = F.log_softmax(pitch_params, -1)
218-
pitch_targets = pitches[None,:,1:,None] #1, batch, time-1, 1
221+
# TODO: is gather working right with extra dim?
222+
pitch_targets = pitches[None,:,1:,None].expand(4, -1, -1, -1) #1, batch, time-1, 1
219223
pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0]
220224

221225
time_targets = times[:,1:]# batch, time-1

0 commit comments

Comments
 (0)