@@ -297,42 +297,91 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
297297
298298 h = self .h_proj (h ).chunk (self .note_dim + 1 , - 1 )
299299
300- # TODO: permutations
301- # TODO: refactor with common distribution API
302- pitch_h , = self .xformer (h [:2 ], [])
303-
304- pitch_params = self .projections [0 ](pitch_h )
305- pred_pitch = D .Categorical (logits = pitch_params ).sample ()
306-
307- embs [0 ] = self .pitch_emb (pred_pitch )
308- _ , time_h = self .xformer (h [:3 ], embs [:1 ])
309-
310- time_params = self .projections [1 ](time_h )
311- pred_time = self .time_dist .sample (time_params )
312- ### TODO: generalize, move into sample
313- # pi only, fewer zeros:
314- # log_pi, loc, s = (
315- # t for t in self.time_dist.get_params(time_params))
316- # bias = float('inf')
317- # log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
318- # idx = D.Categorical(logits=log_pi).sample().item()
319- # pred_time = loc[...,idx].clamp(0,10)
320- ###
321-
322- embs [1 ] = self .time_emb (pred_time )
323- _ , _ , vel_h = self .xformer (h , embs [:2 ])
300+ modalities = [
301+ (
302+ self .projections [0 ],
303+ lambda x : D .Categorical (logits = x ).sample (),
304+ self .pitch_emb
305+ ),
306+ (
307+ self .projections [1 ],
308+ lambda x : self .time_dist .sample (x ),
309+ self .time_emb
310+ ),
311+ (
312+ self .projections [2 ],
313+ lambda x : self .vel_dist .sample (x ),
314+ self .vel_emb
315+ )
316+ ]
324317
325- vel_params = self .projections [2 ](vel_h )
326- pred_vel = self .vel_dist .sample (vel_params )
318+ # TODO: refactor for this:
319+ # modalities = list(zip(
320+ # self.projections,
321+ # self.distributions,
322+ # self.embeddings,
323+ # ))
324+
325+ # TODO: permute h[1:], embs, modalities
326+
327+ condition = []
328+ predicted = []
329+ params = []
330+ for i , (project , sample , embed ) in enumerate (modalities ):
331+ hidden = self .xformer (h [:i + 2 ], condition )[i ]
332+ params .append (project (hidden ))
333+ predicted .append (sample (params [- 1 ]))
334+ if i < len (modalities )- 1 :
335+ condition .append (embed (predicted [- 1 ]))
336+
337+ # TODO: unpermute
327338
328339 return {
329- 'pitch' : pred_pitch .item (),
330- 'time' : pred_time .item (),
331- 'velocity' : pred_vel .item (),
332- 'pitch_params' : pitch_params ,
333- 'time_params' : time_params ,
334- 'vel_params' : vel_params
340+ 'pitch' : predicted [ 0 ] .item (),
341+ 'time' : predicted [ 1 ] .item (),
342+ 'velocity' : predicted [ 2 ] .item (),
343+ 'pitch_params' : params [ 0 ] ,
344+ 'time_params' : params [ 1 ] ,
345+ 'vel_params' : params [ 2 ]
335346 }
347+
348+
349+ # TODO: permutations
350+ # TODO: refactor with common distribution API
351+ # pitch_h, = self.xformer(h[:2], [])
352+
353+ # pitch_params = self.projections[0](pitch_h)
354+ # pred_pitch = D.Categorical(logits=pitch_params).sample()
355+
356+ # embs[0] = self.pitch_emb(pred_pitch)
357+ # _, time_h = self.xformer(h[:3], embs[:1])
358+
359+ # time_params = self.projections[1](time_h)
360+ # pred_time = self.time_dist.sample(time_params)
361+ # ### TODO: generalize, move into sample
362+ # # pi only, fewer zeros:
363+ # # log_pi, loc, s = (
364+ # # t for t in self.time_dist.get_params(time_params))
365+ # # bias = float('inf')
366+ # # log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
367+ # # idx = D.Categorical(logits=log_pi).sample().item()
368+ # # pred_time = loc[...,idx].clamp(0,10)
369+ # ###
370+
371+ # embs[1] = self.time_emb(pred_time)
372+ # _, _, vel_h = self.xformer(h, embs[:2])
373+
374+ # vel_params = self.projections[2](vel_h)
375+ # pred_vel = self.vel_dist.sample(vel_params)
376+
377+ # return {
378+ # 'pitch': pred_pitch.item(),
379+ # 'time': pred_time.item(),
380+ # 'velocity': pred_vel.item(),
381+ # 'pitch_params': pitch_params,
382+ # 'time_params': time_params,
383+ # 'vel_params': vel_params
384+ # }
336385
337386 # TODO: start velocity
338387 def reset (self , start = True ):
0 commit comments