Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 7536660

Browse files
time truncation
1 parent 5c4f2bb commit 7536660

2 files changed

Lines changed: 56 additions & 13 deletions

File tree

notepredictor/notepredictor/distributions.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import numpy as np
23

34
import torch
45
from torch import nn
@@ -97,33 +98,63 @@ def forward(self, h, x):
9798
return r
9899

99100
def cdf(self, h, x):
101+
"""
102+
Args:
103+
h: Tensor[...,n_params]
104+
x: Tensor[...]
105+
`h` should broadcast with `x[...,None]`
106+
Returns:
107+
cdf: Tensor[...] (shape of `x` broadcasted with `h[...,0]`)
108+
"""
100109
log_pi, loc, s = self.get_params(h)
101-
x_ = (x[...,None] - loc) * s
102-
cdfs = x_.sigmoid()
110+
cdfs = self.cdf_components(loc, s, x)
103111
cdf = (cdfs * log_pi.softmax(-1)).sum(-1)
104112
return cdf
105113

106-
def sample(self, h, shape=None):
114+
def cdf_components(self, loc, s, x):
115+
x_ = (x[...,None] - loc) * s
116+
return x_.sigmoid()
117+
118+
def sample(self, h, truncate=None, shape=None):
107119
"""
108120
Args:
109-
shape: additional sample shape to be prepended to dims
121+
h: Tensor[...,n_params]
122+
shape: additional sample shape to be prepended to dims or None
123+
Returns:
124+
Tensor[*shape,...] (h without last dimension, prepended with `shape`)
110125
"""
111126
if shape is None:
112127
unwrap = True
113128
shape = 1
114129
else:
115130
unwrap = False
131+
132+
if truncate is None:
133+
truncate = (-np.inf, np.inf)
134+
truncate = torch.tensor(truncate)
135+
116136
log_pi, loc, s = self.get_params(h)
117137
scale = 1/s
118138

119-
c = D.Categorical(logits=log_pi).sample((shape,))
139+
# cdfs: [...,bound,component]
140+
cdfs = self.cdf_components(loc[...,None,:], s[...,None,:], truncate)
141+
# prob. mass of each component witin bounds
142+
trunc_probs = cdfs[...,1,:] - cdfs[...,0,:] # [...,component]
143+
probs = log_pi.exp() * trunc_probs # reweighted mixture component probs
144+
145+
c = D.Categorical(probs).sample((shape,))
120146
# move sample dimension first
121147
loc = loc.movedim(-1, 0).gather(0, c)
122148
scale = scale.movedim(-1, 0).gather(0, c)
149+
upper = cdfs[...,1,:].movedim(-1, 0).gather(0, c)
150+
lower = cdfs[...,0,:].movedim(-1, 0).gather(0, c)
123151

124152
u = torch.rand(shape, *h.shape[:-1])
153+
# truncate
154+
u = u * (upper-lower) + lower
125155

126-
x = loc + scale * (u.log() - (1 - u).log())
156+
# x = loc + scale * (u.log() - (1 - u).log())
157+
x = loc - scale * (1/u - 1).log()
127158
x = x.clamp(self.lo, self.hi)
128159
return x[0] if unwrap else x
129160

notepredictor/notepredictor/model.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ def cell_state(self):
200200

201201
def get_samplers(self,
202202
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
203-
sweep_time=False):
203+
sweep_time=False, trunc_time=None):
204+
204205
def sample_pitch(x):
205206
if not allow_start:
206207
x[...,self.start_token] = -np.inf
@@ -214,15 +215,20 @@ def sample_pitch(x):
214215
return D.Categorical(logits=x).sample()
215216

216217
def sample_time(x):
218+
# TODO: respect trunc_time when sweep_time is True
217219
if sweep_time:
220+
if trunc_time is not None:
221+
raise NotImplementedError("""
222+
trunc_time with sweep_time needs implementation
223+
""")
218224
assert x.shape[0]==1, "batch size should be 1 here"
219225
log_pi, loc, s = self.time_dist.get_params(x)
220226
idx = log_pi.squeeze().argsort()[:9]
221227
loc = loc.squeeze()[idx].sort().values[...,None] # multiple times in batch dim
222228
# print(loc.shape)
223229
return loc
224230
else:
225-
return self.time_dist.sample(x)
231+
return self.time_dist.sample(x, truncate=trunc_time)
226232

227233
return (
228234
sample_pitch,
@@ -310,7 +316,7 @@ def predict(self,
310316
pitch, time, vel,
311317
fix_pitch=None, fix_time=None, fix_vel=None,
312318
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
313-
sweep_time=False):
319+
sweep_time=False, trunc_time=None):
314320
"""
315321
consume the most recent note and return a prediction for the next note.
316322
@@ -329,6 +335,7 @@ def predict(self,
329335
allow_end: if False, zero probaility for sampling the end token
330336
sweep_time: if True, instead of sampling time, choose a diverse set of
331337
times and stack along the batch dimension
338+
trunc_time: if not None, truncate the time distribution to (lo, hi)
332339
333340
Returns: dict of
334341
'pitch': int. predicted MIDI number of next note.
@@ -358,7 +365,8 @@ def predict(self,
358365
modalities = list(zip(
359366
self.projections,
360367
self.get_samplers(
361-
pitch_topk, index_pitch, allow_start, allow_end, sweep_time),
368+
pitch_topk, index_pitch, allow_start, allow_end,
369+
sweep_time, trunc_time),
362370
self.embeddings,
363371
))
364372

@@ -379,7 +387,10 @@ def predict(self,
379387
det_idx, cons_idx, uncons_idx = [], [], []
380388
for i,(item, embed) in enumerate(zip(fix, self.embeddings)):
381389
if item is None:
382-
if (i==1 and sweep_time) or (i==0 and pitch_topk):
390+
if (
391+
i==1 and (sweep_time or (trunc_time is not None)) or
392+
i==0 and pitch_topk
393+
):
383394
cons_idx.append(i)
384395
else:
385396
uncons_idx.append(i)
@@ -398,8 +409,9 @@ def predict(self,
398409
# TODO: allow constraints;
399410
# attempt to sort the strongest constraints first
400411
# constraints can be:
401-
# discrete set, in which case evaluate probs and then sample categorical
402-
# range, in which case truncate
412+
# discrete set, in which case evaluate probs and then sample categorical;
413+
# range, in which case truncate;
414+
# temperature?
403415

404416
perm_h_tgt = [h_tgt[i] for i in perm]
405417
while len(undet_idx):

0 commit comments

Comments
 (0)