@@ -486,7 +486,7 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
486486 nn .Linear (anglesdecoder_hidden [0 ], anglesdecoder_hidden [1 ]),
487487 nn .GELU (),
488488 nn .Linear (anglesdecoder_hidden [1 ], 3 ),
489- nn .Softmax (dim = - 1 )
489+ # nn.Softmax(dim=-1)
490490 )
491491
492492 # Bond angles prediction
@@ -758,7 +758,7 @@ def __init__(
758758 nn .Conv1d (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ] if len (AAdecoder_hidden ) > 2 else AAdecoder_hidden [1 ], kernel_size = 3 , padding = 1 ),
759759 nn .GELU (),
760760 nn .Conv1d (AAdecoder_hidden [2 ] if len (AAdecoder_hidden ) > 2 else AAdecoder_hidden [1 ], 20 , kernel_size = 1 ), # 20 amino acids
761- nn .Softmax (dim = 1 ) # Probabilities for amino acid classes
761+ # nn.Softmax(dim=1) # Probabilities for amino acid classes
762762 )
763763
764764 # Optional secondary structure prediction head
@@ -772,7 +772,7 @@ def __init__(
772772 nn .Conv1d (ssdecoder_hidden [1 ], ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], kernel_size = 3 , padding = 1 ),
773773 nn .GELU (),
774774 nn .Conv1d (ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], 3 , kernel_size = 1 ), # 3 SS classes
775- nn .Softmax (dim = 1 ) # Probabilities for SS classes
775+ # nn.Softmax(dim=1) # Probabilities for SS classes
776776 )
777777
778778 def forward (self , data , ** kwargs ):
@@ -862,7 +862,7 @@ def forward(self, data, **kwargs):
862862 # AA prediction
863863 aa_out = self .head ['aa_cnn' ](x_cnn ) # (1, 20, seq_len)
864864 aa_out = aa_out .permute (2 , 0 , 1 ).squeeze (1 ) # (seq_len, 20)
865- aa_list .append (F . log_softmax ( aa_out , dim = - 1 ) )
865+ aa_list .append (aa_out )
866866
867867 # SS prediction if enabled
868868 if self .output_ss and 'ss_cnn' in self .head :
@@ -880,7 +880,6 @@ def forward(self, data, **kwargs):
880880 # AA prediction
881881 aa = self .head ['aa_cnn' ](x_cnn ) # (1, 20, seq_len)
882882 aa = aa .permute (2 , 0 , 1 ).squeeze (1 ) # (seq_len, 20)
883- aa = F .log_softmax (aa , dim = - 1 )
884883
885884 # SS prediction if enabled
886885 if self .output_ss and 'ss_cnn' in self .head :
@@ -995,7 +994,7 @@ def __init__(
995994 nn .Conv1d (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ], kernel_size = 3 , padding = 1 ),
996995 nn .GELU (),
997996 nn .Conv1d (AAdecoder_hidden [2 ], 20 , kernel_size = 1 ),
998- nn .Softmax (dim = - 1 )
997+ # nn.Softmax(dim=-1)
999998 )
1000999 else :
10011000 # DNN decoder for amino acid prediction
@@ -1007,7 +1006,7 @@ def __init__(
10071006 nn .Linear (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ]),
10081007 nn .GELU (),
10091008 nn .Linear (AAdecoder_hidden [2 ], 20 ),
1010- nn .Softmax (dim = - 1 )
1009+ # nn.Softmax(dim=-1)
10111010 )
10121011
10131012 # Optional secondary structure prediction head
@@ -1020,7 +1019,7 @@ def __init__(
10201019 nn .Linear (AAdecoder_hidden [1 ], AAdecoder_hidden [2 ]),
10211020 nn .GELU (),
10221021 nn .Linear (AAdecoder_hidden [2 ], 3 ),
1023- nn .Softmax (dim = - 1 )
1022+ # nn.Softmax(dim=-1)
10241023 )
10251024
10261025 def forward (self , data , ** kwargs ):
@@ -1089,7 +1088,7 @@ def forward(self, data, **kwargs):
10891088 xi_cnn = xi .permute (1 , 0 ).unsqueeze (0 ) # (1, d_model, seq_len)
10901089 xi_cnn = self .head ['cnn_decoder' ](xi_cnn ) # (1, 20, seq_len)
10911090 xi_cnn = xi_cnn .permute (2 , 0 , 1 ).squeeze (1 ) # (seq_len, 20)
1092- aa_list .append (F . log_softmax ( xi_cnn [:seq_len , :], dim = - 1 ) )
1091+ aa_list .append (xi_cnn [:seq_len , :])
10931092 else :
10941093 aa_list .append (self .head ['dnn_decoder' ](xi [:seq_len , 0 ]))
10951094 if 'ss_head' in self .head :
@@ -1108,7 +1107,7 @@ def forward(self, data, **kwargs):
11081107 x_cnn = x .permute (1 , 2 , 0 ) # (batch, d_model, seq_len)
11091108 x_cnn = self .head ['cnn_decoder' ](x_cnn ) # (batch, 20, seq_len)
11101109 x_cnn = x_cnn .permute (2 , 0 , 1 ) # (seq_len, batch, 20)
1111- aa = F . log_softmax ( x_cnn , dim = - 1 )
1110+ aa = x_cnn . squeeze ( 1 ) # (seq_len, 20 )
11121111 else :
11131112 aa = self .head ['dnn_decoder' ](x )
11141113 if 'ss_head' in self .head :
@@ -1261,7 +1260,7 @@ def __init__(
12611260 nn .Conv1d (ssdecoder_hidden [1 ], ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], kernel_size = 3 , padding = 1 ),
12621261 nn .GELU (),
12631262 nn .Conv1d (ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], 3 , kernel_size = 1 ),
1264- nn .LogSoftmax (dim = - 1 )
1263+ # nn.LogSoftmax(dim=-1)
12651264
12661265 )
12671266 else :
@@ -1274,7 +1273,7 @@ def __init__(
12741273 nn .Linear (ssdecoder_hidden [1 ], ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ]),
12751274 nn .GELU (),
12761275 nn .Linear (ssdecoder_hidden [2 ] if len (ssdecoder_hidden ) > 2 else ssdecoder_hidden [1 ], 3 ),
1277- nn .LogSoftmax (dim = - 1 )
1276+ # nn.LogSoftmax(dim=-1)
12781277 )
12791278
12801279 # Bond angles prediction head (phi, psi, omega)
0 commit comments