Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 6776c2f

Browse files
predict delta-time: training
1 parent 5ebc6a6 commit 6776c2f

3 files changed

Lines changed: 233 additions & 55 deletions

File tree

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import math
2+
3+
import torch
4+
from torch import nn
5+
import torch.distributions as D
6+
import torch.nn.functional as F
7+
8+
class CensoredMixturePointyBoi(nn.Module):
9+
def __init__(self, n, res=1e-2, lo='-inf', hi='inf', max_sharp=1e3):
10+
super().__init__()
11+
self.n = n
12+
self.res = res
13+
self.register_buffer('max_sharp', torch.tensor(float(max_sharp)))
14+
self.register_buffer('lo', torch.tensor(float(lo)))
15+
self.register_buffer('hi', torch.tensor(float(hi)))
16+
self.bias = nn.Parameter(torch.cat((
17+
torch.zeros(n), torch.linspace(0,1,n), -torch.ones(n)
18+
)))
19+
20+
@property
21+
def n_params(self):
22+
return self.n*3
23+
24+
def get_params(self, h):
25+
assert h.shape[-1] == self.n_params
26+
h = h+self.bias
27+
# get parameters fron unconstrained hidden state:
28+
logit_pi, loc, log_s = torch.chunk(h, 3, -1)
29+
# mixture coefficients
30+
log_pi = logit_pi - logit_pi.logsumexp(1,keepdim=True)
31+
# sharpness
32+
# s = log_s.exp()
33+
s = torch.min(F.softplus(log_s), self.max_sharp)
34+
return log_pi, loc, s
35+
36+
def forward(self, h, x):
37+
"""log prob of x under distribution parameterized by h"""
38+
log_pi, loc, s = self.get_params(h)
39+
40+
x = x.clamp(self.lo, self.hi)[...,None]
41+
xp, xm = x+self.res/2, x-self.res/2
42+
43+
# numerical crimes follow
44+
45+
# truncation
46+
lo_cens = x <= self.lo
47+
xm_ = torch.where(lo_cens, -h.new_ones([]), (xm-loc)*s)
48+
axm_ = torch.where(lo_cens, h.new_zeros([]), xm_.abs())
49+
hi_cens = x >= self.hi
50+
xp_ = torch.where(hi_cens, h.new_ones([]), (xp-loc)*s)
51+
axp_ = torch.where(hi_cens, h.new_zeros([]), xp_.abs())
52+
53+
log_delta_cdf = (
54+
(xp_ - xm_ + xp_*axm_ - axp_*xm_).log()
55+
# (2*self.res + xp_*axm_ - axp_*xm_).log()
56+
- (axp_ + axm_ + axp_*axm_).log1p()
57+
- math.log(2))
58+
59+
# log prob
60+
r = {
61+
'log_prob': (log_pi + log_delta_cdf).logsumexp(-1)
62+
}
63+
with torch.no_grad():
64+
r |= {
65+
'max_sharpness': s.max(),
66+
'min_sharpness': s.min(),
67+
'min_entropy': D.Categorical(logits=log_pi).entropy().min(),
68+
'min_entropy': D.Categorical(logits=log_pi).entropy().min(),
69+
'min_loc': loc.min(),
70+
'max_loc': loc.max()
71+
}
72+
return r
73+
74+
def cdf(self, h, x):
75+
log_pi, loc, s = self.get_params(h)
76+
x_ = (x[...,None] - loc) * s
77+
cdfs = x_ / (1+x_.abs()) * 0.5 + 0.5
78+
cdf = (cdfs * log_pi.softmax(-1)).sum(-1)
79+
return cdf
80+
81+
82+
def sample(self, h, shape=1):
83+
"""
84+
Args:
85+
shape: additional sample shape to be prepended to dims
86+
"""
87+
# if shape is None: shape = []
88+
89+
log_pi, loc, s = self.get_params(h)
90+
c = D.Categorical(logits=log_pi).sample((shape,))
91+
# move sample dimension first
92+
loc = loc.movedim(-1, 0).gather(0, c)
93+
s = s.movedim(-1, 0).gather(0, c)
94+
95+
u = torch.rand(shape, *h.shape[:-1])*2-1
96+
x_ = u / (1 - u.abs())
97+
x = x_ / s + loc
98+
99+
return x.clamp(self.lo, self.hi)

notepredictor/notepredictor/model.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,65 @@
1+
import math
2+
13
import torch
24
from torch import nn
35
import torch.nn.functional as F
46

57
from .rnn import GenericRNN
8+
from .distributions import CensoredMixturePointyBoi
9+
10+
class SineEmbedding(nn.Module):
11+
def __init__(self, n, f0=1e-3, interval=2):
12+
super().__init__()
13+
self.n = n
14+
self.register_buffer('fs', f0 * interval**torch.arange(n) * 2 * math.pi)
15+
16+
def forward(self, x):
17+
x = x[...,None] * self.fs
18+
return x.sin()
619

7-
class PitchPredictor(nn.Module):
20+
class NotePredictor(nn.Module):
821
# note: use named arguments only for benefit of training script
9-
def __init__(self, emb_size=128, hidden_size=512, domain_size=128,
10-
num_layers=1, kind='gru', dropout=0):
22+
def __init__(self,
23+
pitch_emb_size=128, time_emb_size=16, hidden_size=512,
24+
num_layers=1, kind='gru', dropout=0,
25+
num_pitches=128,
26+
time_components=5, time_res=1e-2,
27+
):
1128
"""
1229
"""
1330
super().__init__()
1431

