Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit aed821f

Browse files
fixes
1 parent 4fe3e1a commit aed821f

3 files changed

Lines changed: 71 additions & 57 deletions

File tree

examples/notepredictor/midi-duet.scd

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ s.boot;
1414

1515
(
1616
SynthDef(\pluck, {
17+
var vel = \vel.kr;
1718
var signal = Saw.ar(\freq.kr, 0.2) * EnvGate.new(1);
1819
Out.ar([0,1], signal);
1920
}).add
@@ -26,7 +27,7 @@ OSCdef(\return, {
2627
(Process.elapsedTime - t).postln;
2728
}, '/prediction', nil);
2829
t = Process.elapsedTime;
29-
b.sendMsg("/predictor/predict", \pitch, 60+12.rand, \time, 0);
30+
b.sendMsg("/predictor/predict", \pitch, 60+12.rand, \time, 0, \vel, 0);
3031
)
3132

3233
// set the delay for more precise timing
@@ -51,16 +52,16 @@ MIDIdef.noteOn(\input, {
5152
SystemClock.clear;
5253

5354
//get a new prediction in light of current note
54-
b.sendMsg("/predictor/predict", \pitch, num, \time, dt);
55+
b.sendMsg("/predictor/predict", \pitch, num, \time, dt, \vel, val);
5556

5657
// release the previous note
57-
y.release(1.0);
58+
y.release(0.1);
5859

5960
// play the current note
60-
y = Synth(\pluck, [\freq, num.midicps]);//.release(1);
61+
y = Synth(\pluck, [\freq, num.midicps, \vel, val/127]);//.release(1);
6162

6263
// post the current note
63-
[\player, dt, num].postln;
64+
[\player, dt, num, val].postln;
6465

6566
// mark time of current note
6667
t = t2;
@@ -72,15 +73,16 @@ OSCdef(\return, {
7273
arg msg, time, addr, recvPort;
7374
var num = msg[1]; // MIDI number of predicted note
7475
var dt = msg[2]; // time to predicted note
76+
var val = msg[3]; // velocity 0-127
7577

7678
// time-to-next note gets 'censored' by the model
7779
// when over a threshold, in this case 10 seconds,
7880
// meaning it just predicts 10s rather than a any longer time
79-
var censor = dt==10.0;
81+
var censor = true;//dt==10.0;
8082

8183
censor.if{
8284
// if the predicted time is > 10 seconds, don't schedule it, just stop.
83-
\censor.postln; y.release(3.0)
85+
\censor.postln; //y.release(3.0)
8486
}{
8587
// schedule the predicted note
8688
SystemClock.sched(dt-~delay, {
@@ -99,17 +101,19 @@ OSCdef(\return, {
99101
// be if there was a lot of fast MIDI input)
100102
SystemClock.clear;
101103
// feed model its own prediction as input
102-
b.sendMsg("/predictor/predict", \pitch, num, \time, dt);
104+
b.sendMsg("/predictor/predict",
105+
\pitch, num, \time, dt, \vel, val);
103106
// release the previous note
104107
(dt<3e-2).if{
105-
// if the time delay is very small, slowly release to play a chord
108+
// if the time delay is very small, slow release for chord
106109
y.release(1.0)
107110
}{
108111
// otherwise release fast to play a melody
109112
y.release(0.1)
110113
};
111114
// play the current note
112-
y = Synth(\pluck, [\freq, num.midicps]);//.release(1);
115+
y = Synth(\pluck, [
116+
\freq, num.midicps, \vel, val/127]);//.release(1);
113117
// post the current note
114118
[\model, dt, num].postln;
115119
// mark the time of current note
@@ -126,7 +130,7 @@ OSCdef(\return, {
126130
)
127131

128132
// send a note manually if you don't have a MIDI controller:
129-
b.sendMsg("/predictor/predict", \pitch, 70, \time, 0);
133+
// b.sendMsg("/predictor/predict", \pitch, 70, \time, 0, \vel, 64);
130134

131135
// load another model
132136
// b.sendMsg("/predictor/load", "/path/to/checkpoint");

examples/notepredictor/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _(address, **kw):
3939
print('no model loaded')
4040
else:
4141
r = predictor.predict(**kw)
42-
return '/prediction', r['pitch'], r['time']
42+
return '/prediction', r['pitch'], r['time'], r['velocity']
4343

4444
elif cmd=="reset":
4545
if predictor is None:

notepredictor/notepredictor/model.py

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -84,34 +84,40 @@ class ModalityTransformer(nn.Module):
8484
"""Model joint distribution of modalities autoregressively with random permutations"""
8585
def __init__(self, input_size, hidden_size, heads=4, layers=1):
8686
super().__init__()
87-
self.net = nn.TransformerEncoder(
88-
nn.TransformerEncoderLayer(
87+
self.net = nn.TransformerDecoder(
88+
nn.TransformerDecoderLayer(
8989
input_size, heads, hidden_size, norm_first=False
9090
), layers)
9191

92-
def forward(self, h, modes):
92+
def forward(self, h, targets):
9393
"""
9494
Args:
95-
modes: each a Tensor[batch x time x input_size]
95+
h: list of Tensor[batch x time x input_size], length note_dim+1
96+
targets: list of Tensor[batch x time x input_size], length note_dim-1
9697
"""
97-
x = [h]+modes
98-
batch_size = h.shape[0]*h.shape[1]
98+
h = list(h)
99+
targets = list(targets)
100+
# h is 'target' w.r.t TransformerDecoder
101+
# targets is 'memory'
102+
batch_size = h[0].shape[0]*h[0].shape[1]
99103
# fold time into batch, stack modes
100-
x = torch.stack([
104+
tgt = torch.stack([
101105
item.reshape(batch_size,-1)
102-
for item in x
106+
for item in h[1:]
107+
],0)
108+
mem = torch.stack([
109+
item.reshape(batch_size,-1)
110+
for item in h[:1]+targets
103111
],0)
104112
# now "time"(mode) x "batch"(+time) x channel
105113

106114
# generate a mask
107-
# upper triangular (i.e. diagonal and above is True, meaning masked)
108-
# except h position should attend to self
109-
n = len(modes)+1
110-
mask = x.new_ones((n,n), dtype=bool).triu()
111-
mask[0,0] = False
115+
# this is both the target and memory mask
116+
n = len(h)-1
117+
mask = ~tgt.new_ones((n,n), dtype=bool).tril()
112118

113-
x = self.net(x, mask)
114-
return list(x.reshape(n, *h.shape).unbind(0))[1:]
119+
x = self.net(tgt, mem, mask, mask)
120+
return list(x.reshape(n, *h[0].shape).unbind(0))
115121

116122

117123
class NotePredictor(nn.Module):
@@ -129,6 +135,8 @@ def __init__(self,
129135
"""
130136
super().__init__()
131137

138+
self.note_dim = 3 # pitch, time, velocity
139+
132140
self.start_token = num_pitches
133141
self.end_token = num_pitches+1
134142

@@ -147,7 +155,7 @@ def __init__(self,
147155

148156
# RNN backbone
149157
self.rnn = GenericRNN(kind,
150-
3*emb_size, rnn_hidden,
158+
self.note_dim*emb_size, rnn_hidden,
151159
num_layers=rnn_layers, batch_first=True, dropout=dropout)
152160

153161
# learnable initial RNN state
@@ -158,7 +166,7 @@ def __init__(self,
158166
])
159167

160168
# projection from RNN state to distribution parameters
161-
self.h_proj = nn.Linear(rnn_hidden, emb_size)
169+
self.h_proj = nn.Linear(rnn_hidden, emb_size*(1+self.note_dim))
162170
self.projections = nn.ModuleList([
163171
nn.Linear(emb_size, self.pitch_domain),
164172
nn.Linear(emb_size, self.time_dist.n_params, bias=False),
@@ -199,22 +207,25 @@ def forward(self, pitches, times, velocities, validation=False):
199207
embs = (pitch_emb, time_emb, vel_emb)
200208

201209
x = torch.cat(embs, -1)[:,:-1] # skip last time position
202-
## broadcast intial state to batch size
210+
## broadcast initial state to batch size
203211
initial_state = tuple(
204212
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
205213
for t in self.initial_state)
206214
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
207215

208216
# include initial hidden state for predicting first note
209-
h = torch.cat((
210-
self.initial_state[0][-1][None].expand(batch_size, 1, -1),
211-
h), -2)
217+
# h = torch.cat((
218+
# self.initial_state[0][-1][None].expand(batch_size, 1, -1),
219+
# h), -2)
212220

213221
# fit all note factorizations at once.
214-
perm = torch.randperm(3)
215-
embs = [embs[i] for i in perm]
216-
mode_hs = self.xformer(self.h_proj(h), embs)
217-
mode_hs = [mode_hs[perm[i]] for i in perm]
222+
# TODO: perm each batch item independently?
223+
perm = torch.randperm(self.note_dim)
224+
hs = list(self.h_proj(h).chunk(self.note_dim+1, -1))
225+
hs = hs[:1] + [hs[i+1] for i in perm]
226+
embs = [embs[i][:,1:] for i in perm[:-1]]
227+
mode_hs = self.xformer(hs, embs)
228+
mode_hs = [mode_hs[i] for i in perm.argsort()]
218229

219230
pitch_params, time_params, vel_params = [
220231
proj(h) for proj,h in zip(self.projections, mode_hs)]
@@ -224,11 +235,11 @@ def forward(self, pitches, times, velocities, validation=False):
224235
pitch_targets = pitches[:,1:,None] #batch, time, 1
225236
pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0]
226237

227-
time_targets = times# batch, time
238+
time_targets = times[:,1:] # batch, time
228239
time_result = self.time_dist(time_params, time_targets)
229240
time_log_probs = time_result.pop('log_prob')
230241

231-
vel_targets = velocities # batch, time
242+
vel_targets = velocities[:,1:] # batch, time
232243
vel_result = self.vel_dist(vel_params, vel_targets)
233244
vel_log_probs = vel_result.pop('log_prob')
234245

@@ -252,7 +263,7 @@ def forward(self, pitches, times, velocities, validation=False):
252263
)
253264
return r
254265

255-
# TODO: vel
266+
# TODO: force
256267
def predict(self, pitch, time, vel, force=(None, None, None)):
257268
"""
258269
supply the most recent note and return a prediction for the next note.
@@ -284,38 +295,37 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
284295
for t,new_t in zip(self.cell_state, new_state):
285296
t[:] = new_t
286297

287-
h = self.h_proj(h)
298+
h = self.h_proj(h).chunk(self.note_dim+1, -1)
288299

289300
# TODO: permutations
290-
# TODO: optimize by removing unused positions
291301
# TODO: refactor with common distribution API
292-
pitch_h, = self.xformer(h, embs[:1])
302+
pitch_h, = self.xformer(h[:2], [])
293303

294304
pitch_params = self.projections[0](pitch_h)
295305
pred_pitch = D.Categorical(logits=pitch_params).sample()
296306

297307
embs[0] = self.pitch_emb(pred_pitch)
298-
_, time_h = self.xformer(h, embs[:2])
308+
_, time_h = self.xformer(h[:3], embs[:1])
299309

300310
time_params = self.projections[1](time_h)
301-
pred_time = self.time_dist.sample(time_params)
311+
# pred_time = self.time_dist.sample(time_params)
312+
313+
### TODO: generalize, move into sample
314+
# pi only, fewer zeros:
315+
log_pi, loc, s = (
316+
t for t in self.time_dist.get_params(time_params))
317+
bias = float('inf')
318+
log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
319+
idx = D.Categorical(logits=log_pi).sample().item()
320+
pred_time = loc[...,idx].clamp(0,10)
321+
###
302322

303323
embs[1] = self.time_emb(pred_time)
304-
_, _, vel_h = self.xformer(h, embs)
324+
_, _, vel_h = self.xformer(h, embs[:2])
305325

306326
vel_params = self.projections[2](vel_h)
307327
pred_vel = self.vel_dist.sample(vel_params)
308328

309-
# ### TODO: generalize, move into sample
310-
# ### DEBUG
311-
# # pi only, fewer zeros:
312-
# log_pi, loc, s = (
313-
# t.squeeze() for t in self.time_dist.get_params(time_params))
314-
# bias = 2#float('inf')
315-
# log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
316-
# idx = D.Categorical(logits=log_pi).sample()
317-
# pred_time = loc[idx].clamp(0,10)
318-
319329
return {
320330
'pitch': pred_pitch.item(),
321331
'time': pred_time.item(),
@@ -336,7 +346,7 @@ def reset(self, start=True):
336346
for n,t in zip(self.cell_state_names(), self.initial_state):
337347
getattr(self, n)[:] = t.detach()
338348
if start:
339-
self.predict(self.start_token, 0.)
349+
self.predict(self.start_token, 0., 0.)
340350

341351
@classmethod
342352
def from_checkpoint(cls, path):

0 commit comments

Comments
 (0)