@@ -242,11 +242,6 @@ def forward(self, pitches, times, velocities, validation=False):
242242 for t in self .initial_state )
243243 h , _ = self .rnn (x , initial_state ) #batch, time, hidden_size
244244
245- # include initial hidden state for predicting first note
246- # h = torch.cat((
247- # self.initial_state[0][-1][None].expand(batch_size, 1, -1),
248- # h), -2)
249-
250245 # fit all note factorizations at once.
251246 # TODO: perm each batch item independently?
252247 perm = torch .randperm (self .note_dim )
@@ -273,9 +268,6 @@ def forward(self, pitches, times, velocities, validation=False):
273268 vel_result = self .vel_dist (vel_params , vel_targets )
274269 vel_log_probs = vel_result .pop ('log_prob' )
275270
276- # should reduce over chunk dim with logsumexp?
277- # i.e. average likelihood over factorizations, not LL?
278-
279271 r = {
280272 'pitch_log_probs' : pitch_log_probs ,
281273 'time_log_probs' : time_log_probs ,
@@ -344,9 +336,9 @@ def predict(self,
344336 self .embeddings ,
345337 ))
346338
347- context = []
348- predicted = []
349- params = []
339+ context = [] # embedded outputs for autoregressive prediction
340+ predicted = [] # raw outputs
341+ params = [] # distribution parameters for visualization
350342
351343 fix = [
352344 None if item is None else torch .tensor ([[item ]], dtype = dtype )
@@ -355,7 +347,7 @@ def predict(self,
355347 [torch .long , torch .float , torch .float ])]
356348
357349 # permute h_tgt, embs, modalities
358- # if any modalities are determined, embed them;
350+ # if any modalities are determined, embed them
359351 det_idx , undet_idx = [], []
360352 for i ,(item , embed ) in enumerate (zip (fix , self .embeddings )):
361353 if item is None :
@@ -365,10 +357,8 @@ def predict(self,
365357 context .append (embed (item ))
366358 predicted .append (item .item ())
367359 params .append (None )
368- perm = det_idx + undet_idx
369- iperm = np .argsort (perm )
370-
371- perm_h_tgt = [h_tgt [i ] for i in perm ]
360+ perm = det_idx + undet_idx # permutation from the canonical order
361+ iperm = np .argsort (perm ) # inverse permutation back to canonical order
372362
373363 # for each undetermined modality,
374364 # sample a new value conditioned on alteady determined ones
@@ -378,7 +368,8 @@ def predict(self,
378368 # constraints can be:
379369 # discrete set, in which case evaluate probs and then sample categorical
380370 # range, in which case truncate
381-
371+
372+ perm_h_tgt = [h_tgt [i ] for i in perm ]
382373 while len (undet_idx ):
383374 i = undet_idx .pop (0 ) # index of modality to determine
384375 j = len (det_idx ) # number already determined
0 commit comments