Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 69e5b48

Browse files
predict arguments; autopitch example
1 parent 4c08e6b commit 69e5b48

2 files changed

Lines changed: 159 additions & 10 deletions

File tree

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
(
2+
~gui = false;
3+
MIDIIn.connectAll;
4+
b = NetAddr.new("127.0.0.1", 9999);
5+
Server.default.options.inDevice_("Built-in Microph");
6+
Server.default.options.outDevice_("Built-in Output");
7+
// Server.default.options.inDevice_("mic-buds");
8+
// Server.default.options.outDevice_("mic-buds");
9+
s.boot;
10+
~gui.if{
11+
k = MIDIKeyboard.new(bounds: Rect(0, 0, 500, 100), octaves:11, startnote:0)
12+
};
13+
)
14+
15+
(
16+
SynthDef(\pluck, {
17+
var vel = \vel.kr;
18+
var signal = Saw.ar(\freq.kr(20), 3e-2) * EnvGate.new(1);
19+
var fr = 2.pow(Decay.ar(Impulse.ar(0), 3)*6*vel+8);
20+
signal = BLowPass.ar(signal, fr)*vel;
21+
Out.ar([0,1], signal);
22+
}).add
23+
)
24+
25+
26+
// measure round-trip latency
27+
(
28+
OSCdef(\return, {
29+
arg msg, time, addr, recvPort;
30+
(Process.elapsedTime - t).postln;
31+
}, '/prediction', nil);
32+
t = Process.elapsedTime;
33+
b.sendMsg("/predictor/predict",
34+
\pitch, 60+12.rand, \time, 0, \vel, 0, \fix_time, 0, \fix_vel, 0);
35+
)
36+
37+
// set the delay for more precise timing
38+
~delay = 0.01;
39+
40+
41+
// NetAddr.localAddr // retrieve the current IP and port
42+
// thisProcess.openPorts; // list all open ports
43+
44+
// model chooses pitches
45+
(
46+
~gate = 1;
47+
48+
~reset = {
49+
~last_pitch = nil;
50+
~last_dt = nil;
51+
~last_vel = nil;
52+
t = Process.elapsedTime;
53+
b.sendMsg("/predictor/reset");
54+
y!?{y.free};
55+
y = nil;
56+
b.sendMsg("/predictor/predict", \pitch, 128, \time, 0, \vel, 0);
57+
58+
};
59+
60+
~reset.();
61+
62+
// footswitch
63+
MIDIdef.program(\switch, {
64+
arg num, chan, src;
65+
num.switch
66+
{1}{~gate = 0}
67+
{2}{~gate = 1}
68+
{3}{
69+
~gate = 0;
70+
SystemClock.clear;
71+
b.sendMsg("/predictor/reset");
72+
y.release;
73+
SystemClock.clear;
74+
};
75+
~gate.postln;
76+
});
77+
78+
79+
// MIDI from controller
80+
MIDIdef.noteOn(\input, {
81+
arg vel, pitch, chan, src;
82+
var t2 = Process.elapsedTime;
83+
var dt = t2-(t?t2); //time since last note
84+
85+
// release the previous note
86+
y.release(0.1);
87+
88+
// attack the current note with the old pitch
89+
y = Synth(\pluck, [\freq, ~last_pitch.midicps, \vel, vel/127]);
90+
91+
// get a new prediction in light of last note,
92+
// fixing dt and vel to performed values so just pitch is predicted
93+
b.sendMsg("/predictor/predict",
94+
\pitch, ~last_pitch, \time, ~last_dt, \vel, ~last_vel,
95+
\fix_time, dt, \fix_vel, vel);
96+
97+
~last_dt = dt;
98+
~last_vel = vel;
99+
100+
// mark time of current note
101+
t = t2;
102+
});
103+
104+
// OSC return from python
105+
OSCdef(\return, {
106+
arg msg, time, addr, recvPort;
107+
var pitch = msg[1]; // MIDI number of predicted note
108+
var dt = msg[2]; // time to predicted note
109+
var vel = msg[3]; // velocity 0-127
110+
111+
// store the pitch and immediately set (unless there is no synth,
112+
// indicating this is the first note)
113+
~last_pitch = pitch;
114+
~last_dt.isNil.if{~last_dt = dt};
115+
~last_vel.isNil.if{~last_vel = vel};
116+
y!?{y.set(\freq, ~last_pitch.midicps)};
117+
118+
[pitch, dt, vel].postln;
119+
120+
}, "/prediction", nil);
121+
)
122+
123+
~reset.()
124+
// send a note manually if you don't have a midi controller
125+
MIDIdef.all[\input].func.(64)
126+
127+
(
128+
MIDIdef.all[\input].func.(64);
129+
MIDIdef.all[\input].func.(64);
130+
)
131+
132+
// load another model
133+
// b.sendMsg("/predictor/load", "/path/to/checkpoint");

