11import math
22
3+ import numpy as np
4+
35import torch
46from torch import nn
57import torch .nn .functional as F
@@ -19,10 +21,11 @@ def __init__(self, n, w0=1e-3, interval=1.08):
1921 super ().__init__ ()
2022 self .n = n
2123 self .register_buffer ('fs' , interval ** (- torch .arange (n )) / w0 * 2 * math .pi )
24+ self .proj = nn .Linear (n ,n )
2225
2326 def forward (self , x ):
2427 x = x [...,None ] * self .fs
25- return x .sin ()
28+ return self . proj ( x .sin () )
2629
2730class MixEmbedding (nn .Module ):
2831 def __init__ (self , n , domain = (0 ,1 )):
@@ -42,49 +45,85 @@ def forward(self, x):
4245 Returns:
4346 Tensor[...,n]
4447 """
45- x = (x - self .domain [0 ])* (self .domain [1 ] - self .domain [0 ])
48+ x = (x - self .domain [0 ])/ (self .domain [1 ] - self .domain [0 ])
4649 x = x [...,None ]
4750 return self .hi * x + self .lo * (1 - x )
4851
49- class SelfGated (nn .Module ):
50- def __init__ (self ):
52+ # class SelfGated(nn.Module):
53+ # def __init__(self):
54+ # super().__init__()
55+
56+ # def forward(self, x):
57+ # a, b = x.chunk(2, -1)
58+ # return a * b.sigmoid()
59+
60+ # class SelfGatedMLP(nn.Module):
61+ # def __init__(self, input, hidden, output, layers, dropout=0):
62+ # super().__init__()
63+ # h = input
64+ # def get_dropout():
65+ # if dropout > 0:
66+ # return (nn.Dropout(dropout),)
67+ # else:
68+ # return tuple()
69+ # self.net = []
70+ # for _ in range(layers):
71+ # self.net.append(nn.Sequential(
72+ # *get_dropout(), nn.Linear(h, hidden*2), SelfGated()))
73+ # h = hidden
74+ # self.net.append(nn.Linear(hidden, output))
75+ # self.net = nn.Sequential(*self.net)
76+
77+ # with torch.no_grad():
78+ # self.net[-1].weight.mul_(1e-2)
79+
80+ # def forward(self, x):
81+ # return self.net(x)
82+
83+ class ModalityTransformer (nn .Module ):
84+ """Model joint distribution of modalities autoregressively with random permutations"""
85+ def __init__ (self , input_size , hidden_size , heads = 4 , layers = 1 ):
5186 super ().__init__ ()
87+ self .net = nn .TransformerEncoder (
88+ nn .TransformerEncoderLayer (
89+ input_size , heads , hidden_size , norm_first = False
90+ ), layers )
5291
53- def forward (self , x ):
54- a , b = x . chunk ( 2 , - 1 )
55- return a * b . sigmoid ()
56-
57- class SelfGatedMLP ( nn . Module ):
58- def __init__ ( self , input , hidden , output , layers , dropout = 0 ):
59- super (). __init__ ()
60- h = input
61- def get_dropout ():
62- if dropout > 0 :
63- return ( nn . Dropout ( dropout ),)
64- else :
65- return tuple ()
66- self . net = []
67- for _ in range ( layers ):
68- self . net . append ( nn . Sequential (
69- * get_dropout (), nn . Linear ( h , hidden * 2 ), SelfGated ()))
70- h = hidden
71- self . net . append ( nn . Linear ( hidden , output ) )
72- self . net = nn . Sequential ( * self . net )
73-
74- with torch . no_grad ():
75- self . net [ - 1 ]. weight . mul_ ( 1e-2 )
92+ def forward (self , h , modes ):
93+ """
94+ Args:
95+ modes: each a Tensor[batch x time x input_size]
96+ """
97+ x = [ h ] + modes
98+ batch_size = h . shape [ 0 ] * h . shape [ 1 ]
99+ # fold time into batch, stack modes
100+ x = torch . stack ([
101+ item . reshape ( batch_size , - 1 )
102+ for item in x
103+ ], 0 )
104+ # now "time"(mode) x "batch"(+time) x channel
105+
106+ # generate a mask
107+ # upper triangular (i.e. diagonal and above is True, meaning masked)
108+ # except h position should attend to self
109+ n = len ( modes ) + 1
110+ mask = x . new_ones (( n , n ), dtype = bool ). triu ( )
111+ mask [ 0 , 0 ] = False
112+
113+ x = self . net ( x , mask )
114+ return list ( x . reshape ( n , * h . shape ). unbind ( 0 ))[ 1 :]
76115
77- def forward (self , x ):
78- return self .net (x )
79116
80117class NotePredictor (nn .Module ):
81118 # note: use named arguments only for benefit of training script
82119 def __init__ (self ,
83- pitch_emb_size = 128 , time_emb_size = 128 , vel_emb_size = 128 ,
84- hidden_size = 512 , num_layers = 1 , kind = 'gru' , dropout = 0 ,
120+ emb_size = 256 ,
121+ rnn_hidden = 2048 , rnn_layers = 1 , kind = 'gru' ,
122+ ar_hidden = 2048 , ar_layers = 1 , ar_heads = 4 ,
123+ dropout = 0.1 ,
85124 num_pitches = 128 ,
86- time_bounds = (0 ,10 ), time_components = 16 , time_res = 1e-2 ,
87- vel_components = 8
125+ time_bounds = (0 ,10 ), time_components = 32 , time_res = 1e-2 ,
126+ vel_components = 16
88127 ):
89128 """
90129 """
@@ -102,32 +141,34 @@ def __init__(self,
102141 vel_components , 1.0 , lo = 0 , hi = 127 , init = 'velocity' )
103142
104143 # embeddings for inputs
105- self .pitch_emb = nn .Embedding (self .pitch_domain , pitch_emb_size )
106- self .time_emb = SineEmbedding (time_emb_size )
107- self .vel_emb = MixEmbedding (vel_emb_size , (0 , 127 ))
108-
109- self .pitch_missing = nn .Parameter (torch .randn (pitch_emb_size ))
110- self .time_missing = nn .Parameter (torch .randn (time_emb_size ))
111- self .vel_missing = nn .Parameter (torch .randn (vel_emb_size ))
144+ self .pitch_emb = nn .Embedding (self .pitch_domain , emb_size )
145+ self .time_emb = SineEmbedding (emb_size )
146+ self .vel_emb = MixEmbedding (emb_size , (0 , 127 ))
112147
113148 # RNN backbone
114149 self .rnn = GenericRNN (kind ,
115- pitch_emb_size + time_emb_size + vel_emb_size , hidden_size ,
116- num_layers = num_layers , batch_first = True , dropout = dropout )
150+ 3 * emb_size , rnn_hidden ,
151+ num_layers = rnn_layers , batch_first = True , dropout = dropout )
117152
118153 # learnable initial RNN state
119154 self .initial_state = nn .ParameterList ([
120155 # layer x batch x hidden
121- nn .Parameter (torch .randn (num_layers ,1 ,hidden_size ) * hidden_size ** - 0.5 )
156+ nn .Parameter (torch .randn (rnn_layers ,1 ,rnn_hidden ) * rnn_hidden ** - 0.5 )
122157 for _ in range (2 if kind == 'lstm' else 1 )
123158 ])
124159
125160 # projection from RNN state to distribution parameters
126- self .param_proj = SelfGatedMLP (
127- pitch_emb_size + time_emb_size + vel_emb_size + hidden_size ,
128- hidden_size // 2 ,
129- self .pitch_domain + self .time_dist .n_params + self .vel_dist .n_params ,
130- layers = 2 , dropout = dropout )
161+ self .h_proj = nn .Linear (rnn_hidden , emb_size )
162+ self .projections = nn .ModuleList ([
163+ nn .Linear (emb_size , self .pitch_domain ),
164+ nn .Linear (emb_size , self .time_dist .n_params , bias = False ),
165+ nn .Linear (emb_size , self .vel_dist .n_params , bias = False )
166+ ])
167+ for p in self .projections :
168+ with torch .no_grad ():
169+ p .weight .mul_ (1e-2 )
170+
171+ self .xformer = ModalityTransformer (emb_size , ar_hidden , ar_heads , ar_layers )
131172
132173 # persistent RNN state for inference
133174 for n ,t in zip (self .cell_state_names (), self .initial_state ):
@@ -151,82 +192,43 @@ def forward(self, pitches, times, velocities, validation=False):
151192 """
152193 batch_size , batch_len = pitches .shape
153194
154- pitch_emb = self .pitch_emb (pitches ) # batch, time, pitch_emb_size
155- time_emb = self .time_emb (times ) # batch, time, time_emb_size
156- vel_emb = self .vel_emb (velocities ) # batch, time, vel_emb_size
195+ pitch_emb = self .pitch_emb (pitches ) # batch, time, emb_size
196+ time_emb = self .time_emb (times ) # batch, time, emb_size
197+ vel_emb = self .vel_emb (velocities ) # batch, time, emb_size
157198
158- x = torch .cat ((pitch_emb , time_emb , vel_emb ), - 1 )
199+ embs = (pitch_emb , time_emb , vel_emb )
200+
201+ x = torch .cat (embs , - 1 )[:,:- 1 ] # skip last time position
159202 ## broadcast intial state to batch size
160203 initial_state = tuple (
161204 t .expand (self .rnn .num_layers , x .shape [0 ], - 1 ).contiguous () # 1 x batch x hidden
162205 for t in self .initial_state )
163206 h , _ = self .rnn (x , initial_state ) #batch, time, hidden_size
164207
165- # IDEA: fit all factorizations at once.
166- # add 'missing' value for time / pitch / velocity as model parameters
167- # expand the batch to 6x wide with ~/~/~, T/~/~, ~/P/~, ~/~/V, T/P/~, ~/P/V, T/~/V inputs
168- # the factorizations are:
169- # _~~ -> T_~ -> TP_
170- # _~~ -> T~_ -> T_V
171- # ~_~ -> _P~ -> TP_
172- # ~_~ -> ~P_ -> _PV
173- # ~~_ -> _~V -> T_V
174- # ~~_ -> ~_V -> _PV
175- # i.e. the fully masked positions are counted 2x,
176- # the single-masked positions are counted 2x,
177- # and each double-masked position is counted once
178-
179- masks = [
180- [2 , 0 , 0 , 2 , 0 , 1 , 1 ], #pitch
181- [2 , 2 , 0 , 0 , 1 , 0 , 1 ], #time
182- [2 , 0 , 2 , 0 , 1 , 1 , 0 ] #velocity
183- ]
184-
185- def mask_cat (missing , present , mask ):
186- missing = missing [None ,None ].expand (batch_size , batch_len - 1 , - 1 )
187- return torch .cat ([
188- present if m == 0 else missing for m in mask
189- ], 0 )
190-
191- pitch_features = mask_cat (self .pitch_missing , pitch_emb [:,1 :], masks [0 ])
192- time_features = mask_cat (self .time_missing , time_emb [:,1 :], masks [1 ])
193- vel_features = mask_cat (self .vel_missing , vel_emb [:,1 :], masks [2 ])
194-
195- features = torch .cat ((
196- pitch_features , time_features , vel_features , h [:,:- 1 ].repeat (7 ,1 ,1 )
197- ), - 1 ) # cat along feature dim
198-
199- dist_params = self .param_proj (features ) # combine features with h
200-
201- # split again into time/pitch/vel params
202- dist_params = dist_params .split ([
203- self .pitch_domain , self .time_dist .n_params , self .vel_dist .n_params
204- ], - 1 )
205-
206- # chunk into 7 and discard unmasked positions;
207- # stack the masked positions along new first dim
208- pitch_params , time_params , vel_params = (
209- torch .stack ([
210- ch
211- for m ,ch in zip (mask , dp .chunk (7 , 0 )) if m > 0
212- ], 0 )
213- for mask ,dp in zip (masks , dist_params )
214- )
215-
216- #TODO: weighting
217- # weights = np.log([[m for m in mask if m>0] for mask in masks]) # 3 x 4
208+ # include initial hidden state for predicting first note
209+ h = torch .cat ((
210+ self .initial_state [0 ][- 1 ][None ].expand (batch_size , 1 , - 1 ),
211+ h ), - 2 )
212+
213+ # fit all note factorizations at once.
214+ perm = torch .randperm (3 )
215+ embs = [embs [i ] for i in perm ]
216+ mode_hs = self .xformer (self .h_proj (h ), embs )
217+ mode_hs = [mode_hs [perm [i ]] for i in perm ]
218+
219+ pitch_params , time_params , vel_params = [
220+ proj (h ) for proj ,h in zip (self .projections , mode_hs )]
218221
219222 # get likelihoods
220223 pitch_logits = F .log_softmax (pitch_params , - 1 )
221- # TODO: is gather working right with extra dim?
222- pitch_targets = pitches [None ,:,1 :,None ].expand (4 , - 1 , - 1 , - 1 ) #1, batch, time-1, 1
224+ pitch_targets = pitches [:,1 :,None ] #batch, time, 1
223225 pitch_log_probs = pitch_logits .gather (- 1 , pitch_targets )[...,0 ]
224226
225- time_targets = times [:, 1 :] # batch, time-1
227+ time_targets = times # batch, time
226228 time_result = self .time_dist (time_params , time_targets )
227229 time_log_probs = time_result .pop ('log_prob' )
228230
229- vel_targets = velocities [:, 1 :] # batch, time-1
231+ vel_targets = velocities # batch, time
230232 vel_result = self .vel_dist (vel_params , vel_targets )
231233 vel_log_probs = vel_result .pop ('log_prob' )
232234
@@ -269,41 +271,58 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
269271 with torch .no_grad ():
270272 pitch = torch .LongTensor ([[pitch ]]) # 1x1 (batch, time)
271273 time = torch .FloatTensor ([[time ]]) # 1x1 (batch, time)
272- x = torch .cat ((
273- self .pitch_emb (pitch ), # 1, 1, pitch_emb_size
274- self .time_emb (time )# 1, 1, time_emb_size
275- ), - 1 )
274+ vel = torch .FloatTensor ([[vel ]]) # 1x1 (batch, time)
275+
276+ embs = [
277+ self .pitch_emb (pitch ), # 1, 1, emb_size
278+ self .time_emb (time ),# 1, 1, emb_size
279+ self .vel_emb (vel )# 1, 1, emb_size
280+ ]
281+ x = torch .cat (embs , - 1 )
276282
277283 h , new_state = self .rnn (x , self .cell_state )
278284 for t ,new_t in zip (self .cell_state , new_state ):
279285 t [:] = new_t
280286
281- pitch_params = self .pitch_proj (h )
287+ h = self .h_proj (h )
288+
289+ # TODO: permutations
290+ # TODO: optimize by removing unused positions
291+ # TODO: refactor with common distribution API
292+ pitch_h , = self .xformer (h , embs [:1 ])
293+
294+ pitch_params = self .projections [0 ](pitch_h )
282295 pred_pitch = D .Categorical (logits = pitch_params ).sample ()
283-
284- time_params = self .time_proj (h * self .cond_proj (self .pitch_emb (pred_pitch )).sigmoid ())
285- # time_params = self.time_proj(torch.cat((
286- # h, self.pitch_emb(pred_pitch)
287- # ), -1)) # 1, 1, time_params
288- # TODO: importance sampling?
289- pred_time = self .time_dist .sample (time_params ).squeeze (0 )
290-
291- ### TODO: generalize, move into sample
292- ### DEBUG
293- # pi only, fewer zeros:
294- log_pi , loc , s = (
295- t .squeeze () for t in self .time_dist .get_params (time_params ))
296- bias = 2 #float('inf')
297- log_pi = torch .where (loc <= self .time_dist .res , log_pi - bias , log_pi )
298- idx = D .Categorical (logits = log_pi ).sample ()
299- pred_time = loc [idx ].clamp (0 ,10 )
296+
297+ embs [0 ] = self .pitch_emb (pred_pitch )
298+ _ , time_h = self .xformer (h , embs [:2 ])
299+
300+ time_params = self .projections [1 ](time_h )
301+ pred_time = self .time_dist .sample (time_params )
302+
303+ embs [1 ] = self .time_emb (pred_time )
304+ _ , _ , vel_h = self .xformer (h , embs )
305+
306+ vel_params = self .projections [2 ](vel_h )
307+ pred_vel = self .vel_dist .sample (vel_params )
308+
309+ # ### TODO: generalize, move into sample
310+ # ### DEBUG
311+ # # pi only, fewer zeros:
312+ # log_pi, loc, s = (
313+ # t.squeeze() for t in self.time_dist.get_params(time_params))
314+ # bias = 2#float('inf')
315+ # log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
316+ # idx = D.Categorical(logits=log_pi).sample()
317+ # pred_time = loc[idx].clamp(0,10)
300318
301319 return {
302320 'pitch' : pred_pitch .item (),
303321 'time' : pred_time .item (),
304- 'velocity' : None ,
322+ 'velocity' : pred_vel . item () ,
305323 'pitch_params' : pitch_params ,
306- 'time_params' : time_params
324+ 'time_params' : time_params ,
325+ 'vel_params' : vel_params
307326 }
308327
309328 # TODO: start velocity
0 commit comments