Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 404f1cf

Browse files
new dataloading, move to 32 anon mel/drum, validate with fixed length, fixes
1 parent 865950b commit 404f1cf

3 files changed

Lines changed: 66 additions & 45 deletions

File tree

notochord/notochord/data.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, data_dir, batch_len, transpose=5, speed=0.1, glob='**/*.pkl',
1717
self.transpose = transpose
1818
self.speed = speed
1919
self.start_token = 128
20-
self.n_anon = 8
20+
self.n_anon = 32
2121
self.prog_start_token = 0
2222
# self.clamp_time = clamp_time
2323
self.testing = False
@@ -26,47 +26,64 @@ def __init__(self, data_dir, batch_len, transpose=5, speed=0.1, glob='**/*.pkl',
2626
def __len__(self):
2727
return len(self.files)
2828

29-
def _random_map_anonymous_instruments(self, program: torch.Tensor) -> torch.Tensor:
29+
def _remap_anonymous_instruments(self, program: torch.Tensor) -> torch.Tensor:
3030
"""
31-
Randomly map instruments to eight additional ‘anonymous’ melodic and drum identities
32-
with a probability of 10% per instrument, without replacement.
33-
34-
The input program should contain melodic instruments from MIDI note numbers 0-127 and
35-
drum instruments from 128-255. Anonymous instruments are mapped to subsequent note numbers.
31+
Randomly map instruments to additional ‘anonymous’ melodic and drum identities
32+
with a probability of 10% per instrument, without replacement.
33+
Also map any parts > 256 to appropriate anonymous ids.
3634
"""
37-
unique_melodic = program.masked_select(program<128).unique()
38-
unique_drum = program.masked_select(program>=128).unique()
39-
40-
anon_melodic_start = 256
35+
orig_program = program%1000
36+
is_melodic = (orig_program<=128) | (orig_program>256)
37+
is_anon = (program > 256)
38+
named_melodic = list(program.masked_select(is_melodic & ~is_anon).unique())
39+
anon_melodic = list(program.masked_select(is_melodic & is_anon).unique())
40+
named_drum = list(program.masked_select(~is_melodic & ~is_anon).unique())
41+
anon_drum = list(program.masked_select(~is_melodic & is_anon).unique())
42+
43+
anon_melodic_start = 257
4144
anon_drum_start = anon_melodic_start + self.n_anon
42-
anon_melodic = torch.randperm(self.n_anon) + anon_melodic_start # array of anon melodic programs
43-
anon_drum = torch.randperm(self.n_anon) + anon_drum_start # array of anon drum programs
45+
perm_anon_melodic = torch.randperm(self.n_anon) + anon_melodic_start
46+
perm_anon_drum = torch.randperm(self.n_anon) + anon_drum_start
4447

45-
i = 0
46-
for pr in unique_melodic:
48+
for pr in named_melodic:
4749
if torch.rand((1,)) < 0.1:
48-
program[program==pr] = anon_melodic[i]
49-
i += 1
50-
if i >= len(anon_melodic): # no more anon instruments to write to
51-
break
52-
i = 0
53-
for pr in unique_drum:
50+
anon_melodic.append(pr)
51+
for pr in named_drum:
5452
if torch.rand((1,)) < 0.1:
55-
program[program==pr] = anon_drum[i]
56-
i += 1
57-
if i >= len(anon_drum): # no more anon instruments to write to
58-
break
53+
anon_drum.append(pr)
54+
55+
new_program = program.clone()
56+
57+
if len(anon_melodic)>self.n_anon:
58+
print(f'warning: {anon_melodic} > {self.n_anon} anon melodic instruments')
59+
if len(anon_drum)>self.n_anon:
60+
print(f'warning: {anon_drum} > {self.n_anon} anon drum instruments')
5961

60-
return program
62+
i = 0
63+
for pr in anon_melodic:
64+
new_program[program==pr] = perm_anon_melodic[i%self.n_anon]
65+
i += 1
66+
i = 0
67+
for pr in anon_drum:
68+
new_program[program==pr] = perm_anon_drum[i%self.n_anon]
69+
i += 1
70+
71+
# print(new_program.unique())
72+
73+
return new_program
6174

6275
def __getitem__(self, idx):
6376
f = self.files[idx]
6477
item = torch.load(f)
65-
program = item['program'] # 1-d LongTensor of MIDI programs 0-255
66-
# (128-255 are drums)
78+
program = item['program'] # 1-d LongTensor of MIDI programs
79+
# 0 is unused
80+
# (128-256 are drums)
81+
# 257+ are 'true anonymous' (no program change on track)
82+
# (drums with no PC are just mapped to 129)
83+
# N + 1000*K is the Kth additional part for instrument N
6784
pitch = item['pitch'] # 1-d LongTensor of MIDI pitches 0-127
68-
time = item['time']
69-
velocity = item['velocity']
85+
time = item['time'] # 1-d DoubleTensor of absolute times in seconds
86+
velocity = item['velocity'] # 1-d LongTensor of MIDI velocities 0-127
7087

7188
assert len(pitch) == len(time)
7289

@@ -79,24 +96,24 @@ def __getitem__(self, idx):
7996
)
8097
pitch = pitch + transpose
8198

