@@ -84,34 +84,40 @@ class ModalityTransformer(nn.Module):
8484 """Model joint distribution of modalities autoregressively with random permutations"""
8585 def __init__ (self , input_size , hidden_size , heads = 4 , layers = 1 ):
8686 super ().__init__ ()
87- self .net = nn .TransformerEncoder (
88- nn .TransformerEncoderLayer (
87+ self .net = nn .TransformerDecoder (
88+ nn .TransformerDecoderLayer (
8989 input_size , heads , hidden_size , norm_first = False
9090 ), layers )
9191
92- def forward (self , h , modes ):
92+ def forward (self , h , targets ):
9393 """
9494 Args:
95- modes: each a Tensor[batch x time x input_size]
95+ h: list of Tensor[batch x time x input_size], length note_dim+1
96+ targets: list of Tensor[batch x time x input_size], length note_dim-1
9697 """
97- x = [h ]+ modes
98- batch_size = h .shape [0 ]* h .shape [1 ]
98+ h = list (h )
99+ targets = list (targets )
100+ # h is 'target' w.r.t TransformerDecoder
101+ # targets is 'memory'
102+ batch_size = h [0 ].shape [0 ]* h [0 ].shape [1 ]
99103 # fold time into batch, stack modes
100- x = torch .stack ([
104+ tgt = torch .stack ([
101105 item .reshape (batch_size ,- 1 )
102- for item in x
106+ for item in h [1 :]
107+ ],0 )
108+ mem = torch .stack ([
109+ item .reshape (batch_size ,- 1 )
110+ for item in h [:1 ]+ targets
103111 ],0 )
104112 # now "time"(mode) x "batch"(+time) x channel
105113
106114 # generate a mask
107- # upper triangular (i.e. diagonal and above is True, meaning masked)
108- # except h position should attend to self
109- n = len (modes )+ 1
110- mask = x .new_ones ((n ,n ), dtype = bool ).triu ()
111- mask [0 ,0 ] = False
115+ # this is both the target and memory mask
116+ n = len (h )- 1
117+ mask = ~ tgt .new_ones ((n ,n ), dtype = bool ).tril ()
112118
113- x = self .net (x , mask )
114- return list (x .reshape (n , * h .shape ).unbind (0 ))[ 1 :]
119+ x = self .net (tgt , mem , mask , mask )
120+ return list (x .reshape (n , * h [ 0 ] .shape ).unbind (0 ))
115121
116122
117123class NotePredictor (nn .Module ):
@@ -129,6 +135,8 @@ def __init__(self,
129135 """
130136 super ().__init__ ()
131137
138+ self .note_dim = 3 # pitch, time, velocity
139+
132140 self .start_token = num_pitches
133141 self .end_token = num_pitches + 1
134142
@@ -147,7 +155,7 @@ def __init__(self,
147155
148156 # RNN backbone
149157 self .rnn = GenericRNN (kind ,
150- 3 * emb_size , rnn_hidden ,
158+ self . note_dim * emb_size , rnn_hidden ,
151159 num_layers = rnn_layers , batch_first = True , dropout = dropout )
152160
153161 # learnable initial RNN state
@@ -158,7 +166,7 @@ def __init__(self,
158166 ])
159167
160168 # projection from RNN state to distribution parameters
161- self .h_proj = nn .Linear (rnn_hidden , emb_size )
169+ self .h_proj = nn .Linear (rnn_hidden , emb_size * ( 1 + self . note_dim ) )
162170 self .projections = nn .ModuleList ([
163171 nn .Linear (emb_size , self .pitch_domain ),
164172 nn .Linear (emb_size , self .time_dist .n_params , bias = False ),
@@ -199,22 +207,25 @@ def forward(self, pitches, times, velocities, validation=False):
199207 embs = (pitch_emb , time_emb , vel_emb )
200208
201209 x = torch .cat (embs , - 1 )[:,:- 1 ] # skip last time position
202- ## broadcast intial state to batch size
210+ ## broadcast initial state to batch size
203211 initial_state = tuple (
204212 t .expand (self .rnn .num_layers , x .shape [0 ], - 1 ).contiguous () # 1 x batch x hidden
205213 for t in self .initial_state )
206214 h , _ = self .rnn (x , initial_state ) #batch, time, hidden_size
207215
208216 # include initial hidden state for predicting first note
209- h = torch .cat ((
210- self .initial_state [0 ][- 1 ][None ].expand (batch_size , 1 , - 1 ),
211- h ), - 2 )
217+ # h = torch.cat((
218+ # self.initial_state[0][-1][None].expand(batch_size, 1, -1),
219+ # h), -2)
212220
213221 # fit all note factorizations at once.
214- perm = torch .randperm (3 )
215- embs = [embs [i ] for i in perm ]
216- mode_hs = self .xformer (self .h_proj (h ), embs )
217- mode_hs = [mode_hs [perm [i ]] for i in perm ]
222+ # TODO: perm each batch item independently?
223+ perm = torch .randperm (self .note_dim )
224+ hs = list (self .h_proj (h ).chunk (self .note_dim + 1 , - 1 ))
225+ hs = hs [:1 ] + [hs [i + 1 ] for i in perm ]
226+ embs = [embs [i ][:,1 :] for i in perm [:- 1 ]]
227+ mode_hs = self .xformer (hs , embs )
228+ mode_hs = [mode_hs [i ] for i in perm .argsort ()]
218229
219230 pitch_params , time_params , vel_params = [
220231 proj (h ) for proj ,h in zip (self .projections , mode_hs )]
@@ -224,11 +235,11 @@ def forward(self, pitches, times, velocities, validation=False):
224235 pitch_targets = pitches [:,1 :,None ] #batch, time, 1
225236 pitch_log_probs = pitch_logits .gather (- 1 , pitch_targets )[...,0 ]
226237
227- time_targets = times # batch, time
238+ time_targets = times [:, 1 :] # batch, time
228239 time_result = self .time_dist (time_params , time_targets )
229240 time_log_probs = time_result .pop ('log_prob' )
230241
231- vel_targets = velocities # batch, time
242+ vel_targets = velocities [:, 1 :] # batch, time
232243 vel_result = self .vel_dist (vel_params , vel_targets )
233244 vel_log_probs = vel_result .pop ('log_prob' )
234245
@@ -252,7 +263,7 @@ def forward(self, pitches, times, velocities, validation=False):
252263 )
253264 return r
254265
255- # TODO: vel
266+ # TODO: force
256267 def predict (self , pitch , time , vel , force = (None , None , None )):
257268 """
258269 supply the most recent note and return a prediction for the next note.
@@ -284,38 +295,37 @@ def predict(self, pitch, time, vel, force=(None, None, None)):
284295 for t ,new_t in zip (self .cell_state , new_state ):
285296 t [:] = new_t
286297
287- h = self .h_proj (h )
298+ h = self .h_proj (h ). chunk ( self . note_dim + 1 , - 1 )
288299
289300 # TODO: permutations
290- # TODO: optimize by removing unused positions
291301 # TODO: refactor with common distribution API
292- pitch_h , = self .xformer (h , embs [: 1 ])
302+ pitch_h , = self .xformer (h [: 2 ], [ ])
293303
294304 pitch_params = self .projections [0 ](pitch_h )
295305 pred_pitch = D .Categorical (logits = pitch_params ).sample ()
296306
297307 embs [0 ] = self .pitch_emb (pred_pitch )
298- _ , time_h = self .xformer (h , embs [:2 ])
308+ _ , time_h = self .xformer (h [: 3 ] , embs [:1 ])
299309
300310 time_params = self .projections [1 ](time_h )
301- pred_time = self .time_dist .sample (time_params )
311+ # pred_time = self.time_dist.sample(time_params)
312+
313+ ### TODO: generalize, move into sample
314+ # pi only, fewer zeros:
315+ log_pi , loc , s = (
316+ t for t in self .time_dist .get_params (time_params ))
317+ bias = float ('inf' )
318+ log_pi = torch .where (loc <= self .time_dist .res , log_pi - bias , log_pi )
319+ idx = D .Categorical (logits = log_pi ).sample ().item ()
320+ pred_time = loc [...,idx ].clamp (0 ,10 )
321+ ###
302322
303323 embs [1 ] = self .time_emb (pred_time )
304- _ , _ , vel_h = self .xformer (h , embs )
324+ _ , _ , vel_h = self .xformer (h , embs [: 2 ] )
305325
306326 vel_params = self .projections [2 ](vel_h )
307327 pred_vel = self .vel_dist .sample (vel_params )
308328
309- # ### TODO: generalize, move into sample
310- # ### DEBUG
311- # # pi only, fewer zeros:
312- # log_pi, loc, s = (
313- # t.squeeze() for t in self.time_dist.get_params(time_params))
314- # bias = 2#float('inf')
315- # log_pi = torch.where(loc <= self.time_dist.res, log_pi-bias, log_pi)
316- # idx = D.Categorical(logits=log_pi).sample()
317- # pred_time = loc[idx].clamp(0,10)
318-
319329 return {
320330 'pitch' : pred_pitch .item (),
321331 'time' : pred_time .item (),
@@ -336,7 +346,7 @@ def reset(self, start=True):
336346 for n ,t in zip (self .cell_state_names (), self .initial_state ):
337347 getattr (self , n )[:] = t .detach ()
338348 if start :
339- self .predict (self .start_token , 0. )
349+ self .predict (self .start_token , 0. , 0. )
340350
341351 @classmethod
342352 def from_checkpoint (cls , path ):
0 commit comments