Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 3cfb0fe

Browse files
add velocity and multi-factorization
1 parent 0074dc4 commit 3cfb0fe

4 files changed

Lines changed: 205 additions & 93 deletions

File tree

notepredictor/notepredictor/data.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __getitem__(self, idx):
2828
item = torch.load(f)
2929
pitch = item['pitch'] # 1-d LongTensor of MIDI pitches 0-127
3030
time = item['time']
31+
velocity = item['velocity']
3132
assert len(pitch) == len(time)
3233

3334
# random transpose avoiding out of range notes
@@ -45,6 +46,11 @@ def __getitem__(self, idx):
4546
time + (torch.rand_like(time)-0.5)*2e-3
4647
).clamp(0., float('inf'))
4748

49+
velocity = (
50+
velocity +
51+
(torch.rand_like(time)-0.5) * ((velocity>0) & (velocity<127)).float()
52+
).clamp(0., 127.)
53+
4854
# pad with start, end tokens
4955
pad = max(1, self.batch_len-len(pitch))
5056
pitch = torch.cat((
@@ -55,15 +61,21 @@ def __getitem__(self, idx):
5561
time.new_zeros((1,)),
5662
time,
5763
time.new_zeros((pad,))))
64+
velocity = torch.cat((
65+
velocity.new_zeros((1,)),
66+
velocity,
67+
velocity.new_zeros((pad,))))
5868

5969
# random slice
6070
i = random.randint(0, len(pitch)-self.batch_len)
6171
pitch = pitch[i:i+self.batch_len]
6272
time = time[i:i+self.batch_len]
73+
velocity = velocity[i:i+self.batch_len]
6374

6475
# time = time.clamp(*self.clamp_time)
6576

6677
return {
6778
'pitch':pitch,
68-
'time':time
79+
'time':time,
80+
'velocity':velocity
6981
}

notepredictor/notepredictor/distributions.py

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,28 @@
66
import torch.nn.functional as F
77

88
class CensoredMixtureLogistic(nn.Module):
9-
def __init__(self, n, res=1e-2, lo='-inf', hi='inf', sharp_bounds=(1e-4,2e3)):
9+
def __init__(self, n, res=1e-2, lo='-inf', hi='inf',
10+
sharp_bounds=(1e-4,2e3), init=None):
1011
super().__init__()
1112
self.n = n
1213
self.res = res
1314
self.sharp_bounds = sharp_bounds
14-
# self.base_bounds = tuple(reversed([math.exp(-s) for s in sharp_bounds]))
15-
# self.register_buffer('max_sharp', torch.tensor(float(max_sharp)))
1615
self.register_buffer('lo', torch.tensor(float(lo)))
1716
self.register_buffer('hi', torch.tensor(float(hi)))
1817
# TODO: init is not general-purpose
1918
#TODO
20-
self.bias = nn.Parameter(torch.cat((
21-
torch.zeros(n), torch.logspace(-3,1,n), torch.zeros(n)
22-
)))
19+
if init=='time':
20+
self.bias = nn.Parameter(torch.cat((
21+
torch.zeros(n), torch.logspace(-3,1,n), torch.zeros(n)
22+
)))
23+
elif init=='velocity':
24+
self.bias = nn.Parameter(torch.cat((
25+
torch.zeros(n), torch.linspace(1,126,n), torch.zeros(n)
26+
)))
27+
else:
28+
self.bias = nn.Parameter(torch.cat((
29+
torch.zeros(n), torch.randn(n), torch.zeros(n)
30+
)))
2331

2432
@property
2533
def n_params(self):
@@ -28,7 +36,7 @@ def n_params(self):
2836
def get_params(self, h):
2937
assert h.shape[-1] == self.n_params
3038
h = h+self.bias
31-
# get parameters fron unconstrained hidden state:
39+
# get parameters from unconstrained hidden state:
3240
logit_pi, loc, log_s = torch.chunk(h, 3, -1)
3341
# mixture coefficients
3442
log_pi = logit_pi - logit_pi.logsumexp(-1,keepdim=True)
@@ -37,48 +45,30 @@ def get_params(self, h):
3745
# sharpness
3846
s = F.softplus(log_s).clamp(*self.sharp_bounds)
3947
return log_pi, loc, s
40-
# exp of negative sharpness
41-
# b = logit_b.sigmoid().clamp(*self.base_bounds)
42-
# log_b = F.logsigmoid(logit_b).clamp(
43-
# math.log(self.base_bounds[0]), math.log(self.base_bounds[1]))
44-
# b = ((logit_b / (logit_b.abs() + 1))*.5+.5).clamp(*self.base_bounds)
45-
# log_b = b.log()
46-
# return log_pi, loc, b, log_b
48+
4749

