Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 722776f

Browse files
linn display
1 parent 7367944 commit 722776f

2 files changed

Lines changed: 70 additions & 16 deletions

File tree

notepredictor/notepredictor/distributions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,16 @@ def cdf(self, h, x):
103103
cdf = (cdfs * log_pi.softmax(-1)).sum(-1)
104104
return cdf
105105

106-
def sample(self, h, shape=1):
106+
def sample(self, h, shape=None):
107107
"""
108108
Args:
109109
shape: additional sample shape to be prepended to dims
110110
"""
111+
if shape is None:
112+
unwrap = True
113+
shape = 1
114+
else:
115+
unwrap = False
111116
log_pi, loc, s = self.get_params(h)
112117
scale = 1/s
113118

@@ -119,7 +124,8 @@ def sample(self, h, shape=1):
119124
u = torch.rand(shape, *h.shape[:-1])
120125

121126
x = loc + scale * (u.log() - (1 - u).log())
122-
return x.clamp(self.lo, self.hi)
127+
x = x.clamp(self.lo, self.hi)
128+
return x[0] if unwrap else x
123129

124130

125131
class CensoredMixturePointyBoi(nn.Module):

notepredictor/notepredictor/model.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ def forward(self, ctx, h_ctx, h_tgt):
9999
h_tgt: list of Tensor[batch x time x input_size], length note_dim
100100
these are projections of the RNN state
101101
"""
102-
h_tgt = list(h_tgt)
103-
ctx = list(ctx)
102+
# h_tgt = list(h_tgt)
103+
# ctx = list(ctx)
104+
105+
# explicitly broadcast
106+
h_ctx, *ctx = torch.broadcast_tensors(h_ctx, *ctx)
107+
h_ctx, *h_tgt = torch.broadcast_tensors(h_ctx, *h_tgt)
108+
104109
# h_tgt is 'target' w.r.t TransformerDecoder
105110
# h_ctx and context are 'memory'
106111
batch_size = h_ctx.shape[0]*h_ctx.shape[1]
@@ -193,20 +198,35 @@ def cell_state_names(self):
193198
def cell_state(self):
194199
return tuple(getattr(self, n) for n in self.cell_state_names())
195200

196-
def get_samplers(self, index_pitch=None, allow_start=False, allow_end=False):
201+
def get_samplers(self,
202+
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
203+
sweep_time=False):
197204
def sample_pitch(x):
198205
if not allow_start:
199206
x[...,self.start_token] = -np.inf
200207
if not allow_end:
201208
x[...,self.end_token] = -np.inf
202209
if index_pitch is not None:
203210
return x.argsort(-1, True)[...,index_pitch]
211+
elif pitch_topk is not None:
212+
return x.argsort(-1, True)[...,:pitch_topk].transpose(0,-1)
204213
else:
205214
return D.Categorical(logits=x).sample()
206215

216+
def sample_time(x):
217+
if sweep_time:
218+
assert x.shape[0]==1, "batch size should be 1 here"
219+
log_pi, loc, s = self.time_dist.get_params(x)
220+
idx = log_pi.squeeze().argsort()[:9]
221+
loc = loc.squeeze()[idx].sort().values[...,None] # multiple times in batch dim
222+
# print(loc.shape)
223+
return loc
224+
else:
225+
return self.time_dist.sample(x)
226+
207227
return (
208228
sample_pitch,
209-
lambda x: self.time_dist.sample(x),
229+
sample_time,
210230
lambda x: self.vel_dist.sample(x),
211231
)
212232

@@ -289,9 +309,10 @@ def forward(self, pitches, times, velocities, validation=False):
289309
def predict(self,
290310
pitch, time, vel,
291311
fix_pitch=None, fix_time=None, fix_vel=None,
292-
index_pitch=None, allow_start=False, allow_end=False):
312+
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
313+
sweep_time=False):
293314
"""
294-
supply the most recent note and return a prediction for the next note.
315+
consume the most recent note and return a prediction for the next note.
295316
296317
various constraints can be enforced on the next note.
297318
@@ -304,6 +325,8 @@ def predict(self,
304325
most likely pitch instead of sampling.
305326
allow_start: if False, zero probability for sampling the start token
306327
allow_end: if False, zero probaility for sampling the end token
328+
sweep_time: if True, instead of sampling time, choose a diverse set of
329+
times and stack along the batch dimension
307330
308331
Returns: dict of
309332
'pitch': int. predicted MIDI number of next note.
@@ -332,7 +355,8 @@ def predict(self,
332355

333356
modalities = list(zip(
334357
self.projections,
335-
self.get_samplers(index_pitch, allow_start, allow_end),
358+
self.get_samplers(
359+
pitch_topk, index_pitch, allow_start, allow_end, sweep_time),
336360
self.embeddings,
337361
))
338362

@@ -348,15 +372,21 @@ def predict(self,
348372

349373
# permute h_tgt, embs, modalities
350374
# if any modalities are determined, embed them
351-
det_idx, undet_idx = [], []
375+
# sort constrained modailities before unconstrained
376+
# TODO: option to skip modalities
377+
det_idx, cons_idx, uncons_idx = [], [], []
352378
for i,(item, embed) in enumerate(zip(fix, self.embeddings)):
353379
if item is None:
354-
undet_idx.append(i)
380+
if (i==1 and sweep_time) or (i==0 and pitch_topk):
381+
cons_idx.append(i)
382+
else:
383+
uncons_idx.append(i)
355384
else:
356385
det_idx.append(i)
357386
context.append(embed(item))
358-
predicted.append(item.item())
387+
predicted.append(item)
359388
params.append(None)
389+
undet_idx = cons_idx + uncons_idx
360390
perm = det_idx + undet_idx # permutation from the canonical order
361391
iperm = np.argsort(perm) # inverse permutation back to canonical order
362392

@@ -378,16 +408,34 @@ def predict(self,
378408
hidden = self.xformer(context, h_ctx, perm_h_tgt[:j+1])[j]
379409
params.append(project(hidden))
380410
pred = sample(params[-1])
381-
predicted.append(pred.item())
411+
predicted.append(pred)
382412
# prepare for next iteration
383413
if len(undet_idx):
384414
context.append(embed(pred))
385415
det_idx.append(i)
386416

417+
418+
pred_pitch = predicted[iperm[0]]
419+
pred_time = predicted[iperm[1]]
420+
pred_vel = predicted[iperm[2]]
421+
422+
print(pred_time.shape)
423+
print(pred_pitch.shape)
424+
print(pred_vel.shape)
425+
426+
if sweep_time or pitch_topk:
427+
pred_pitch = [x.item() for x in pred_pitch]
428+
pred_time = [x.item() for x in pred_time]
429+
pred_vel = [x.item() for x in pred_vel]
430+
print(pred_time, pred_pitch, pred_vel)
431+
else:
432+
pred_pitch = pred_pitch.item()
433+
pred_time = pred_time.item()
434+
pred_vel = pred_vel.item()
387435
return {
388-
'pitch': predicted[iperm[0]],
389-
'time': predicted[iperm[1]],
390-
'velocity': predicted[iperm[2]],
436+
'pitch': pred_pitch,
437+
'time': pred_time,
438+
'velocity': pred_vel,
391439
'pitch_params': params[iperm[0]],
392440
'time_params': params[iperm[1]],
393441
'vel_params': params[iperm[2]]

0 commit comments

Comments
 (0)