82-
# randomly map instruments to 'anonymous melodic' and 'anonymous drum'
83-
program = self._random_map_anonymous_instruments(program)
84-
85-
# shift from 0-index to general MIDI 1-index; reserve 0 for start token
86-
program += 1
99+
# scramble anonymous and extra parts to 'anonymous melodic' and 'anonymous drum' parts
100+
program = self._remap_anonymous_instruments(program)
87101

88-
time_margin = 1e-3 # hardcoded since it should match prep script
102+
time_margin = 1e-3
89103

90104
# dequantize: add noise up to +/- margin
91-
time = time + (torch.rand_like(time)*2-1)*time_margin
105+
# move note-ons later, note-offs earlier
106+
time = (time +
107+
torch.rand_like(time) * ((velocity==0).double()*2-1) * time_margin
108+
)
92109
# random augment tempo
93110
time = time * (1 + random.random()*self.speed*2 - self.speed)
94111

95112
# dequantize velocity
96113
velocity = velocity.float()
97114
velocity = (
98115
velocity +
99-
(torch.rand_like(time)-0.5) * ((velocity>0) & (velocity<127)).float()
116+
(torch.rand_like(time, dtype=torch.float)-0.5) * ((velocity>0) & (velocity<127)).float()
100117
).clamp(0., 127.)
101118
# random velocity curve
102119
# take care not to map any positive values closer to 0 than 1
@@ -110,7 +127,7 @@ def __getitem__(self, idx):
110127
# sort (using argsort on time and indexing the rest)
111128
# compute delta time
112129
time, idx = time.sort()
113-
time = torch.cat((time.new_zeros((1,)), time)).diff(1)
130+
time = torch.cat((time.new_zeros((1,)), time)).diff(1).float()
114131
program = program[idx]
115132
pitch = pitch[idx]
116133
velocity = velocity[idx]

notochord/notochord/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(self,
159159
mlp_layers=0,
160160
dropout=0.1, norm=None,
161161
num_pitches=128,
162-
num_instruments=272,
162+
num_instruments=320,
163163
time_sines=128, vel_sines=128,
164164
time_bounds=(0,10), time_components=32, time_res=1e-2,
165165
vel_components=16
@@ -188,10 +188,12 @@ def __init__(self,
188188
# embeddings for inputs
189189
self.instrument_emb = nn.Embedding(self.instrument_domain, emb_size)
190190
self.pitch_emb = nn.Embedding(self.pitch_domain, emb_size)
191-
self.time_emb = torch.jit.script(SineEmbedding(
191+
self.time_emb = torch.jit.script(
192+
SineEmbedding(
192193
time_sines, emb_size, 1e-3, 30, scale='log'))
193194
# self.vel_emb = MixEmbedding(emb_size, (0, 127))
194-
self.vel_emb = torch.jit.script(SineEmbedding(
195+
self.vel_emb = torch.jit.script(
196+
SineEmbedding(
195197
vel_sines, emb_size, 2, 512, scale='lin'))
196198

197199
# RNN backbone
@@ -436,7 +438,7 @@ def forward(self, instruments, pitches, times, velocities, ends,
436438
def is_drum(self, inst):
437439
# TODO: add a constructor argument to specify which are drums
438440
# hardcoded for now
439-
return inst > 128 and inst < 257 or inst > 264
441+
return inst > 128 and inst < 257 or inst > 288
440442

441443

442444
def feed(self, inst, pitch, time, vel):

notochord/notochord/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self,
5454
"""
5555
kw['model'] = model = get_class_defaults(model_cls) | model
5656
model['num_pitches'] = 128
57-
model['num_instruments'] = 272
57+
model['num_instruments'] = 320
5858
# model['time_bounds'] = clamp_time
5959

6060
# assign all arguments to self by default
@@ -237,7 +237,8 @@ def train(self):
237237

238238
##### validation loop
239239
def run_validation():
240-
logs = self._validate(valid_loader)['logs']
240+
self.dataset.batch_len = self.dataset.max_test_len
241+
logs = self._validate(valid_loader, testing=False)['logs']
241242
self.log('valid', logs)
242243

243244
epoch_size = self.epoch_size or len(train_loader)
@@ -251,6 +252,7 @@ def run_validation():
251252
##### training loop
252253
self.model.train()
253254
self.dataset.testing = False
255+
self.dataset.batch_len = self.batch_len
254256
for batch in tqdm(it.islice(train_loader, epoch_size),
255257
desc=f'training epoch {self.epoch}', total=epoch_size):
256258
mask = batch['mask'].to(self.device, non_blocking=True)

0 commit comments

Comments
 (0)