Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 53eb691

Browse files
pitch_index argument to select pitch with nth greatest likelihood
1 parent 69e5b48 commit 53eb691

2 files changed

Lines changed: 10 additions & 10 deletions

File tree

examples/notepredictor/autopitch.scd

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ MIDIdef.noteOn(\input, {
9292
// fixing dt and vel to performed values so just pitch is predicted
9393
b.sendMsg("/predictor/predict",
9494
\pitch, ~last_pitch, \time, ~last_dt, \vel, ~last_vel,
95-
\fix_time, dt, \fix_vel, vel);
95+
\index_pitch, pitch, \fix_time, dt, \fix_vel, vel);
9696

9797
~last_dt = dt;
9898
~last_vel = vel;
@@ -122,12 +122,9 @@ OSCdef(\return, {
122122

123123
~reset.()
124124
// send a note manually if you don't have a midi controller
125-
MIDIdef.all[\input].func.(64)
125+
MIDIdef.all[\input].func.(64, 16) //velocity, "pitch"
126126

127-
(
128-
MIDIdef.all[\input].func.(64);
129-
MIDIdef.all[\input].func.(64);
130-
)
127+
"abc"+0
131128

132129
// load another model
133130
// b.sendMsg("/predictor/load", "/path/to/checkpoint");

notepredictor/notepredictor/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,16 @@ def cell_state_names(self):
193193
def cell_state(self):
194194
return tuple(getattr(self, n) for n in self.cell_state_names())
195195

196-
def get_samplers(self, allow_start=False, allow_end=False):
196+
def get_samplers(self, index_pitch=None, allow_start=False, allow_end=False):
197197
def sample_pitch(x):
198198
if not allow_start:
199199
x[...,self.start_token] = -np.inf
200200
if not allow_end:
201201
x[...,self.end_token] = -np.inf
202-
return D.Categorical(logits=x).sample()
202+
if index_pitch is not None:
203+
return x.argsort(-1, True)[...,index_pitch]
204+
else:
205+
return D.Categorical(logits=x).sample()
203206

204207
return (
205208
sample_pitch,
@@ -294,7 +297,7 @@ def forward(self, pitches, times, velocities, validation=False):
294297
def predict(self,
295298
pitch, time, vel,
296299
fix_pitch=None, fix_time=None, fix_vel=None,
297-
allow_end=False, allow_start=False):
300+
index_pitch=None, allow_start=False, allow_end=False):
298301
"""
299302
supply the most recent note and return a prediction for the next note.
300303
@@ -331,7 +334,7 @@ def predict(self,
331334

332335
modalities = list(zip(
333336
self.projections,
334-
self.get_samplers(allow_start=allow_start, allow_end=allow_end),
337+
self.get_samplers(index_pitch, allow_start, allow_end),
335338
self.embeddings,
336339
))
337340

0 commit comments

Comments
 (0)