Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 4fe3e1a

Browse files
velocity prediction and mini-transformer with random permutations
1 parent 5313d3e commit 4fe3e1a

1 file changed

Lines changed: 152 additions & 133 deletions

File tree

notepredictor/notepredictor/model.py

Lines changed: 152 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import math
22

3+
import numpy as np
4+
35
import torch
46
from torch import nn
57
import torch.nn.functional as F
@@ -19,10 +21,11 @@ def __init__(self, n, w0=1e-3, interval=1.08):
1921
super().__init__()
2022
self.n = n
2123
self.register_buffer('fs', interval**(-torch.arange(n)) / w0 * 2 * math.pi)
24+
self.proj = nn.Linear(n,n)
2225

2326
def forward(self, x):
2427
x = x[...,None] * self.fs
25-
return x.sin()
28+
return self.proj(x.sin())
2629

2730
class MixEmbedding(nn.Module):
2831
def __init__(self, n, domain=(0,1)):
@@ -42,49 +45,85 @@ def forward(self, x):
4245
Returns:
4346
Tensor[...,n]
4447
"""
45-
x = (x - self.domain[0])*(self.domain[1] - self.domain[0])
48+
x = (x - self.domain[0])/(self.domain[1] - self.domain[0])
4649
x = x[...,None]
4750
return self.hi * x + self.lo * (1-x)
4851

49-
class SelfGated(nn.Module):
50-
def __init__(self):
52+
# class SelfGated(nn.Module):
53+
# def __init__(self):
54+
# super().__init__()
55+
56+
# def forward(self, x):
57+
# a, b = x.chunk(2, -1)
58+
# return a * b.sigmoid()
59+
60+
# class SelfGatedMLP(nn.Module):
61+
# def __init__(self, input, hidden, output, layers, dropout=0):
62+
# super().__init__()
63+
# h = input
64+
# def get_dropout():
65+
# if dropout > 0:
66+
# return (nn.Dropout(dropout),)
67+
# else:
68+
# return tuple()
69+
# self.net = []
70+
# for _ in range(layers):
71+
# self.net.append(nn.Sequential(
72+
# *get_dropout(), nn.Linear(h, hidden*2), SelfGated()))
73+
# h = hidden
74+
# self.net.append(nn.Linear(hidden, output))
75+
# self.net = nn.Sequential(*self.net)
76+
77+
# with torch.no_grad():
78+
# self.net[-1].weight.mul_(1e-2)
79+
80+
# def forward(self, x):
81+
# return self.net(x)
82+
83+
class ModalityTransformer(nn.Module):
84+
"""Model joint distribution of modalities autoregressively with random permutations"""
85+
def __init__(self, input_size, hidden_size, heads=4, layers=1):
5186
super().__init__()
87+
self.net = nn.TransformerEncoder(
88+
nn.TransformerEncoderLayer(
89+
input_size, heads, hidden_size, norm_first=False
90+
), layers)
5291

53-
def forward(self, x):
54-
a, b = x.chunk(2, -1)
55-
return a * b.sigmoid()
56-
57-
class SelfGatedMLP(nn.Module):
58-
def __init__(self, input, hidden, output, layers, dropout=0):
59-
super().__init__()
60-
h = input
61-
def get_dropout():
62-
if dropout > 0:
63-
return (nn.Dropout(dropout),)
64-
else:
65-
return tuple()
66-
self.net = []
67-
for _ in range(layers):
68-
self.net.append(nn.Sequential(
69-
*get_dropout(), nn.Linear(h, hidden*2), SelfGated()))
70-
h = hidden
71-
self.net.append(nn.Linear(hidden, output))
72-
self.net = nn.Sequential(*self.net)
73-
74-
with torch.no_grad():
75-
self.net[-1].weight.mul_(1e-2)
92+
def forward(self, h, modes):
93+
"""
94+
Args:
95+
modes: each a Tensor[batch x time x input_size]
96+
"""
97+
x = [h]+modes
98+
batch_size = h.shape[0]*h.shape[1]
99+
# fold time into batch, stack modes
100+
x = torch.stack([
101+
item.reshape(batch_size,-1)
102+
for item in x
103+
],0)
104+
# now "time"(mode) x "batch"(+time) x channel
105+
106+
# generate a mask
107+
# upper triangular (i.e. diagonal and above is True, meaning masked)
108+
# except h position should attend to self
109+
n = len(modes)+1
110+
mask = x.new_ones((n,n), dtype=bool).triu()
111+
mask[0,0] = False
112+
113+
x = self.net(x, mask)
114+
return list(x.reshape(n, *h.shape).unbind(0))[1:]
76115

77-
def forward(self, x):
78-
return self.net(x)
79116

80117
class NotePredictor(nn.Module):
81118
# note: use named arguments only for benefit of training script
82119
def __init__(self,
83-
pitch_emb_size=128, time_emb_size=128, vel_emb_size=128,
84-
hidden_size=512, num_layers=1, kind='gru', dropout=0,
120+
emb_size=256,
121+
rnn_hidden=2048, rnn_layers=1, kind='gru',
122+
ar_hidden=2048, ar_layers=1, ar_heads=4,
123+
dropout=0.1,
85124
num_pitches=128,
86-
time_bounds=(0,10), time_components=16, time_res=1e-2,
87-
vel_components=8
125+
time_bounds=(0,10), time_components=32, time_res=1e-2,
126+
vel_components=16
88127
):
89128
"""
90129
"""
@@ -102,32 +141,34 @@ def __init__(self,
102141
vel_components, 1.0, lo=0, hi=127, init='velocity')
103142

104143
# embeddings for inputs
105-
self.pitch_emb = nn.Embedding(self.pitch_domain, pitch_emb_size)
106-
self.time_emb = SineEmbedding(time_emb_size)
107-
self.vel_emb = MixEmbedding(vel_emb_size, (0, 127))
108-
109-
self.pitch_missing = nn.Parameter(torch.randn(pitch_emb_size))
110-
self.time_missing = nn.Parameter(torch.randn(time_emb_size))
111-
self.vel_missing = nn.Parameter(torch.randn(vel_emb_size))
144+
self.pitch_emb = nn.Embedding(self.pitch_domain, emb_size)
145+
self.time_emb = SineEmbedding(emb_size)
146+
self.vel_emb = MixEmbedding(emb_size, (0, 127))
112147

113148
# RNN backbone
114149
self.rnn = GenericRNN(kind,
115-
pitch_emb_size+time_emb_size+vel_emb_size, hidden_size,
116-
num_layers=num_layers, batch_first=True, dropout=dropout)
150+
3*emb_size, rnn_hidden,
151+
num_layers=rnn_layers, batch_first=True, dropout=dropout)
117152

118153
# learnable initial RNN state
119154
self.initial_state = nn.ParameterList([
120155
# layer x batch x hidden
121-
nn.Parameter(torch.randn(num_layers,1,hidden_size)*hidden_size**-0.5)
156+
nn.Parameter(torch.randn(rnn_layers,1,rnn_hidden)*rnn_hidden**-0.5)
122157
for _ in range(2 if kind=='lstm' else 1)
123158
])
124159

125160
# projection from RNN state to distribution parameters
126-
self.param_proj = SelfGatedMLP(
127-
pitch_emb_size+time_emb_size+vel_emb_size+hidden_size,
128-
hidden_size//2,
129-
self.pitch_domain+self.time_dist.n_params+self.vel_dist.n_params,
130-
layers=2, dropout=dropout)
161+
self.h_proj = nn.Linear(rnn_hidden, emb_size)
162+
self.projections = nn.ModuleList([
163+
nn.Linear(emb_size, self.pitch_domain),
164+
nn.Linear(emb_size, self.time_dist.n_params, bias=False),
165+
nn.Linear(emb_size, self.vel_dist.n_params, bias=False)
166+
])
167+
for p in self.projections:
168+
with torch.no_grad():
169+
p.weight.mul_(1e-2)
170+
171+
self.xformer = ModalityTransformer(emb_size, ar_hidden, ar_heads, ar_layers)
131172

132173
# persistent RNN state for inference
133174
for n,t in zip(self.cell_state_names(), self.initial_state):
@@ -151,82 +192,43 @@ def forward(self, pitches, times, velocities, validation=False):
151192
"""
152193
batch_size, batch_len = pitches.shape
153194

