@@ -99,8 +99,13 @@ def forward(self, ctx, h_ctx, h_tgt):
9999 h_tgt: list of Tensor[batch x time x input_size], length note_dim
100100 these are projections of the RNN state
101101 """
102- h_tgt = list (h_tgt )
103- ctx = list (ctx )
102+ # h_tgt = list(h_tgt)
103+ # ctx = list(ctx)
104+
105+ # explicitly broadcast
106+ h_ctx , * ctx = torch .broadcast_tensors (h_ctx , * ctx )
107+ h_ctx , * h_tgt = torch .broadcast_tensors (h_ctx , * h_tgt )
108+
104109 # h_tgt is 'target' w.r.t TransformerDecoder
105110 # h_ctx and context are 'memory'
106111 batch_size = h_ctx .shape [0 ]* h_ctx .shape [1 ]
@@ -193,20 +198,35 @@ def cell_state_names(self):
193198 def cell_state (self ):
194199 return tuple (getattr (self , n ) for n in self .cell_state_names ())
195200
196- def get_samplers (self , index_pitch = None , allow_start = False , allow_end = False ):
201+ def get_samplers (self ,
202+ pitch_topk = None , index_pitch = None , allow_start = False , allow_end = False ,
203+ sweep_time = False ):
197204 def sample_pitch (x ):
198205 if not allow_start :
199206 x [...,self .start_token ] = - np .inf
200207 if not allow_end :
201208 x [...,self .end_token ] = - np .inf
202209 if index_pitch is not None :
203210 return x .argsort (- 1 , True )[...,index_pitch ]
211+ elif pitch_topk is not None :
212+ return x .argsort (- 1 , True )[...,:pitch_topk ].transpose (0 ,- 1 )
204213 else :
205214 return D .Categorical (logits = x ).sample ()
206215
216+ def sample_time (x ):
217+ if sweep_time :
218+ assert x .shape [0 ]== 1 , "batch size should be 1 here"
219+ log_pi , loc , s = self .time_dist .get_params (x )
220+ idx = log_pi .squeeze ().argsort ()[:9 ]
221+ loc = loc .squeeze ()[idx ].sort ().values [...,None ] # multiple times in batch dim
222+ # print(loc.shape)
223+ return loc
224+ else :
225+ return self .time_dist .sample (x )
226+
207227 return (
208228 sample_pitch ,
209- lambda x : self . time_dist . sample ( x ) ,
229+ sample_time ,
210230 lambda x : self .vel_dist .sample (x ),
211231 )
212232
@@ -289,9 +309,10 @@ def forward(self, pitches, times, velocities, validation=False):
289309 def predict (self ,
290310 pitch , time , vel ,
291311 fix_pitch = None , fix_time = None , fix_vel = None ,
292- index_pitch = None , allow_start = False , allow_end = False ):
312+ pitch_topk = None , index_pitch = None , allow_start = False , allow_end = False ,
313+ sweep_time = False ):
293314 """
294- supply the most recent note and return a prediction for the next note.
315+ consume the most recent note and return a prediction for the next note.
295316
296317 various constraints can be enforced on the next note.
297318
@@ -304,6 +325,8 @@ def predict(self,
304325 most likely pitch instead of sampling.
305326 allow_start: if False, zero probability for sampling the start token
306327 allow_end: if False, zero probaility for sampling the end token
328+ sweep_time: if True, instead of sampling time, choose a diverse set of
329+ times and stack along the batch dimension
307330
308331 Returns: dict of
309332 'pitch': int. predicted MIDI number of next note.
@@ -332,7 +355,8 @@ def predict(self,
332355
333356 modalities = list (zip (
334357 self .projections ,
335- self .get_samplers (index_pitch , allow_start , allow_end ),
358+ self .get_samplers (
359+ pitch_topk , index_pitch , allow_start , allow_end , sweep_time ),
336360 self .embeddings ,
337361 ))
338362
@@ -348,15 +372,21 @@ def predict(self,
348372
349373 # permute h_tgt, embs, modalities
350374 # if any modalities are determined, embed them
351- det_idx , undet_idx = [], []
375+ # sort constrained modailities before unconstrained
376+ # TODO: option to skip modalities
377+ det_idx , cons_idx , uncons_idx = [], [], []
352378 for i ,(item , embed ) in enumerate (zip (fix , self .embeddings )):
353379 if item is None :
354- undet_idx .append (i )
380+ if (i == 1 and sweep_time ) or (i == 0 and pitch_topk ):
381+ cons_idx .append (i )
382+ else :
383+ uncons_idx .append (i )
355384 else :
356385 det_idx .append (i )
357386 context .append (embed (item ))
358- predicted .append (item . item () )
387+ predicted .append (item )
359388 params .append (None )
389+ undet_idx = cons_idx + uncons_idx
360390 perm = det_idx + undet_idx # permutation from the canonical order
361391 iperm = np .argsort (perm ) # inverse permutation back to canonical order
362392
@@ -378,16 +408,34 @@ def predict(self,
378408 hidden = self .xformer (context , h_ctx , perm_h_tgt [:j + 1 ])[j ]
379409 params .append (project (hidden ))
380410 pred = sample (params [- 1 ])
381- predicted .append (pred . item () )
411+ predicted .append (pred )
382412 # prepare for next iteration
383413 if len (undet_idx ):
384414 context .append (embed (pred ))
385415 det_idx .append (i )
386416
417+
418+ pred_pitch = predicted [iperm [0 ]]
419+ pred_time = predicted [iperm [1 ]]
420+ pred_vel = predicted [iperm [2 ]]
421+
422+ print (pred_time .shape )
423+ print (pred_pitch .shape )
424+ print (pred_vel .shape )
425+
426+ if sweep_time or pitch_topk :
427+ pred_pitch = [x .item () for x in pred_pitch ]
428+ pred_time = [x .item () for x in pred_time ]
429+ pred_vel = [x .item () for x in pred_vel ]
430+ print (pred_time , pred_pitch , pred_vel )
431+ else :
432+ pred_pitch = pred_pitch .item ()
433+ pred_time = pred_time .item ()
434+ pred_vel = pred_vel .item ()
387435 return {
388- 'pitch' : predicted [ iperm [ 0 ]] ,
389- 'time' : predicted [ iperm [ 1 ]] ,
390- 'velocity' : predicted [ iperm [ 2 ]] ,
436+ 'pitch' : pred_pitch ,
437+ 'time' : pred_time ,
438+ 'velocity' : pred_vel ,
391439 'pitch_params' : params [iperm [0 ]],
392440 'time_params' : params [iperm [1 ]],
393441 'vel_params' : params [iperm [2 ]]
0 commit comments