1+ import math
2+
13import torch
24from torch import nn
35import torch .nn .functional as F
46
57from .rnn import GenericRNN
8+ from .distributions import CensoredMixturePointyBoi
9+
10+ class SineEmbedding (nn .Module ):
11+ def __init__ (self , n , f0 = 1e-3 , interval = 2 ):
12+ super ().__init__ ()
13+ self .n = n
14+ self .register_buffer ('fs' , f0 * interval ** torch .arange (n ) * 2 * math .pi )
15+
16+ def forward (self , x ):
17+ x = x [...,None ] * self .fs
18+ return x .sin ()
619
7- class PitchPredictor (nn .Module ):
20+ class NotePredictor (nn .Module ):
821 # note: use named arguments only for benefit of training script
9- def __init__ (self , emb_size = 128 , hidden_size = 512 , domain_size = 128 ,
10- num_layers = 1 , kind = 'gru' , dropout = 0 ):
22+ def __init__ (self ,
23+ pitch_emb_size = 128 , time_emb_size = 16 , hidden_size = 512 ,
24+ num_layers = 1 , kind = 'gru' , dropout = 0 ,
25+ num_pitches = 128 ,
26+ time_components = 5 , time_res = 1e-2 ,
27+ ):
1128 """
1229 """
1330 super ().__init__ ()
1431
15- self .start_token = domain_size - 2
16- self .end_token = domain_size - 1
32+ self .start_token = num_pitches
33+ self .end_token = num_pitches + 1
1734
18- self .emb = nn .Embedding (domain_size , emb_size )
19- self .proj = nn .Linear (hidden_size , domain_size )
20- #### DEBUG
21- with torch .no_grad ():
22- self .proj .weight .mul_ (1e-2 )
35+ self .pitch_domain = num_pitches + 2
36+
37+ # TODO: upper truncation?
38+ self .time_dist = CensoredMixturePointyBoi (time_components , time_res , 0 , 10 )
2339
24- self .rnn = GenericRNN (kind , emb_size , hidden_size ,
40+ # embeddings for inputs
41+ self .pitch_emb = nn .Embedding (self .pitch_domain , pitch_emb_size )
42+ self .time_emb = SineEmbedding (time_emb_size )
43+
44+ # RNN backbone
45+ self .rnn = GenericRNN (kind , pitch_emb_size + time_emb_size , hidden_size ,
2546 num_layers = num_layers , batch_first = True , dropout = dropout )
26-
27- # learnable initial state
47+
48+ # learnable initial RNN state
2849 self .initial_state = nn .ParameterList ([
2950 # layer x batch x hidden
3051 nn .Parameter (torch .randn (num_layers ,1 ,hidden_size )* hidden_size ** - 0.5 )
3152 for _ in range (2 if kind == 'lstm' else 1 )
3253 ])
3354
34- # persistent state for inference
55+ # projection from RNN state to distribution parameters
56+ self .time_proj = nn .Linear (hidden_size , self .time_dist .n_params , bias = False )
57+ self .pitch_proj = nn .Linear (hidden_size + time_emb_size , self .pitch_domain )
58+ with torch .no_grad ():
59+ self .time_proj .weight .mul_ (1e-2 )
60+ self .pitch_proj .weight .mul_ (1e-2 )
61+
62+ # persistent RNN state for inference
3563 for n ,t in zip (self .cell_state_names (), self .initial_state ):
3664 self .register_buffer (n , t .clone ())
3765
@@ -42,36 +70,61 @@ def cell_state_names(self):
4270 def cell_state (self ):
4371 return tuple (getattr (self , n ) for n in self .cell_state_names ())
4472
45- def forward (self , notes ):
73+ def forward (self , pitches , times ):
4674 """
4775 Args:
48- notes: LongTensor[batch, time]
76+ pitches: LongTensor[batch, time]
77+ times: FloatTensor[batch, time]
4978 """
50- x = self .emb (notes ) # batch, time, emb_size
79+
80+ time_emb = self .time_emb (times ) # batch, time, time_emb_size
81+ pitch_emb = self .pitch_emb (pitches ) # batch, time, note_emb_size
82+
83+ x = torch .cat ((pitch_emb , time_emb ), - 1 )
5184 ## broadcast intial state to batch size
5285 initial_state = tuple (
5386 t .expand (self .rnn .num_layers , x .shape [0 ], - 1 ).contiguous () # 1 x batch x hidden
5487 for t in self .initial_state )
5588 h , _ = self .rnn (x , initial_state ) #batch, time, hidden_size
5689
57- logits = self .proj (h [:,:- 1 ]) # batch, time-1, 128
58- logits = F .log_softmax (logits , - 1 ) # logits = logits - logits.logsumexp(-1, keepdim=True)
59- targets = notes [:,1 :,None ] #batch, time-1, 1
60- return {
61- 'log_probs' : logits .gather (- 1 , targets )[...,0 ],
62- 'logits' : logits
90+ # RNN hidden state -> time prediction
91+ time_params = self .time_proj (h [:,:- 1 ]) # batch, time-1, time_params
92+ time_targets = times [:,1 :] # batch, time-1
93+ time_result = self .time_dist (time_params , time_targets )
94+ time_log_probs = time_result .pop ('log_prob' )
95+
96+ # RNN hidden state, time -> pitch prediction
97+ # pitch_params = h[...,:self.pitch_domain] + self.pitch_bias # CI
98+ pitch_params = self .pitch_proj (torch .cat ((h [:,:- 1 ], time_emb [:,1 :]), - 1 ))
99+ pitch_logits = F .log_softmax (pitch_params , - 1 )
100+ pitch_targets = pitches [:,1 :,None ] #batch, time-1, 1
101+ pitch_log_probs = pitch_logits .gather (- 1 , pitch_targets )[...,0 ]
102+
103+ r = {
104+ 'pitch_log_probs' : pitch_log_probs ,
105+ 'time_log_probs' : time_log_probs ,
106+ ** time_result
63107 }
108+ with torch .no_grad ():
109+ r ['time_acc_30ms' ] = (
110+ self .time_dist .cdf (time_params , time_targets + 0.03 )
111+ - torch .where (time_targets - 0.03 >= 0 ,
112+ self .time_dist .cdf (time_params , time_targets - 0.03 ),
113+ time_targets .new_zeros ([]))
114+ )
115+ return r
64116
65- def predict (self , note , sample = True ):
117+ # TODO: time
118+ def predict (self , note , time , sample = True ):
66119 """
67120 Args:
68121 note: int
69122 sample: bool
70123 Returns:
71- int if `sample` else Tensor[domain_size ]
124+ int if `sample` else Tensor[num_notes+2 ]
72125 """
73126 note = torch .LongTensor ([[note ]]) # 1x1 (batch, time)
74- x = self .emb (note ) # 1, 1, emb_size
127+ x = self .note_emb (note ) # 1, 1, emb_size
75128
76129 h , new_state = self .rnn (x , self .cell_state )
77130 for t ,new_t in zip (self .cell_state , new_state ):
0 commit comments