154-
pitch_emb = self.pitch_emb(pitches) # batch, time, pitch_emb_size
155-
time_emb = self.time_emb(times) # batch, time, time_emb_size
156-
vel_emb = self.vel_emb(velocities) # batch, time, vel_emb_size
195+
pitch_emb = self.pitch_emb(pitches) # batch, time, emb_size
196+
time_emb = self.time_emb(times) # batch, time, emb_size
197+
vel_emb = self.vel_emb(velocities) # batch, time, emb_size
157198

158-
x = torch.cat((pitch_emb, time_emb, vel_emb), -1)
199+
embs = (pitch_emb, time_emb, vel_emb)
200+
201+
x = torch.cat(embs, -1)[:,:-1] # skip last time position
159202
## broadcast intial state to batch size
160203
initial_state = tuple(
161204
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
162205
for t in self.initial_state)
163206
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
164207

165-
# IDEA: fit all factorizations at once.
166-
# add 'missing' value for time / pitch / velocity as model parameters
167-
# expand the batch to 6x wide with ~/~/~, T/~/~, ~/P/~, ~/~/V, T/P/~, ~/P/V, T/~/V inputs
168-
# the factorizations are:
169-
# _~~ -> T_~ -> TP_
170-
# _~~ -> T~_ -> T_V
171-
# ~_~ -> _P~ -> TP_
172-
# ~_~ -> ~P_ -> _PV
173-
# ~~_ -> _~V -> T_V
174-
# ~~_ -> ~_V -> _PV
175-
# i.e. the fully masked positions are counted 2x,
176-
# the single-masked positions are counted 2x,
177-
# and each double-masked position is counted once
178-
179-
masks = [
180-
[2, 0, 0, 2, 0, 1, 1], #pitch
181-
[2, 2, 0, 0, 1, 0, 1], #time
182-
[2, 0, 2, 0, 1, 1, 0] #velocity
183-
]
184-
185-
def mask_cat(missing, present, mask):
186-
missing = missing[None,None].expand(batch_size, batch_len-1, -1)
187-
return torch.cat([
188-
present if m==0 else missing for m in mask
189-
], 0)
190-
191-
pitch_features = mask_cat(self.pitch_missing, pitch_emb[:,1:], masks[0])
192-
time_features = mask_cat(self.time_missing, time_emb[:,1:], masks[1])
193-
vel_features = mask_cat(self.vel_missing, vel_emb[:,1:], masks[2])
194-
195-
features = torch.cat((
196-
pitch_features, time_features, vel_features, h[:,:-1].repeat(7,1,1)
197-
), -1) # cat along feature dim
198-
199-
dist_params = self.param_proj(features) # combine features with h
200-
201-
# split again into time/pitch/vel params
202-
dist_params = dist_params.split([
203-
self.pitch_domain, self.time_dist.n_params, self.vel_dist.n_params
204-
], -1)
205-
206-
# chunk into 7 and discard unmasked positions;
207-
# stack the masked positions along new first dim
208-
pitch_params, time_params, vel_params = (
209-
torch.stack([
210-
ch
211-
for m,ch in zip(mask, dp.chunk(7, 0)) if m>0
212-
], 0)
213-
for mask,dp in zip(masks, dist_params)
214-
)
215-
216-
#TODO: weighting
217-
# weights = np.log([[m for m in mask if m>0] for mask in masks]) # 3 x 4
208+
# include initial hidden state for predicting first note
209+
h = torch.cat((
210+
self.initial_state[0][-1][None].expand(batch_size, 1, -1),
211+
h), -2)
212+
213+
# fit all note factorizations at once.
214+
perm = torch.randperm(3)
215+
embs = [embs[i] for i in perm]
216+
mode_hs = self.xformer(self.h_proj(h), embs)
217+
mode_hs = [mode_hs[perm[i]] for i in perm]
218+
219+
pitch_params, time_params, vel_params = [
220+
proj(h) for proj,h in zip(self.projections, mode_hs)]
218221

