Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 51c3c3a

Browse files
fix transposition for anonymous instruments; remove old code
1 parent 756717f commit 51c3c3a

2 files changed

Lines changed: 10 additions & 63 deletions

File tree

notochord/notochord/data.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@ def __init__(self, data_dir, batch_len, transpose=5, speed=0.1, glob='**/*.pkl',
2525

2626
def __len__(self):
2727
return len(self.files)
28+
29+
def is_melodic(self, program):
30+
orig_program = program%1000
31+
return (orig_program<=128) | (orig_program>256)
32+
33+
def is_anon(self, program):
34+
return program > 256
2835

2936
def _remap_anonymous_instruments(self, program: torch.Tensor) -> torch.Tensor:
3037
"""
3138
Randomly map instruments to additional ‘anonymous’ melodic and drum identities
3239
with a probability of 10% per instrument, without replacement.
3340
Also map any parts > 256 to appropriate anonymous ids.
3441
"""
35-
orig_program = program%1000
36-
is_melodic = (orig_program<=128) | (orig_program>256)
37-
is_anon = (program > 256)
42+
is_melodic = self.is_melodic(program)
43+
is_anon = self.is_anon(program)
3844
named_melodic = list(program.masked_select(is_melodic & ~is_anon).unique())
3945
anon_melodic = list(program.masked_select(is_melodic & is_anon).unique())
4046
named_drum = list(program.masked_select(~is_melodic & ~is_anon).unique())
@@ -92,7 +98,7 @@ def __getitem__(self, idx):
9298
transpose_up = min(self.transpose, 127-pitch.max())
9399
transpose = (
94100
random.randint(-transpose_down, transpose_up)
95-
* (program<128) # don't transpose drums
101+
* self.is_melodic(program).long() # don't transpose drums
96102
)
97103
pitch = pitch + transpose
98104

notochord/notochord/model.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -91,65 +91,6 @@ def get_norm():
9191
def forward(self, x):
9292
return self.net(x)
9393

94-
# class ModalityTransformer(nn.Module):
95-
# """
96-
# Model joint distribution of note modalities (e.g. pitch, time, velocity).
97-
98-
# This is an autoregressive Transformer model for the *internal* structure of notes.
99-
# It is *not* autoregressive in time, but in modality.
100-
# At training time, it executes in parallel over all timesteps and modalities, with
101-
# time dependencies provided via the RNN backbone.
102-
103-
# At sampling time it is called serially, one modality at a time,
104-
# repeatedly at each time step.
105-
106-
# Inspired by XLNet: http://arxiv.org/abs/1906.08237
107-
# """
108-
# def __init__(self, input_size, hidden_size, heads=4, layers=1):
109-
# super().__init__()
110-
# self.net = nn.TransformerDecoder(
111-
# nn.TransformerDecoderLayer(
112-
# input_size, heads, hidden_size, norm_first=False
113-
# ), layers)
114-
115-
# def forward(self, ctx, h_ctx, h_tgt):
116-
# """
117-
# Args:
118-
# ctx: list of Tensor[batch x time x input_size], length note_dim-1
119-
# these are the embedded ground truth values
120-
# h_ctx: Tensor[batch x time x input_size]
121-
# projection of RNN state (need something to attend to when ctx is empty)
122-
# h_tgt: list of Tensor[batch x time x input_size], length note_dim
123-
# these are projections of the RNN state for each target,
124-
# which the Transformer will map to distribution parameters.
125-
# """
126-
# # explicitly broadcast
127-
# h_ctx, *ctx = torch.broadcast_tensors(h_ctx, *ctx)
128-
# h_ctx, *h_tgt = torch.broadcast_tensors(h_ctx, *h_tgt)
129-
130-
# # h_tgt is 'target' w.r.t TransformerDecoder
131-
# # h_ctx and context are 'memory'
132-
# batch_size = h_ctx.shape[0]*h_ctx.shape[1]
133-
# # fold time into batch, stack modes
134-
# tgt = torch.stack([
135-
# item.reshape(batch_size,-1)
136-
# for item in h_tgt
137-
# ],0)
138-
# mem = torch.stack([
139-
# item.reshape(batch_size,-1)
140-
# for item in [h_ctx, *ctx]
141-
# ],0)
142-
# # now "time"(mode) x "batch"(+time) x channel
143-
144-
# # generate a mask
145-
# # this is both the target and memory mask
146-
# # masking is such that each target can only depend on "previous" context
147-
# n = len(h_tgt)
148-
# mask = ~tgt.new_ones((n,n), dtype=bool).tril()
149-
150-
# x = self.net(tgt, mem, mask, mask)
151-
# return list(x.reshape(n, *h_ctx.shape).unbind(0))
152-
15394

15495
class Notochord(nn.Module):
15596
# note: use named arguments only for benefit of training script

0 commit comments

Comments
 (0)