Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit b7cd33b

Browse files
time temperature and bias; rename midi-duet to generate.scd
1 parent 7536660 commit b7cd33b

4 files changed

Lines changed: 68 additions & 18 deletions

File tree

examples/notepredictor/autopitch.scd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ s.waitForBoot{
6767
};
6868
)
6969

70+
OSCdef.trace(false)
7071
// ~linn_reset.()
7172

7273
(
Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ b.sendMsg("/predictor/predict", \pitch, 60+12.rand, \time, 0, \vel, 0);
4242
// duet with the model
4343
// feeds the model's predictions back to it as well as player input
4444
(
45+
~step = 0;
4546
~gate = 1;
46-
t = Process.elapsedTime;
47+
t = nil;
48+
~player_t = t;
49+
~machine_t = t;
4750
b.sendMsg("/predictor/reset");
4851

4952
// footswitch
@@ -58,6 +61,7 @@ MIDIdef.program(\switch, {
5861
b.sendMsg("/predictor/reset");
5962
y.release;
6063
SystemClock.clear;
64+
~step = 0;
6165
};
6266
~gate.postln;
6367
});
@@ -73,7 +77,12 @@ MIDIdef.noteOn(\input, {
7377
SystemClock.clear;
7478

7579
//get a new prediction in light of current note
76-
b.sendMsg("/predictor/predict", \pitch, num, \time, dt, \vel, val);
80+
b.sendMsg("/predictor/predict",
81+
\pitch, num, \time, dt, \vel, val,
82+
\allow_start, false, \allow_end, false,
83+
\time_temp, 0, \min_time, 0.1, \max_time, 5
84+
// \fix_time, 9
85+
);
7786

7887
// release the previous note
7988
y.release(0.1);
@@ -86,6 +95,9 @@ MIDIdef.noteOn(\input, {
8695

8796
// mark time of current note
8897
t = t2;
98+
~player_t = t;
99+
100+
~step = ~step + 1;
89101
});
90102

91103

@@ -98,8 +110,8 @@ OSCdef(\return, {
98110

99111
// time-to-next note gets 'censored' by the model
100112
// when over a threshold, in this case 10 seconds,
101-
// meaning it just predicts 10s rather than a any longer time
102-
var censor = dt==10.0;
113+
// meaning it just predicts 10s rather than any longer time
114+
var censor = dt>10.0;
103115

104116
censor.if{
105117
// if the predicted time is > 10 seconds, don't schedule it, just stop.
@@ -108,6 +120,8 @@ OSCdef(\return, {
108120
// schedule the predicted note
109121
SystemClock.sched(dt-~delay, {
110122
(~gate>0).if{
123+
var t2 = Process.elapsedTime;
124+
var dt_actual = t2 - t;
111125
(num==129).if{
112126
// 129 is the 'stop token', meaning 'end-of-performance'
113127
// in this case don't schedule a note, and reset the model
@@ -124,7 +138,13 @@ OSCdef(\return, {
124138
SystemClock.clear;
125139
// feed model its own prediction as input
126140
b.sendMsg("/predictor/predict",
127-
\pitch, num, \time, dt, \vel, val);
141+
\pitch, num, \time, dt_actual, \vel, val,
142+
\allow_start, false, \allow_end, false,
143+
\time_temp, 0.1, \min_time, 0.1, \max_time, 5
144+
// \fix_time, (~step%4==0).if{0.6}{0} // tetrachords
145+
// \fix_time, (~step%8)*0.1 // specific rhythm
146+
147+
);
128148
// release the previous note
129149
(dt<3e-2).if{
130150
// if the time delay is very small, slow release for chord
@@ -138,14 +158,17 @@ OSCdef(\return, {
138158
\freq, num.midicps, \vel, val/127]);//.release(1);
139159
// post the current note
140160
[\model, dt, num, val].postln;
141-
// mark the time of current note
142-
t = Process.elapsedTime;
161+
// mark the actual time of current note
162+
t = t2;
163+
~machine_t = t;
143164
// crudely draw note on piano GUI
144165
~gui.if{
145166
AppClock.sched(0,{k.keyDown(num)});
146167
AppClock.sched(0.2,{k.keyUp(num)});
147168
}
148-
}
169+
};
170+
~step = ~step+1;
171+
[\late, dt_actual-dt].postln;
149172
}
150173
})};
151174

notepredictor/notepredictor/distributions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,17 @@ def cdf_components(self, loc, s, x):
115115
x_ = (x[...,None] - loc) * s
116116
return x_.sigmoid()
117117

118-
def sample(self, h, truncate=None, shape=None):
118+
def sample(self, h, truncate=None, shape=None, temp=None, bias=None):
119119
"""
120120
Args:
121121
h: Tensor[...,n_params]
122-
shape: additional sample shape to be prepended to dims or None
122+
truncate: Optional[Tuple[2]]. lower and upper bound for truncation.
123+
shape: Optional[int]. additional sample shape to be prepended to dims.
124+
temp: Optional[float]. pseudo-temperature (temperature of each mixture
125+
component). default is 1. 0 would sample component location only,
126+
ignoring sharpness.
127+
bias: applied outside of truncation but inside of clamping,
128+
useful e.g. for latency correction when sampling delta-time
123129
Returns:
124130
Tensor[*shape,...] (h without last dimension, prepended with `shape`)
125131
"""
@@ -133,6 +139,12 @@ def sample(self, h, truncate=None, shape=None):
133139
truncate = (-np.inf, np.inf)
134140
truncate = torch.tensor(truncate)
135141

142+
if temp is None:
143+
temp = 1
144+
145+
if bias is None:
146+
bias = 0
147+
136148
log_pi, loc, s = self.get_params(h)
137149
scale = 1/s
138150

@@ -154,7 +166,7 @@ def sample(self, h, truncate=None, shape=None):
154166
u = u * (upper-lower) + lower
155167

156168
# x = loc + scale * (u.log() - (1 - u).log())
157-
x = loc - scale * (1/u - 1).log()
169+
x = loc + bias - scale * temp * (1/u - 1).log()
158170
x = x.clamp(self.lo, self.hi)
159171
return x[0] if unwrap else x
160172

notepredictor/notepredictor/model.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ def cell_state(self):
200200

201201
def get_samplers(self,
202202
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
203-
sweep_time=False, trunc_time=None):
203+
sweep_time=False, min_time=None, max_time=None, bias_time=None, time_temp=None):
204+
"""
205+
this method converts the many arguments to `predict` into functions for
206+
sampling each note modality (e.g. pitch, time, velocity)
207+
"""
204208

205209
def sample_pitch(x):
206210
if not allow_start:
@@ -217,7 +221,7 @@ def sample_pitch(x):
217221
def sample_time(x):
218222
# TODO: respect trunc_time when sweep_time is True
219223
if sweep_time:
220-
if trunc_time is not None:
224+
if min_time is not None or max_time is not None:
221225
raise NotImplementedError("""
222226
trunc_time with sweep_time needs implementation
223227
""")
@@ -228,7 +232,11 @@ def sample_time(x):
228232
# print(loc.shape)
229233
return loc
230234
else:
231-
return self.time_dist.sample(x, truncate=trunc_time)
235+
trunc = (
236+
-np.inf if min_time is None else min_time,
237+
np.inf if max_time is None else max_time)
238+
return self.time_dist.sample(x,
239+
truncate=trunc, temp=time_temp, bias=bias_time)
232240

233241
return (
234242
sample_pitch,
@@ -316,7 +324,7 @@ def predict(self,
316324
pitch, time, vel,
317325
fix_pitch=None, fix_time=None, fix_vel=None,
318326
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
319-
sweep_time=False, trunc_time=None):
327+
sweep_time=False, min_time=None, max_time=None, bias_time=None, time_temp=None):
320328
"""
321329
consume the most recent note and return a prediction for the next note.
322330
@@ -335,7 +343,8 @@ def predict(self,
335343
allow_end: if False, zero probaility for sampling the end token
336344
sweep_time: if True, instead of sampling time, choose a diverse set of
337345
times and stack along the batch dimension
338-
trunc_time: if not None, truncate the time distribution to (lo, hi)
346+
min_time, max_time: if not None, truncate the time distribution
347+
time_temp: if not None, apply pseudo-temperature to the time distribution.
339348
340349
Returns: dict of
341350
'pitch': int. predicted MIDI number of next note.
@@ -366,7 +375,7 @@ def predict(self,
366375
self.projections,
367376
self.get_samplers(
368377
pitch_topk, index_pitch, allow_start, allow_end,
369-
sweep_time, trunc_time),
378+
sweep_time, min_time, max_time, bias_time, time_temp),
370379
self.embeddings,
371380
))
372381

@@ -388,7 +397,9 @@ def predict(self,
388397
for i,(item, embed) in enumerate(zip(fix, self.embeddings)):
389398
if item is None:
390399
if (
391-
i==1 and (sweep_time or (trunc_time is not None)) or
400+
i==1 and (sweep_time
401+
or (min_time is not None) or (max_time is not None)
402+
or (time_temp is not None)) or
392403
i==0 and pitch_topk
393404
):
394405
cons_idx.append(i)
@@ -403,6 +414,9 @@ def predict(self,
403414
perm = det_idx + undet_idx # permutation from the canonical order
404415
iperm = np.argsort(perm) # inverse permutation back to canonical order
405416

417+
md = ['pitch', 'time', 'vel']
418+
print([md[i] for i in perm])
419+
406420
# for each undetermined modality,
407421
# sample a new value conditioned on alteady determined ones
408422

0 commit comments

Comments
 (0)