@@ -17,7 +17,7 @@ def __init__(self, data_dir, batch_len, transpose=5, speed=0.1, glob='**/*.pkl',
1717 self .transpose = transpose
1818 self .speed = speed
1919 self .start_token = 128
20- self .n_anon = 8
20+ self .n_anon = 32
2121 self .prog_start_token = 0
2222 # self.clamp_time = clamp_time
2323 self .testing = False
@@ -26,47 +26,64 @@ def __init__(self, data_dir, batch_len, transpose=5, speed=0.1, glob='**/*.pkl',
2626 def __len__ (self ):
2727 return len (self .files )
2828
29- def _random_map_anonymous_instruments (self , program : torch .Tensor ) -> torch .Tensor :
29+ def _remap_anonymous_instruments (self , program : torch .Tensor ) -> torch .Tensor :
3030 """
31- Randomly map instruments to eight additional ‘anonymous’ melodic and drum identities
32- with a probability of 10% per instrument, without replacement.
33-
34- The input program should contain melodic instruments from MIDI note numbers 0-127 and
35- drum instruments from 128-255. Anonymous instruments are mapped to subsequent note numbers.
31+ Randomly map instruments to additional ‘anonymous’ melodic and drum identities
32+ with a probability of 10% per instrument, without replacement.
33+ Also map any parts > 256 to appropriate anonymous ids.
3634 """
37- unique_melodic = program .masked_select (program < 128 ).unique ()
38- unique_drum = program .masked_select (program >= 128 ).unique ()
39-
40- anon_melodic_start = 256
35+ orig_program = program % 1000
36+ is_melodic = (orig_program <= 128 ) | (orig_program > 256 )
37+ is_anon = (program > 256 )
38+ named_melodic = list (program .masked_select (is_melodic & ~ is_anon ).unique ())
39+ anon_melodic = list (program .masked_select (is_melodic & is_anon ).unique ())
40+ named_drum = list (program .masked_select (~ is_melodic & ~ is_anon ).unique ())
41+ anon_drum = list (program .masked_select (~ is_melodic & is_anon ).unique ())
42+
43+ anon_melodic_start = 257
4144 anon_drum_start = anon_melodic_start + self .n_anon
42- anon_melodic = torch .randperm (self .n_anon ) + anon_melodic_start # array of anon melodic programs
43- anon_drum = torch .randperm (self .n_anon ) + anon_drum_start # array of anon drum programs
45+ perm_anon_melodic = torch .randperm (self .n_anon ) + anon_melodic_start
46+ perm_anon_drum = torch .randperm (self .n_anon ) + anon_drum_start
4447
45- i = 0
46- for pr in unique_melodic :
48+ for pr in named_melodic :
4749 if torch .rand ((1 ,)) < 0.1 :
48- program [program == pr ] = anon_melodic [i ]
49- i += 1
50- if i >= len (anon_melodic ): # no more anon instruments to write to
51- break
52- i = 0
53- for pr in unique_drum :
50+ anon_melodic .append (pr )
51+ for pr in named_drum :
5452 if torch .rand ((1 ,)) < 0.1 :
55- program [program == pr ] = anon_drum [i ]
56- i += 1
57- if i >= len (anon_drum ): # no more anon instruments to write to
58- break
53+ anon_drum .append (pr )
54+
55+ new_program = program .clone ()
56+
57+ if len (anon_melodic )> self .n_anon :
58+ print (f'warning: { anon_melodic } > { self .n_anon } anon melodic instruments' )
59+ if len (anon_drum )> self .n_anon :
60+ print (f'warning: { anon_drum } > { self .n_anon } anon drum instruments' )
5961
60- return program
62+ i = 0
63+ for pr in anon_melodic :
64+ new_program [program == pr ] = perm_anon_melodic [i % self .n_anon ]
65+ i += 1
66+ i = 0
67+ for pr in anon_drum :
68+ new_program [program == pr ] = perm_anon_drum [i % self .n_anon ]
69+ i += 1
70+
71+ # print(new_program.unique())
72+
73+ return new_program
6174
6275 def __getitem__ (self , idx ):
6376 f = self .files [idx ]
6477 item = torch .load (f )
65- program = item ['program' ] # 1-d LongTensor of MIDI programs 0-255
66- # (128-255 are drums)
78+ program = item ['program' ] # 1-d LongTensor of MIDI programs
79+ # 0 is unused
80+ # (128-256 are drums)
81+ # 257+ are 'true anonymous' (no program change on track)
82+ # (drums with no PC are just mapped to 129)
83+ # N + 1000*K is the Kth additional part for instrument N
6784 pitch = item ['pitch' ] # 1-d LongTensor of MIDI pitches 0-127
68- time = item ['time' ]
69- velocity = item ['velocity' ]
85+ time = item ['time' ] # 1-d DoubleTensor of absolute times in seconds
86+ velocity = item ['velocity' ] # 1-d LongTensor of MIDI velocities 0-127
7087
7188 assert len (pitch ) == len (time )
7289
@@ -79,24 +96,24 @@ def __getitem__(self, idx):
7996 )
8097 pitch = pitch + transpose
8198
82- # randomly map instruments to 'anonymous melodic' and 'anonymous drum'
83- program = self ._random_map_anonymous_instruments (program )
84-
85- # shift from 0-index to general MIDI 1-index; reserve 0 for start token
86- program += 1
99+ # scramble anonymous and extra parts to 'anonymous melodic' and 'anonymous drum' parts
100+ program = self ._remap_anonymous_instruments (program )
87101
88- time_margin = 1e-3 # hardcoded since it should match prep script
102+ time_margin = 1e-3
89103
90104 # dequantize: add noise up to +/- margin
91- time = time + (torch .rand_like (time )* 2 - 1 )* time_margin
105+ # move note-ons later, note-offs earlier
106+ time = (time +
107+ torch .rand_like (time ) * ((velocity == 0 ).double ()* 2 - 1 ) * time_margin
108+ )
92109 # random augment tempo
93110 time = time * (1 + random .random ()* self .speed * 2 - self .speed )
94111
95112 # dequantize velocity
96113 velocity = velocity .float ()
97114 velocity = (
98115 velocity +
99- (torch .rand_like (time )- 0.5 ) * ((velocity > 0 ) & (velocity < 127 )).float ()
116+ (torch .rand_like (time , dtype = torch . float )- 0.5 ) * ((velocity > 0 ) & (velocity < 127 )).float ()
100117 ).clamp (0. , 127. )
101118 # random velocity curve
102119 # take care not to map any positive values closer to 0 than 1
@@ -110,7 +127,7 @@ def __getitem__(self, idx):
110127 # sort (using argsort on time and indexing the rest)
111128 # compute delta time
112129 time , idx = time .sort ()
113- time = torch .cat ((time .new_zeros ((1 ,)), time )).diff (1 )
130+ time = torch .cat ((time .new_zeros ((1 ,)), time )).diff (1 ). float ()
114131 program = program [idx ]
115132 pitch = pitch [idx ]
116133 velocity = velocity [idx ]
0 commit comments