219222
# get likelihoods
220223
pitch_logits = F.log_softmax(pitch_params, -1)
221-
# TODO: is gather working right with extra dim?
222-
pitch_targets = pitches[None,:,1:,None].expand(4, -1, -1, -1) #1, batch, time-1, 1
224+
pitch_targets = pitches[:,1:,None] #batch, time, 1
223225
pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0]
224226

225-
time_targets = times[:,1:]# batch, time-1
227+
time_targets = times# batch, time
226228
time_result = self.time_dist(time_params, time_targets)
227229
time_log_probs = time_result.pop('log_prob')
228230

229-
vel_targets = velocities[:,1:]# batch, time-1
231+
vel_targets = velocities # batch, time
230232
vel_result = self.vel_dist(vel_params, vel_targets)
231233
vel_log_probs = vel_result.pop('log_prob')
232234

@@ -269,41 +271,58 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
269271
with torch.no_grad():
270272
pitch = torch.LongTensor([[pitch]]) # 1x1 (batch, time)
271273
time = torch.FloatTensor([[time]]) # 1x1 (batch, time)
272-
x = torch.cat((
273-
self.pitch_emb(pitch), # 1, 1, pitch_emb_size
274-
self.time_emb(time)# 1, 1, time_emb_size
275-
), -1)
274+
vel = torch.FloatTensor([[vel]]) # 1x1 (batch, time)
275+
276+
embs = [
277+
self.pitch_emb(pitch), # 1, 1, emb_size
278+
self.time_emb(time),# 1, 1, emb_size
279+
self.vel_emb(vel)# 1, 1, emb_size
280+
]
281+
x = torch.cat(embs, -1)
276282

