1+ from pathlib import Path
2+ from multiprocessing import Pool
3+ import functools as ft
4+ import itertools as it
5+ from collections import defaultdict
6+ import random
7+
8+ from tqdm import tqdm
9+ import fire
10+ import mido
11+ # from pretty_midi import PrettyMIDI
12+ import torch
13+
14+ def pitch_collision (p1 , p2 ):
15+ return len (p1 ['pitches' ] & p2 ['pitches' ]) > 0
16+
17+ def time_collision (p1 , p2 ):
18+ # there is a collision unless one ends before the other starts
19+ if p1 ['bounds' ][1 ] < p2 ['bounds' ][0 ]:
20+ return False
21+ if p2 ['bounds' ][1 ] < p1 ['bounds' ][0 ]:
22+ return False
23+ return True
24+
25+ def note_collision (p1 , p2 ):
26+ parts = {1 :p1 , 2 :p2 }
27+ if not pitch_collision (p1 , p2 ):
28+ # print('early out: pitch')
29+ return False
30+ if not time_collision (p1 , p2 ):
31+ # print('early out: time')
32+ return False
33+ events = [
34+ * zip (p1 ['events' ]['time' ], p1 ['events' ]['vel' ], p1 ['events' ]['pitch' ], it .repeat (1 )),
35+ * zip (p2 ['events' ]['time' ], p2 ['events' ]['vel' ], p2 ['events' ]['pitch' ], it .repeat (2 ))]
36+ events .sort ()
37+ # print(events)
38+ held = {1 :set (), 2 :set ()}
39+ for (t ,v ,p ,part ) in events :
40+ if v > 0 :
41+ held [part ].add (p )
42+ else :
43+ held [part ].discard (p )
44+ if p in held [1 ] and p in held [2 ]:
45+ return True
46+ return False
47+
48+ def repair_events (part ):
49+ events = zip (part ['events' ]['time' ], part ['events' ]['vel' ], part ['events' ]['pitch' ])
50+ i = part ['inst' ]
51+ new_events = []
52+ # held = defaultdict(int)
53+ held = set ()
54+ n_double_off = 0
55+ n_double_on = 0
56+ for (t ,v ,p ) in events :
57+ if v > 0 :
58+ if p in held :
59+ # add note-offs for extra note-ons
60+ n_double_on += 1
61+ # print(f'double note on: part {part["part"]} pitch {p} time {t}')
62+ new_events .append ((t ,0 ,p ,i ))
63+ new_events .append ((t ,v ,p ,i ))
64+ held .add (p )
65+ else :
66+ if p in held :
67+ new_events .append ((t ,v ,p ,i ))
68+ held .remove (p )
69+ else :
70+ # delete extra note-offs
71+ n_double_off += 1
72+ # print(f'double note off: part {part["part"]} pitch {p} time {t}')
73+ # if n_double_off:
74+ # print(f'double note off: part {part["part"]} count {n_double_off}')
75+ # if n_double_on:
76+ # print(f'double note on: part {part["part"]} count {n_double_on}')
77+ new_events .sort ()
78+ return new_events
79+
80+ # number of channels with no program change
81+ class AnonTracks :
82+ def __init__ (self ):
83+ self .n = 0
84+ def __call__ (self ):
85+ self .n += 1
86+ return 256 + self .n
87+
88+ def process (fnames ):
89+ f ,g = fnames
90+
91+ # fix overlapping notes and add a margin for
92+ # dequantization at data loading time
93+ # time_margin = 1e-3
94+
95+ try :
96+ mid = mido .MidiFile (f )
97+ except Exception :
98+ tqdm .write (f'error opening { f } ' )
99+ return
100+
101+ if mid .type == 2 :
102+ tqdm .write (f'type 2 file { f } ' )
103+ return
104+
105+ release_velocity_counts = defaultdict (int )
106+ channel_counts = defaultdict (int )
107+ note_ons = 0
108+ pseudo_note_offs = 0
109+ note_offs = 0
110+
111+ tempo_changes = 0
112+
113+ parts = defaultdict (lambda : defaultdict (list ))
114+
115+ ticks_per_beat = mid .ticks_per_beat
116+
117+ next_anon = AnonTracks ()
118+
119+ # apparently delta times are within-track,
120+ # but tempo changes affect all tracks?
121+ # 500_000 is default us/beat
122+ tick_tempos = [(0 , 500_000 )]
123+ for track_idx , track in enumerate (mid .tracks ):
124+ channel_instruments = defaultdict (next_anon )
125+ time_ticks = 0
126+
127+ for msg in track :
128+ # time_seconds += mido.tick2second(msg.time, ticks_per_beat, tempo)
129+ time_ticks += msg .time
130+
131+ if msg .type == 'program_change' :
132+ channel_instruments [msg .channel ] = msg .program + 1 + 128 * int (msg .channel == 9 )
133+
134+ elif msg .type == 'set_tempo' :
135+ tick_tempos .append ((time_ticks , msg .tempo ))
136+ # print(f'tempo: {msg.tempo} at tick {time_ticks}')
137+ tempo_changes += 1
138+
139+ elif msg .type in ('note_on' , 'note_off' ):
140+
141+ # triple of track number, channel number, current instrument
142+ if msg .channel not in channel_instruments and msg .channel == 9 :
143+ channel_instruments [msg .channel ] = 129
144+ part = (track_idx , msg .channel , channel_instruments [msg .channel ])
145+ pitch = msg .note
146+
147+ if msg .type == 'note_on' :
148+ vel = msg .velocity
149+ if msg .velocity == 0 :
150+ pseudo_note_offs += 1
151+ else :
152+ note_ons += 1
153+ else :
154+ vel = 0
155+ note_offs += 1
156+ release_velocity_counts [msg .velocity ] += 1
157+
158+ channel_counts [msg .channel ] += 1
159+
160+ # absolute time, pitch, vel=0 for note-off
161+ parts [part ]['pitch' ].append (pitch )
162+ parts [part ]['time' ].append (time_ticks )#time_seconds)
163+ parts [part ]['vel' ].append (vel )
164+
165+ else :
166+ continue
167+
168+ # abs ticks -> abs seconds:
169+ # find tempo below ticks
170+ # abs seconds = convert(ticks-last_tempo_change_ticks) + abs_seconds(last_tempo_change_ticks)
171+ @ft .lru_cache (4096 )
172+ def abs_ticks_to_seconds (ticks ):
173+ if ticks == 0 : return 0
174+ last_change_ticks , tempo = next (filter (
175+ lambda x : x [0 ]< ticks ,
176+ reversed (tick_tempos )))
177+ return (
178+ mido .tick2second (ticks - last_change_ticks , ticks_per_beat , tempo )
179+ + abs_ticks_to_seconds (last_change_ticks ))
180+
181+ try :
182+ for part , events in parts .items ():
183+ events ['time' ] = [abs_ticks_to_seconds (t ) for t in events ['time' ]]
184+ except RecursionError :
185+ tqdm .write (f'too many tempo changes in { f } ' )
186+
187+ by_inst = defaultdict (list )
188+ for (track , channel , inst ), events in parts .items ():
189+ unique_pitch = set (events ['pitch' ])
190+ time_bounds = (min (events ['time' ]), max (events ['time' ]))
191+ # print(part, unique_pitch)
192+ by_inst [inst ].append ({
193+ 'size' :len (events ['pitch' ]),
194+ 'part' :(track , channel ),
195+ 'pitches' :unique_pitch ,
196+ 'bounds' :time_bounds ,
197+ 'events' :events
198+ })
199+
200+ complete_parts = []
201+ # within each instrument,
202+ for inst , parts in by_inst .items ():
203+ assigned_parts = []
204+ # sort parts by number of events
205+ parts .sort (key = lambda e : e ['size' ])
206+
207+ n_parts = 0
208+
209+ # pop off parts and check for collisions
210+ while len (parts ):
211+ no_collide = None
212+ part = parts .pop ()
213+ # try for each assigned part:
214+ for ap in assigned_parts :
215+ # check for collision
216+ # (can speed this up by checking for pitch and bounding time collisions first)
217+ if not note_collision (part , ap ):
218+ # print('no collision')
219+ no_collide = ap
220+ break
221+ # else:
222+ # print('collision')
223+ # if all collide, assign the part to a new anonymous instrument
224+ if no_collide is None :
225+ part ['inst' ] = n_parts * 1000 + inst
226+ # print(f'assigned instrument {part["inst"]} for {part["part"]}')
227+ n_parts += 1
228+ assigned_parts .append (part )
229+ # otherwise merge with non-colliding part
230+ else :
231+ # print(f'merge {inst} {ap["part"]} < {part["part"]}')
232+ for k in ('pitch' , 'time' , 'vel' ):
233+ ap ['events' ][k ].extend (part ['events' ][k ])
234+
235+ for p in assigned_parts :
236+ complete_parts .append (repair_events (p ))
237+
238+ events = sum (complete_parts , start = [])
239+
240+ if len (events ) < 64 :
241+ return
242+
243+ events .sort ()
244+
245+ time , vel , pitch , prog = zip (* events )
246+ torch .save (dict (
247+ time = torch .DoubleTensor (time ),
248+ pitch = torch .LongTensor (pitch ),
249+ velocity = torch .LongTensor (vel ),
250+ program = torch .LongTensor (prog )
251+ ), g .with_suffix ('.pkl' ))
252+
253+
254+ def main (data_path , dest_path , n_jobs = 4 ):
255+ data_dir = Path (data_path )
256+ files = list (data_dir .glob ('**/*.mid' ))
257+ files_out = [
258+ (Path (dest_path ) / f .relative_to (data_path ))
259+ for f in files ]
260+ parents = {g .parent for g in files_out }
261+ for parent in list (parents ):
262+ parent .mkdir (parents = True , exist_ok = True )
263+
264+ # files = files[:1000]
265+
266+ with Pool (n_jobs ) as pool :
267+ for _ in tqdm (pool .imap_unordered (process , zip (files , files_out ), 32 )):
268+ pass
269+
270+
271+ if __name__ == "__main__" :
272+ fire .Fire (main )
273+
0 commit comments