22from torch import nn
33import torch .nn .functional as F
44
5- # import geotorch
6-
7- # shim torch RNN,GRU classes to have same API as LSTM
8- def rnn_shim (cls ):
9- """LSTM API for GRU and RNN.
10-
11- hidden state is first element of state tuple"""
12- class shim (cls ):
13- def forward (self , input , states ):
14- assert len (states )== 1
15- out , h = super ().forward (input , * states )
16- return out , (h ,)
17- return shim
18-
19- GRU = rnn_shim (nn .GRU )
20- RNN = rnn_shim (nn .RNN )
21- LSTM = nn .LSTM
22-
23- class ExpRNN (nn .Module ):
24- pass
25-
26- # class ExpRNNCell(nn.Module):
27- # pass
28-
29- class GenericRNN (nn .Module ):
30- kind_cls = {
31- 'gru' :GRU ,
32- 'lstm' :LSTM ,
33- 'elman' :RNN ,
34- 'exprnn' :ExpRNN
35- }
36- # desiderata:
37- # support geotorch constraints
38- # clean API for multiple layers, multiple cell states (e.g. LSTM)
39- def __init__ (self , kind , * a , ** kw ):
40- super ().__init__ ()
41- if kw .get ('bidirectional' ): raise ValueError ("""
42- bidirectional GenericRNN not supported.
43- """ )
44- cls = GenericRNN .kind_cls [kind ]
45- self .rnn = cls (* a , ** kw )
46-
47- def __getattr__ (self , a ):
48- try :
49- return super ().__getattr__ (a )
50- except AttributeError :
51- return getattr (self .rnn , a )
52-
53- def forward (self , x , initial_state ):
54- """
55- Args:
56- x: Tensor[batch x time x channel] if batch_first else [time x batch x channel]
57- initial_state: List[Tensor[layers x batch x hidden]]], list of components
58- with 0 being hidden state (e.g. 1 is cell state for LSTM).
59- Returns:
60- hidden: hidden states of top layers Tensor[batch x time x hidden]
61- or [time x batch x hidden]
62- new_states: List[Tensor[layers x batch x hidden]]
63- """
64- hidden , final_state = self .rnn .forward (x , initial_state ) #forward or __call__?
65- return hidden , final_state
66-
67- ## NOTE: individual time-step API might be useful, not actually needed yet though
68- # def step(self, x, state):
69- # """
70- # Args:
71- # x: Tensor[batch x channel]
72- # state: List[Tensor[layers x batch x hidden]]], list of components
73- # with 0 being hidden state (e.g. 1 is cell state for LSTM).
74- # Returns:
75- # hidden: hidden state of top layer [batch x hidden]
76- # new_states: List[Tensor[layers x batch x hidden]]
77- # """
78- # time_idx = 1 if self.rnn.batch_first else 0
79- # x = x.unsqueeze(time_idx)
80- # hidden, state = self.forward(x, state)
81- # return hidden.squeeze(time_idx), state
82-
5+ from .rnn import GenericRNN
836
847class PitchPredictor (nn .Module ):
858 # note: use named arguments only for benefit of training script
86- def __init__ (self , emb_size = 128 , hidden_size = 512 , domain_size = 128 , num_layers = 1 ):
9+ def __init__ (self , emb_size = 128 , hidden_size = 512 , domain_size = 128 ,
10+ num_layers = 1 , kind = 'gru' , dropout = 0 ):
8711 """
8812 """
8913 super ().__init__ ()
@@ -93,14 +17,18 @@ def __init__(self, emb_size=128, hidden_size=512, domain_size=128, num_layers=1)
9317
9418 self .emb = nn .Embedding (domain_size , emb_size )
9519 self .proj = nn .Linear (hidden_size , domain_size )
20+ #### DEBUG
21+ with torch .no_grad ():
22+ self .proj .weight .mul_ (1e-2 )
9623
97- self .rnn = GenericRNN ('gru' , emb_size , hidden_size ,
98- num_layers = num_layers , batch_first = True )
24+ self .rnn = GenericRNN (kind , emb_size , hidden_size ,
25+ num_layers = num_layers , batch_first = True , dropout = dropout )
9926
10027 # learnable initial state
10128 self .initial_state = nn .ParameterList ([
10229 # layer x batch x hidden
103- nn .Parameter (torch .randn (num_layers ,1 ,hidden_size )* hidden_size ** - 0.5 ),
30+ nn .Parameter (torch .randn (num_layers ,1 ,hidden_size )* hidden_size ** - 0.5 )
31+ for _ in range (2 if kind == 'lstm' else 1 )
10432 ])
10533
10634 # persistent state for inference
0 commit comments