@@ -200,7 +200,8 @@ def cell_state(self):
200200
201201 def get_samplers (self ,
202202 pitch_topk = None , index_pitch = None , allow_start = False , allow_end = False ,
203- sweep_time = False ):
203+ sweep_time = False , trunc_time = None ):
204+
204205 def sample_pitch (x ):
205206 if not allow_start :
206207 x [...,self .start_token ] = - np .inf
@@ -214,15 +215,20 @@ def sample_pitch(x):
214215 return D .Categorical (logits = x ).sample ()
215216
216217 def sample_time (x ):
218+ # TODO: respect trunc_time when sweep_time is True
217219 if sweep_time :
220+ if trunc_time is not None :
221+ raise NotImplementedError ("""
222+ trunc_time with sweep_time needs implementation
223+ """ )
218224 assert x .shape [0 ]== 1 , "batch size should be 1 here"
219225 log_pi , loc , s = self .time_dist .get_params (x )
220226 idx = log_pi .squeeze ().argsort ()[:9 ]
221227 loc = loc .squeeze ()[idx ].sort ().values [...,None ] # multiple times in batch dim
222228 # print(loc.shape)
223229 return loc
224230 else :
225- return self .time_dist .sample (x )
231+ return self .time_dist .sample (x , truncate = trunc_time )
226232
227233 return (
228234 sample_pitch ,
@@ -310,7 +316,7 @@ def predict(self,
310316 pitch , time , vel ,
311317 fix_pitch = None , fix_time = None , fix_vel = None ,
312318 pitch_topk = None , index_pitch = None , allow_start = False , allow_end = False ,
313- sweep_time = False ):
319+ sweep_time = False , trunc_time = None ):
314320 """
315321 consume the most recent note and return a prediction for the next note.
316322
@@ -329,6 +335,7 @@ def predict(self,
329335 allow_end: if False, zero probaility for sampling the end token
330336 sweep_time: if True, instead of sampling time, choose a diverse set of
331337 times and stack along the batch dimension
338+ trunc_time: if not None, truncate the time distribution to (lo, hi)
332339
333340 Returns: dict of
334341 'pitch': int. predicted MIDI number of next note.
@@ -358,7 +365,8 @@ def predict(self,
358365 modalities = list (zip (
359366 self .projections ,
360367 self .get_samplers (
361- pitch_topk , index_pitch , allow_start , allow_end , sweep_time ),
368+ pitch_topk , index_pitch , allow_start , allow_end ,
369+ sweep_time , trunc_time ),
362370 self .embeddings ,
363371 ))
364372
@@ -379,7 +387,10 @@ def predict(self,
379387 det_idx , cons_idx , uncons_idx = [], [], []
380388 for i ,(item , embed ) in enumerate (zip (fix , self .embeddings )):
381389 if item is None :
382- if (i == 1 and sweep_time ) or (i == 0 and pitch_topk ):
390+ if (
391+ i == 1 and (sweep_time or (trunc_time is not None )) or
392+ i == 0 and pitch_topk
393+ ):
383394 cons_idx .append (i )
384395 else :
385396 uncons_idx .append (i )
@@ -398,8 +409,9 @@ def predict(self,
398409 # TODO: allow constraints;
399410 # attempt to sort the strongest constraints first
400411 # constraints can be:
401- # discrete set, in which case evaluate probs and then sample categorical
402- # range, in which case truncate
412+ # discrete set, in which case evaluate probs and then sample categorical;
413+ # range, in which case truncate;
414+ # temperature?
403415
404416 perm_h_tgt = [h_tgt [i ] for i in perm ]
405417 while len (undet_idx ):
0 commit comments