Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit bf4a59a

Browse files
multifactorization inference WIP
1 parent 0fd5cfb commit bf4a59a

1 file changed

Lines changed: 81 additions & 32 deletions

File tree

notepredictor/notepredictor/model.py

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -297,42 +297,91 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
297297

298298
h = self.h_proj(h).chunk(self.note_dim+1, -1)
299299

300-
# TODO: permutations
301-
# TODO: refactor with common distribution API
302-
pitch_h, = self.xformer(h[:2], [])
303-
304-
pitch_params = self.projections[0](pitch_h)
305-
pred_pitch = D.Categorical(logits=pitch_params).sample()
306-
307-
embs[0] = self.pitch_emb(pred_pitch)
308-
_, time_h = self.xformer(h[:3], embs[:1])
309-
310-
time_params = self.projections[1](time_h)
311-
pred_time = self.time_dist.sample(time_params)
312-
### TODO: generalize, move into sample
313-
# pi only, fewer zeros:
314-
# log_pi, loc, s = (
315-
# t for t in self.time_dist.get_params(time_params))
316-
# bias = float('inf')
317-
# log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
318-
# idx = D.Categorical(logits=log_pi).sample().item()
319-
# pred_time = loc[...,idx].clamp(0,10)
320-
###
321-
322-
embs[1] = self.time_emb(pred_time)
323-
_, _, vel_h = self.xformer(h, embs[:2])
300+
modalities = [
301+
(
302+
self.projections[0],
303+
lambda x: D.Categorical(logits=x).sample(),
304+
self.pitch_emb
305+
),
306+
(
307+
self.projections[1],
308+
lambda x: self.time_dist.sample(x),
309+
self.time_emb
310+
),
311+
(
312+
self.projections[2],
313+
lambda x: self.vel_dist.sample(x),
314+
self.vel_emb
315+
)
316+
]
324317

325-
vel_params = self.projections[2](vel_h)
326-
pred_vel = self.vel_dist.sample(vel_params)
318+
# TODO: refactor for this:
319+
# modalities = list(zip(
320+
# self.projections,
321+
# self.distributions,
322+
# self.embeddings,
323+
# ))
324+
325+
# TODO: permute h[1:], embs, modalities
326+
327+
condition = []
328+
predicted = []
329+
params = []
330+
for i, (project, sample, embed) in enumerate(modalities):
331+
hidden = self.xformer(h[:i+2], condition)[i]
332+
params.append(project(hidden))
333+
predicted.append(sample(params[-1]))
334+
if i<len(modalities)-1:
335+
condition.append(embed(predicted[-1]))
336+
337+
# TODO: unpermute
327338

328339
return {
329-
'pitch': pred_pitch.item(),
330-
'time': pred_time.item(),
331-
'velocity': pred_vel.item(),
332-
'pitch_params': pitch_params,
333-
'time_params': time_params,
334-
'vel_params': vel_params
340+
'pitch': predicted[0].item(),
341+
'time': predicted[1].item(),
342+
'velocity': predicted[2].item(),
343+
'pitch_params': params[0],
344+
'time_params': params[1],
345+
'vel_params': params[2]
335346
}
347+
348+
349+
# TODO: permutations
350+
# TODO: refactor with common distribution API
351+
# pitch_h, = self.xformer(h[:2], [])
352+
353+
# pitch_params = self.projections[0](pitch_h)
354+
# pred_pitch = D.Categorical(logits=pitch_params).sample()
355+
356+
# embs[0] = self.pitch_emb(pred_pitch)
357+
# _, time_h = self.xformer(h[:3], embs[:1])
358+
359+
# time_params = self.projections[1](time_h)
360+
# pred_time = self.time_dist.sample(time_params)
361+
# ### TODO: generalize, move into sample
362+
# # pi only, fewer zeros:
363+
# # log_pi, loc, s = (
364+
# # t for t in self.time_dist.get_params(time_params))
365+
# # bias = float('inf')
366+
# # log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
367+
# # idx = D.Categorical(logits=log_pi).sample().item()
368+
# # pred_time = loc[...,idx].clamp(0,10)
369+
# ###
370+
371+
# embs[1] = self.time_emb(pred_time)
372+
# _, _, vel_h = self.xformer(h, embs[:2])
373+
374+
# vel_params = self.projections[2](vel_h)
375+
# pred_vel = self.vel_dist.sample(vel_params)
376+
377+
# return {
378+
# 'pitch': pred_pitch.item(),
379+
# 'time': pred_time.item(),
380+
# 'velocity': pred_vel.item(),
381+
# 'pitch_params': pitch_params,
382+
# 'time_params': time_params,
383+
# 'vel_params': vel_params
384+
# }
336385

337386
# TODO: start velocity
338387
def reset(self, start=True):

0 commit comments

Comments
 (0)