notepredictor/notepredictor/model.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,16 @@ def cell_state_names(self):
193193
def cell_state(self):
194194
return tuple(getattr(self, n) for n in self.cell_state_names())
195195

196-
@property
197-
def samplers(self):
196+
def get_samplers(self, allow_start=False, allow_end=False):
197+
def sample_pitch(x):
198+
if not allow_start:
199+
x[...,self.start_token] = -np.inf
200+
if not allow_end:
201+
x[...,self.end_token] = -np.inf
202+
return D.Categorical(logits=x).sample()
203+
198204
return (
199-
lambda x: D.Categorical(logits=x).sample(),
205+
sample_pitch,
200206
lambda x: self.time_dist.sample(x),
201207
lambda x: self.vel_dist.sample(x),
202208
)
@@ -285,15 +291,18 @@ def forward(self, pitches, times, velocities, validation=False):
285291
return r
286292

287293
# TODO: force
288-
def predict(self, pitch, time, vel, force=(None, None, None)):
294+
def predict(self,
295+
pitch, time, vel,
296+
fix_pitch=None, fix_time=None, fix_vel=None,
297+
allow_end=False, allow_start=False):
289298
"""
290299
supply the most recent note and return a prediction for the next note.
291300
292301
Args:
293302
pitch: int. MIDI number of current note.
294303
time: float. elapsed time since previous note.
295304
vel: float. (possibly dequantized) MIDI velocity from 0-127 inclusive.
296-
force: Tuple[Optional[Number]].
305+
fix_*: same as above, but to fix a value for the predicted note
297306
298307
Returns: dict of
299308
'pitch': int. predicted MIDI number of next note.
@@ -322,23 +331,24 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
322331

323332
modalities = list(zip(
324333
self.projections,
325-
self.samplers,
334+
self.get_samplers(allow_start=allow_start, allow_end=allow_end),
326335
self.embeddings,
327336
))
328337

329338
context = []
330339
predicted = []
331340
params = []
332341

333-
force = [
342+
fix = [
334343
None if item is None else torch.tensor([[item]], dtype=dtype)
335344
for item, dtype in zip(
336-
force, [torch.long, torch.float, torch.float])]
345+
[fix_pitch, fix_time, fix_vel],
346+
[torch.long, torch.float, torch.float])]
337347

338348
# permute h_tgt, embs, modalities
339349
# if any modalities are determined, embed them;
340350
det_idx, undet_idx = [], []
341-
for i,(item, embed) in enumerate(zip(force, self.embeddings)):
351+
for i,(item, embed) in enumerate(zip(fix, self.embeddings)):
342352
if item is None:
343353
undet_idx.append(i)
344354
else:
@@ -354,11 +364,17 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
354364
# for each undetermined modality,
355365
# sample a new value conditioned on alteady determined ones
356366

367+
# TODO: allow constraints;
368+
# attempt to sort the strongest constraints first
369+
# constraints can be:
370+
# discrete set, in which case evaluate probs and then sample categorical
371+
# range, in which case truncate
372+
357373
while len(undet_idx):
358374
i = undet_idx.pop(0) # index of modality to determine
359375
j = len(det_idx) # number already determined
360376
project, sample, embed = modalities[i]
361-
# determine the next modality
377+
# determine value for the next modality
362378
hidden = self.xformer(context, h_ctx, perm_h_tgt[:j+1])[j]
363379
params.append(project(hidden))
364380
pred = sample(params[-1])

0 commit comments

Comments
 (0)