4850
def forward(self, h, x):
4951
"""log prob of x under distribution parameterized by h"""
50-
# log_pi, loc, b, log_b = self.get_params(h)
5152
log_pi, loc, s = self.get_params(h)
5253

5354
d = self.res/2
54-
x = x.clamp(self.lo-d, self.hi+d)[...,None]
55+
x = x.clamp(self.lo, self.hi)[...,None]
5556
x_ = (x - loc) * s
56-
57-
# numerical crimes follow
58-
# q = b ** -x_
59-
q = x_.exp()
6057
sd = s*d
61-
# bdp, bdm = b**d, b**-d
62-
sdm, sdp = (-sd).exp(), sd.exp()
63-
# # censoring
64-
lo_cens = x <= self.lo
65-
hi_cens = x >= self.hi
66-
ones = torch.ones_like(q)
67-
zeros = torch.zeros_like(q)
68-
69-
diff_term = torch.where(
70-
lo_cens | hi_cens, ones, sdp - sdm).log()
71-
# sdm_term = torch.where(
72-
# hi_cens, ones, (q + sdm)).log()
73-
sdm_term = torch.where(hi_cens, zeros, x_ + F.softplus(-sd-x_))
74-
# sdp_term = torch.where(
75-
# lo_cens, ones, (q + sdp)).log()
76-
sdp_term = torch.where(lo_cens, zeros, x_ + F.softplus(sd-x_))
77-
x_or_sd = torch.where(hi_cens, sd, x_)
7858

79-
log_delta_cdf = (
80-
x_or_sd + diff_term - sdm_term - sdp_term
81-
)
59+
# # censoring
60+
lo_cens = x <= self.lo+d
61+
hi_cens = x >= self.hi-d
62+
ones = torch.ones_like(x_)
63+
zeros = torch.zeros_like(x_)
64+
65+
diff_term = torch.where(lo_cens | hi_cens,
66+
ones, sd.exp() - (-sd).exp()
67+
).log()
68+
minus_sp_term = torch.where(hi_cens, -sd, F.softplus(-sd-x_))
69+
plus_sp_term = torch.where(lo_cens, zeros, x_ + F.softplus(sd-x_))
70+
71+
log_delta_cdf = diff_term - minus_sp_term - plus_sp_term
8272

8373
# log prob
8474
r = {
@@ -87,23 +77,23 @@ def forward(self, h, x):
8777
# diagnostics
8878
with torch.no_grad():
8979
ent = D.Categorical(logits=log_pi).entropy()
90-
# s = -b.log()
9180
r |= {
92-
'min_sharpness': s.min(),
81+
# 'min_sharpness': s.min(),
9382
'max_sharpness': s.max(),
94-
'min_entropy': ent.min(),
95-
'max_entropy': ent.max(),
96-
'marginal_entropy': D.Categorical(
83+
'mean_sharpness': (s*log_pi.exp()).sum(-1).mean(),
84+
# 'min_entropy': ent.min(),
85+
# 'max_entropy': ent.max(),
86+
'mean_cmp_entropy': ent.mean(),
87+
'marginal_cmp_entropy': D.Categorical(
9788
log_pi.exp().mean(list(range(log_pi.ndim-1)))).entropy(),
98-
'min_loc': loc.min(),
99-
'max_loc': loc.max()
89+
# 'min_loc': loc.min(),
90+
# 'max_loc': loc.max()
10091
}
10192
return r
10293

10394
def cdf(self, h, x):
10495
log_pi, loc, s = self.get_params(h)
10596
x_ = (x[...,None] - loc) * s
106-
# cdfs = 1 / (1 + b ** x_)
10797
cdfs = x_.sigmoid()
10898
cdf = (cdfs * log_pi.softmax(-1)).sum(-1)
10999
return cdf
@@ -113,9 +103,7 @@ def sample(self, h, shape=1):
113103
Args:
114104
shape: additional sample shape to be prepended to dims
115105
"""
116-
# log_pi, loc, _, log_b = self.get_params(h)
117106
log_pi, loc, s = self.get_params(h)
118-
# scale = -1/log_b
119107
scale = 1/s
120108

121109
c = D.Categorical(logits=log_pi).sample((shape,))

0 commit comments

Comments
 (0)