@@ -37,9 +37,7 @@ def __init__(
3737 rnn_klass_str = rnn_type .upper ().split ("-" )[0 ]
3838 if rnn_klass_str not in ["GRU" , "LSTM" , "RNN" ]:
3939 raise ValueError (f"Unrecognized rnn_type: { rnn_type } " )
40- rnn_klass = {"GRU" : torch .nn .GRU , "LSTM" : torch .nn .LSTM , "RNN" : torch .nn .RNN }[
41- rnn_klass_str
42- ]
40+ rnn_klass = {"GRU" : torch .nn .GRU , "LSTM" : torch .nn .LSTM , "RNN" : torch .nn .RNN }[rnn_klass_str ]
4341 rnn_kwargs = {}
4442 if rnn_klass_str == "RNN" :
4543 rnn_kwargs ["nonlinearity" ] = rnn_type .lower ().split ("-" )[- 1 ]
@@ -57,16 +55,11 @@ def __init__(
5755 if isinstance (output_size , int ):
5856 output_size = {"output" : output_size }
5957 self .heads = torch .nn .ModuleDict (
60- {
61- name : torch .nn .Linear (hidden_size , size )
62- for name , size in output_size .items ()
63- }
58+ {name : torch .nn .Linear (hidden_size , size ) for name , size in output_size .items ()}
6459 )
6560
6661 @classmethod
67- def infer_config_from_state_dict (
68- cls , state_dict : dict , rnn_type : str = "GRU"
69- ) -> dict [str , int | float ]:
62+ def infer_config_from_state_dict (cls , state_dict : dict , rnn_type : str = "GRU" ) -> dict [str , int | float ]:
7063 """
7164 This method is specific to each processor.
7265
@@ -88,8 +81,7 @@ def infer_config_from_state_dict(
8881 # Infer input_size from linear_embeddings.weight (shape: [input_size, input_size])
8982 "input_size" : state_dict ["linear_embeddings.weight" ].shape [1 ],
9083 # Infer hidden_size from rnn.weight_ih_l0 (shape: [hidden_size * 3, input_size])
91- "hidden_size" : state_dict ["rnn.weight_ih_l0" ].shape [0 ]
92- // cls ._get_gate_count (rnn_type ),
84+ "hidden_size" : state_dict ["rnn.weight_ih_l0" ].shape [0 ] // cls ._get_gate_count (rnn_type ),
9385 # Infer num_layers by counting rnn layers in state_dict (e.g., weight_ih_l<k>)
9486 "num_layers" : sum (1 for key in state_dict if "rnn.weight_ih_l" in key ),
9587 "output_size" : output_size ,
@@ -134,27 +126,25 @@ def forward(
134126 ) -> tuple [dict [str , torch .Tensor ], torch .Tensor | tuple ]:
135127 """
136128 Forward pass through the RNN model.
129+
137130 Args:
138131 x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
139132 input_lens (Optional[torch.Tensor]): Optional tensor of lengths for each sequence in the batch.
140133 If provided, sequences will be packed before passing through the RNN.
141134 hx (Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]): Optional initial hidden state for the RNN.
135+
142136 Returns:
143137 tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
144- A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size).
145- If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
138+ A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size).
139+ If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
146140 """
147141 x = self .linear_embeddings (x )
148142 x = self .dropout_input (x )
149143 total_length = x .shape [1 ]
150144 if input_lens is not None :
151- x = torch .nn .utils .rnn .pack_padded_sequence (
152- x , input_lens , batch_first = True , enforce_sorted = False
153- )
145+ x = torch .nn .utils .rnn .pack_padded_sequence (x , input_lens , batch_first = True , enforce_sorted = False )
154146 x_out , hx_out = self .rnn (x , hx )
155147 if input_lens is not None :
156- x_out , _ = torch .nn .utils .rnn .pad_packed_sequence (
157- x_out , batch_first = True , total_length = total_length
158- )
148+ x_out , _ = torch .nn .utils .rnn .pad_packed_sequence (x_out , batch_first = True , total_length = total_length )
159149 x_out = self .output_dropout (x_out )
160150 return {name : head (x_out ) for name , head in self .heads .items ()}, hx_out
0 commit comments