Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 865950b

Browse files
add new prep script
1 parent 311d6da commit 865950b

1 file changed

Lines changed: 273 additions & 0 deletions

File tree

notochord/scripts/lakh_prep_2.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
from pathlib import Path
2+
from multiprocessing import Pool
3+
import functools as ft
4+
import itertools as it
5+
from collections import defaultdict
6+
import random
7+
8+
from tqdm import tqdm
9+
import fire
10+
import mido
11+
# from pretty_midi import PrettyMIDI
12+
import torch
13+
14+
def pitch_collision(p1, p2):
15+
return len(p1['pitches'] & p2['pitches']) > 0
16+
17+
def time_collision(p1, p2):
18+
# there is a collision unless one ends before the other starts
19+
if p1['bounds'][1] < p2['bounds'][0]:
20+
return False
21+
if p2['bounds'][1] < p1['bounds'][0]:
22+
return False
23+
return True
24+
25+
def note_collision(p1, p2):
26+
parts = {1:p1, 2:p2}
27+
if not pitch_collision(p1, p2):
28+
# print('early out: pitch')
29+
return False
30+
if not time_collision(p1, p2):
31+
# print('early out: time')
32+
return False
33+
events = [
34+
*zip(p1['events']['time'], p1['events']['vel'], p1['events']['pitch'], it.repeat(1)),
35+
*zip(p2['events']['time'], p2['events']['vel'], p2['events']['pitch'], it.repeat(2))]
36+
events.sort()
37+
# print(events)
38+
held = {1:set(), 2:set()}
39+
for (t,v,p,part) in events:
40+
if v > 0:
41+
held[part].add(p)
42+
else:
43+
held[part].discard(p)
44+
if p in held[1] and p in held[2]:
45+
return True
46+
return False
47+
48+
def repair_events(part):
49+
events = zip(part['events']['time'], part['events']['vel'], part['events']['pitch'])
50+
i = part['inst']
51+
new_events = []
52+
# held = defaultdict(int)
53+
held = set()
54+
n_double_off = 0
55+
n_double_on = 0
56+
for (t,v,p) in events:
57+
if v > 0:
58+
if p in held:
59+
# add note-offs for extra note-ons
60+
n_double_on += 1
61+
# print(f'double note on: part {part["part"]} pitch {p} time {t}')
62+
new_events.append((t,0,p,i))
63+
new_events.append((t,v,p,i))
64+
held.add(p)
65+
else:
66+
if p in held:
67+
new_events.append((t,v,p,i))
68+
held.remove(p)
69+
else:
70+
# delete extra note-offs
71+
n_double_off += 1
72+
# print(f'double note off: part {part["part"]} pitch {p} time {t}')
73+
# if n_double_off:
74+
# print(f'double note off: part {part["part"]} count {n_double_off}')
75+
# if n_double_on:
76+
# print(f'double note on: part {part["part"]} count {n_double_on}')
77+
new_events.sort()
78+
return new_events
79+
80+
# number of channels with no program change
81+
class AnonTracks:
82+
def __init__(self):
83+
self.n = 0
84+
def __call__(self):
85+
self.n += 1
86+
return 256+self.n
87+
88+
def process(fnames):
89+
f,g = fnames
90+
91+
# fix overlapping notes and add a margin for
92+
# dequantization at data loading time
93+
# time_margin = 1e-3
94+
95+
try:
96+
mid = mido.MidiFile(f)
97+
except Exception:
98+
tqdm.write(f'error opening {f}')
99+
return
100+
101+
if mid.type==2:
102+
tqdm.write(f'type 2 file {f}')
103+
return
104+
105+
release_velocity_counts = defaultdict(int)
106+
channel_counts = defaultdict(int)
107+
note_ons = 0
108+
pseudo_note_offs = 0
109+
note_offs = 0
110+
111+
tempo_changes = 0
112+
113+
parts = defaultdict(lambda: defaultdict(list))
114+
115+
ticks_per_beat = mid.ticks_per_beat
116+
117+
next_anon = AnonTracks()
118+
119+
# apparently delta times are within-track,
120+
# but tempo changes affect all tracks?
121+
# 500_000 is default us/beat
122+
tick_tempos = [(0, 500_000)]
123+
for track_idx, track in enumerate(mid.tracks):
124+
channel_instruments = defaultdict(next_anon)
125+
time_ticks = 0
126+
127+
for msg in track:
128+
# time_seconds += mido.tick2second(msg.time, ticks_per_beat, tempo)
129+
time_ticks += msg.time
130+
131+
if msg.type=='program_change':
132+
channel_instruments[msg.channel] = msg.program + 1 + 128*int(msg.channel==9)
133+
134+
elif msg.type=='set_tempo':
135+
tick_tempos.append((time_ticks, msg.tempo))
136+
# print(f'tempo: {msg.tempo} at tick {time_ticks}')
137+
tempo_changes += 1
138+
139+
elif msg.type in ('note_on', 'note_off'):
140+
141+
# triple of track number, channel number, current instrument
142+
if msg.channel not in channel_instruments and msg.channel==9:
143+
channel_instruments[msg.channel] = 129
144+
part = (track_idx, msg.channel, channel_instruments[msg.channel])
145+
pitch = msg.note
146+
147+
if msg.type=='note_on':
148+
vel = msg.velocity
149+
if msg.velocity==0:
150+
pseudo_note_offs += 1
151+
else:
152+
note_ons += 1
153+
else:
154+
vel = 0
155+
note_offs += 1
156+
release_velocity_counts[msg.velocity] += 1
157+
158+
channel_counts[msg.channel] += 1
159+
160+
# absolute time, pitch, vel=0 for note-off
161+
parts[part]['pitch'].append(pitch)
162+
parts[part]['time'].append(time_ticks)#time_seconds)
163+
parts[part]['vel'].append(vel)
164+
165+
else:
166+
continue
167+
168+
# abs ticks -> abs seconds:
169+
# find tempo below ticks
170+
# abs seconds = convert(ticks-last_tempo_change_ticks) + abs_seconds(last_tempo_change_ticks)
171+
@ft.lru_cache(4096)
172+
def abs_ticks_to_seconds(ticks):
173+
if ticks==0: return 0
174+
last_change_ticks, tempo = next(filter(
175+
lambda x: x[0]<ticks,
176+
reversed(tick_tempos)))
177+
return (
178+
mido.tick2second(ticks-last_change_ticks, ticks_per_beat, tempo)
179+
+ abs_ticks_to_seconds(last_change_ticks))
180+
181+
try:
182+
for part, events in parts.items():
183+
events['time'] = [abs_ticks_to_seconds(t) for t in events['time']]
184+
except RecursionError:
185+
tqdm.write(f'too many tempo changes in {f}')
186+
187+
by_inst = defaultdict(list)
188+
for (track, channel, inst), events in parts.items():
189+
unique_pitch = set(events['pitch'])
190+
time_bounds = (min(events['time']), max(events['time']))
191+
# print(part, unique_pitch)
192+
by_inst[inst].append({
193+
'size':len(events['pitch']),
194+
'part':(track, channel),
195+
'pitches':unique_pitch,
196+
'bounds':time_bounds,
197+
'events':events
198+
})
199+
200+
complete_parts = []
201+
# within each instrument,
202+
for inst, parts in by_inst.items():
203+
assigned_parts = []
204+
# sort parts by number of events
205+
parts.sort(key=lambda e: e['size'])
206+
207+
n_parts = 0
208+
209+
# pop off parts and check for collisions
210+
while len(parts):
211+
no_collide = None
212+
part = parts.pop()
213+
# try for each assigned part:
214+
for ap in assigned_parts:
215+
# check for collision
216+
# (can speed this up by checking for pitch and bounding time collisions first)
217+
if not note_collision(part, ap):
218+
# print('no collision')
219+
no_collide = ap
220+
break
221+
# else:
222+
# print('collision')
223+
# if all collide, assign the part to a new anonymous instrument
224+
if no_collide is None:
225+
part['inst'] = n_parts*1000 + inst
226+
# print(f'assigned instrument {part["inst"]} for {part["part"]}')
227+
n_parts += 1
228+
assigned_parts.append(part)
229+
# otherwise merge with non-colliding part
230+
else:
231+
# print(f'merge {inst} {ap["part"]} < {part["part"]}')
232+
for k in ('pitch', 'time', 'vel'):
233+
ap['events'][k].extend(part['events'][k])
234+
235+
for p in assigned_parts:
236+
complete_parts.append(repair_events(p))
237+
238+
events = sum(complete_parts, start=[])
239+
240+
if len(events) < 64:
241+
return
242+
243+
events.sort()
244+
245+
time, vel, pitch, prog = zip(*events)
246+
torch.save(dict(
247+
time=torch.DoubleTensor(time),
248+
pitch=torch.LongTensor(pitch),
249+
velocity=torch.LongTensor(vel),
250+
program=torch.LongTensor(prog)
251+
), g.with_suffix('.pkl'))
252+
253+
254+
def main(data_path, dest_path, n_jobs=4):
255+
data_dir = Path(data_path)
256+
files = list(data_dir.glob('**/*.mid'))
257+
files_out = [
258+
(Path(dest_path) / f.relative_to(data_path))
259+
for f in files]
260+
parents = {g.parent for g in files_out}
261+
for parent in list(parents):
262+
parent.mkdir(parents=True, exist_ok=True)
263+
264+
# files = files[:1000]
265+
266+
with Pool(n_jobs) as pool:
267+
for _ in tqdm(pool.imap_unordered(process, zip(files, files_out), 32)):
268+
pass
269+
270+
271+
if __name__=="__main__":
272+
fire.Fire(main)
273+

0 commit comments

Comments
 (0)