Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit bba7a2b

Browse files
comments
1 parent 7dc638f commit bba7a2b

1 file changed

Lines changed: 8 additions & 17 deletions

File tree

notepredictor/notepredictor/model.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,6 @@ def forward(self, pitches, times, velocities, validation=False):
242242
for t in self.initial_state)
243243
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
244244

245-
# include initial hidden state for predicting first note
246-
# h = torch.cat((
247-
# self.initial_state[0][-1][None].expand(batch_size, 1, -1),
248-
# h), -2)
249-
250245
# fit all note factorizations at once.
251246
# TODO: perm each batch item independently?
252247
perm = torch.randperm(self.note_dim)
@@ -273,9 +268,6 @@ def forward(self, pitches, times, velocities, validation=False):
273268
vel_result = self.vel_dist(vel_params, vel_targets)
274269
vel_log_probs = vel_result.pop('log_prob')
275270

276-
# should reduce over chunk dim with logsumexp?
277-
# i.e. average likelihood over factorizations, not LL?
278-
279271
r = {
280272
'pitch_log_probs': pitch_log_probs,
281273
'time_log_probs': time_log_probs,
@@ -344,9 +336,9 @@ def predict(self,
344336
self.embeddings,
345337
))
346338

347-
context = []
348-
predicted = []
349-
params = []
339+
context = [] # embedded outputs for autoregressive prediction
340+
predicted = [] # raw outputs
341+
params = [] # distribution parameters for visualization
350342

351343
fix = [
352344
None if item is None else torch.tensor([[item]], dtype=dtype)
@@ -355,7 +347,7 @@ def predict(self,
355347
[torch.long, torch.float, torch.float])]
356348

357349
# permute h_tgt, embs, modalities
358-
# if any modalities are determined, embed them;
350+
# if any modalities are determined, embed them
359351
det_idx, undet_idx = [], []
360352
for i,(item, embed) in enumerate(zip(fix, self.embeddings)):
361353
if item is None:
@@ -365,10 +357,8 @@ def predict(self,
365357
context.append(embed(item))
366358
predicted.append(item.item())
367359
params.append(None)
368-
perm = det_idx + undet_idx
369-
iperm = np.argsort(perm)
370-
371-
perm_h_tgt = [h_tgt[i] for i in perm]
360+
perm = det_idx + undet_idx # permutation from the canonical order
361+
iperm = np.argsort(perm) # inverse permutation back to canonical order
372362

373363
# for each undetermined modality,
374364
# sample a new value conditioned on alteady determined ones
@@ -378,7 +368,8 @@ def predict(self,
378368
# constraints can be:
379369
# discrete set, in which case evaluate probs and then sample categorical
380370
# range, in which case truncate
381-
371+
372+
perm_h_tgt = [h_tgt[i] for i in perm]
382373
while len(undet_idx):
383374
i = undet_idx.pop(0) # index of modality to determine
384375
j = len(det_idx) # number already determined

0 commit comments

Comments
 (0)