15-
self.start_token = domain_size-2
16-
self.end_token = domain_size-1
32+
self.start_token = num_pitches
33+
self.end_token = num_pitches+1
1734

18-
self.emb = nn.Embedding(domain_size, emb_size)
19-
self.proj = nn.Linear(hidden_size, domain_size)
20-
#### DEBUG
21-
with torch.no_grad():
22-
self.proj.weight.mul_(1e-2)
35+
self.pitch_domain = num_pitches+2
36+
37+
# TODO: upper truncation?
38+
self.time_dist = CensoredMixturePointyBoi(time_components, time_res, 0, 10)
2339

24-
self.rnn = GenericRNN(kind, emb_size, hidden_size,
40+
# embeddings for inputs
41+
self.pitch_emb = nn.Embedding(self.pitch_domain, pitch_emb_size)
42+
self.time_emb = SineEmbedding(time_emb_size)
43+
44+
# RNN backbone
45+
self.rnn = GenericRNN(kind, pitch_emb_size+time_emb_size, hidden_size,
2546
num_layers=num_layers, batch_first=True, dropout=dropout)
26-
27-
# learnable initial state
47+
48+
# learnable initial RNN state
2849
self.initial_state = nn.ParameterList([
2950
# layer x batch x hidden
3051
nn.Parameter(torch.randn(num_layers,1,hidden_size)*hidden_size**-0.5)
3152
for _ in range(2 if kind=='lstm' else 1)
3253
])
3354

34-
# persistent state for inference
55+
# projection from RNN state to distribution parameters
56+
self.time_proj = nn.Linear(hidden_size, self.time_dist.n_params, bias=False)
57+
self.pitch_proj = nn.Linear(hidden_size + time_emb_size, self.pitch_domain)
58+
with torch.no_grad():
59+
self.time_proj.weight.mul_(1e-2)
60+
self.pitch_proj.weight.mul_(1e-2)
61+
62+
# persistent RNN state for inference
3563
for n,t in zip(self.cell_state_names(), self.initial_state):
3664
self.register_buffer(n, t.clone())
3765

@@ -42,36 +70,61 @@ def cell_state_names(self):
4270
def cell_state(self):
4371
return tuple(getattr(self, n) for n in self.cell_state_names())
4472

45-
def forward(self, notes):
73+
def forward(self, pitches, times):
4674
"""
4775
Args:
48-
notes: LongTensor[batch, time]
76+
pitches: LongTensor[batch, time]
77+
times: FloatTensor[batch, time]
4978
"""
50-
x = self.emb(notes) # batch, time, emb_size
79+
80+
time_emb = self.time_emb(times) # batch, time, time_emb_size
81+
pitch_emb = self.pitch_emb(pitches) # batch, time, note_emb_size
82+
83+
x = torch.cat((pitch_emb, time_emb), -1)
5184
## broadcast intial state to batch size
5285
initial_state = tuple(
5386
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
5487
for t in self.initial_state)
5588
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
5689

57-
logits = self.proj(h[:,:-1]) # batch, time-1, 128
58-
logits = F.log_softmax(logits, -1) # logits = logits - logits.logsumexp(-1, keepdim=True)
59-
targets = notes[:,1:,None] #batch, time-1, 1
60-
return {
61-
'log_probs': logits.gather(-1, targets)[...,0],
62-
'logits': logits
90+
# RNN hidden state -> time prediction
91+
time_params = self.time_proj(h[:,:-1]) # batch, time-1, time_params
92+
time_targets = times[:,1:] # batch, time-1
93+
time_result = self.time_dist(time_params, time_targets)
94+
time_log_probs = time_result.pop('log_prob')
95+
96+
# RNN hidden state, time -> pitch prediction
97+
# pitch_params = h[...,:self.pitch_domain] + self.pitch_bias # CI
98+
pitch_params = self.pitch_proj(torch.cat((h[:,:-1], time_emb[:,1:]), -1))
99+
pitch_logits = F.log_softmax(pitch_params, -1)
100+
pitch_targets = pitches[:,1:,None] #batch, time-1, 1
101+
pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0]
102+
103+
r = {
104+
'pitch_log_probs': pitch_log_probs,
105+
'time_log_probs': time_log_probs,
106+
**time_result
63107
}
108+
with torch.no_grad():
109+
r['time_acc_30ms'] = (
110+
self.time_dist.cdf(time_params, time_targets + 0.03)
111+
- torch.where(time_targets - 0.03 >= 0,
112+
self.time_dist.cdf(time_params, time_targets - 0.03),
113+
time_targets.new_zeros([]))
114+
)
115+
return r
64116

65-
def predict(self, note, sample=True):
117+
# TODO: time
118+
def predict(self, note, time, sample=True):
66119
"""
67120
Args:
68121
note: int
69122
sample: bool
70123
Returns:
71-
int if `sample` else Tensor[domain_size]
124+
int if `sample` else Tensor[num_notes+2]
72125
"""
73126
note = torch.LongTensor([[note]]) # 1x1 (batch, time)
74-
x = self.emb(note) # 1, 1, emb_size
127+
x = self.note_emb(note) # 1, 1, emb_size
75128

76129
h, new_state = self.rnn(x, self.cell_state)
77130
for t,new_t in zip(self.cell_state, new_state):

0 commit comments

Comments
 (0)