277283
h, new_state = self.rnn(x, self.cell_state)
278284
for t,new_t in zip(self.cell_state, new_state):
279285
t[:] = new_t
280286

281-
pitch_params = self.pitch_proj(h)
287+
h = self.h_proj(h)
288+
289+
# TODO: permutations
290+
# TODO: optimize by removing unused positions
291+
# TODO: refactor with common distribution API
292+
pitch_h, = self.xformer(h, embs[:1])
293+
294+
pitch_params = self.projections[0](pitch_h)
282295
pred_pitch = D.Categorical(logits=pitch_params).sample()
283-
284-
time_params = self.time_proj(h*self.cond_proj(self.pitch_emb(pred_pitch)).sigmoid())
285-
# time_params = self.time_proj(torch.cat((
286-
# h, self.pitch_emb(pred_pitch)
287-
# ), -1)) # 1, 1, time_params
288-
# TODO: importance sampling?
289-
pred_time = self.time_dist.sample(time_params).squeeze(0)
290-
291-
### TODO: generalize, move into sample
292-
### DEBUG
293-
# pi only, fewer zeros:
294-
log_pi, loc, s = (
295-
t.squeeze() for t in self.time_dist.get_params(time_params))
296-
bias = 2#float('inf')
297-
log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
298-
idx = D.Categorical(logits=log_pi).sample()
299-
pred_time = loc[idx].clamp(0,10)
296+
297+
embs[0] = self.pitch_emb(pred_pitch)
298+
_, time_h = self.xformer(h, embs[:2])
299+
300+
time_params = self.projections[1](time_h)
301+
pred_time = self.time_dist.sample(time_params)
302+
303+
embs[1] = self.time_emb(pred_time)
304+
_, _, vel_h = self.xformer(h, embs)
305+
306+
vel_params = self.projections[2](vel_h)
307+
pred_vel = self.vel_dist.sample(vel_params)
308+
309+
# ### TODO: generalize, move into sample
310+
# ### DEBUG
311+
# # pi only, fewer zeros:
312+
# log_pi, loc, s = (
313+
# t.squeeze() for t in self.time_dist.get_params(time_params))
314+
# bias = 2#float('inf')
315+
# log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
316+
# idx = D.Categorical(logits=log_pi).sample()
317+
# pred_time = loc[idx].clamp(0,10)
300318

301319
return {
302320
'pitch': pred_pitch.item(),
303321
'time': pred_time.item(),
304-
'velocity': None,
322+
'velocity': pred_vel.item(),
305323
'pitch_params': pitch_params,
306-
'time_params': time_params
324+
'time_params': time_params,
325+
'vel_params': vel_params
307326
}
308327

309328
# TODO: start velocity

0 commit comments

Comments
 (0)