Skip to content

Commit 2c97c33

Browse files
committed
fix discrete cat grad and losses
1 parent e5e4646 commit 2c97c33

3 files changed

Lines changed: 76 additions & 107 deletions

File tree

foldtree2/notebooks/experiments/test_monodecoders.ipynb

Lines changed: 59 additions & 91 deletions
Large diffs are not rendered by default.

foldtree2/src/losses/losses.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ def aa_reconstruction_loss(x, recon_x , normalize = False):
529529
The function expects one-hot encoded targets (not class indices).
530530
Cross-entropy will internally convert these to class indices.
531531
"""
532-
533-
return F.cross_entropy(recon_x, x)
532+
target = x.argmax(dim=-1) # integers in [0..19], no gradients needed
533+
return F.cross_entropy(recon_x, target)
534534

535535

536536
def ss_reconstruction_loss(ss, recon_ss, mask_plddt=False, plddt_threshold=0.3 , plddt_mask = None , normalize = False):
@@ -570,18 +570,20 @@ def ss_reconstruction_loss(ss, recon_ss, mask_plddt=False, plddt_threshold=0.3 ,
570570
Masking by pLDDT is recommended for AlphaFold structures where low-confidence
571571
regions may have unreliable secondary structure assignments.
572572
"""
573+
574+
target = ss.argmax(dim=-1)
573575
if mask_plddt:
574576
# Create boolean mask for high-confidence residues
575577
mask = (plddt_mask > plddt_threshold).squeeze()
576578
if mask.sum() > 0:
577579
# Compute loss only on masked residues
578-
ss_loss = F.cross_entropy(recon_ss[mask], ss[mask])
580+
ss_loss = F.cross_entropy(recon_ss[mask], target[mask])
579581
else:
580582
# No residues pass threshold - return zero loss to prevent NaN
581583
ss_loss = torch.tensor(0.0, device=recon_ss.device)
582584
else:
583585
# Compute loss on all residues
584-
ss_loss = F.cross_entropy(recon_ss, ss)
586+
ss_loss = F.cross_entropy(recon_ss, target)
585587
return ss_loss
586588

587589

foldtree2/src/mono_decoders.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def __init__(self, in_channels={'res': 10, 'godnode4decoder': 5, 'foldx': 23},
486486
nn.Linear(anglesdecoder_hidden[0], anglesdecoder_hidden[1]),
487487
nn.GELU(),
488488
nn.Linear(anglesdecoder_hidden[1], 3),
489-
nn.Softmax(dim=-1)
489+
#nn.Softmax(dim=-1)
490490
)
491491

492492
# Bond angles prediction
@@ -758,7 +758,7 @@ def __init__(
758758
nn.Conv1d(AAdecoder_hidden[1], AAdecoder_hidden[2] if len(AAdecoder_hidden) > 2 else AAdecoder_hidden[1], kernel_size=3, padding=1),
759759
nn.GELU(),
760760
nn.Conv1d(AAdecoder_hidden[2] if len(AAdecoder_hidden) > 2 else AAdecoder_hidden[1], 20, kernel_size=1), # 20 amino acids
761-
nn.Softmax(dim=1) # Probabilities for amino acid classes
761+
#nn.Softmax(dim=1) # Probabilities for amino acid classes
762762
)
763763

764764
# Optional secondary structure prediction head
@@ -772,7 +772,7 @@ def __init__(
772772
nn.Conv1d(ssdecoder_hidden[1], ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], kernel_size=3, padding=1),
773773
nn.GELU(),
774774
nn.Conv1d(ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], 3, kernel_size=1), # 3 SS classes
775-
nn.Softmax(dim=1) # Probabilities for SS classes
775+
#nn.Softmax(dim=1) # Probabilities for SS classes
776776
)
777777

778778
def forward(self, data, **kwargs):
@@ -862,7 +862,7 @@ def forward(self, data, **kwargs):
862862
# AA prediction
863863
aa_out = self.head['aa_cnn'](x_cnn) # (1, 20, seq_len)
864864
aa_out = aa_out.permute(2, 0, 1).squeeze(1) # (seq_len, 20)
865-
aa_list.append(F.log_softmax(aa_out, dim=-1))
865+
aa_list.append(aa_out)
866866

867867
# SS prediction if enabled
868868
if self.output_ss and 'ss_cnn' in self.head:
@@ -880,7 +880,6 @@ def forward(self, data, **kwargs):
880880
# AA prediction
881881
aa = self.head['aa_cnn'](x_cnn) # (1, 20, seq_len)
882882
aa = aa.permute(2, 0, 1).squeeze(1) # (seq_len, 20)
883-
aa = F.log_softmax(aa, dim=-1)
884883

885884
# SS prediction if enabled
886885
if self.output_ss and 'ss_cnn' in self.head:
@@ -995,7 +994,7 @@ def __init__(
995994
nn.Conv1d(AAdecoder_hidden[1], AAdecoder_hidden[2], kernel_size=3, padding=1),
996995
nn.GELU(),
997996
nn.Conv1d(AAdecoder_hidden[2], 20, kernel_size=1),
998-
nn.Softmax(dim=-1)
997+
#nn.Softmax(dim=-1)
999998
)
1000999
else:
10011000
# DNN decoder for amino acid prediction
@@ -1007,7 +1006,7 @@ def __init__(
10071006
nn.Linear(AAdecoder_hidden[1], AAdecoder_hidden[2]),
10081007
nn.GELU(),
10091008
nn.Linear(AAdecoder_hidden[2], 20),
1010-
nn.Softmax(dim=-1)
1009+
#nn.Softmax(dim=-1)
10111010
)
10121011

10131012
# Optional secondary structure prediction head
@@ -1020,7 +1019,7 @@ def __init__(
10201019
nn.Linear(AAdecoder_hidden[1], AAdecoder_hidden[2]),
10211020
nn.GELU(),
10221021
nn.Linear(AAdecoder_hidden[2], 3),
1023-
nn.Softmax(dim=-1)
1022+
#nn.Softmax(dim=-1)
10241023
)
10251024

10261025
def forward(self, data, **kwargs):
@@ -1089,7 +1088,7 @@ def forward(self, data, **kwargs):
10891088
xi_cnn = xi.permute(1, 0).unsqueeze(0) # (1, d_model, seq_len)
10901089
xi_cnn = self.head['cnn_decoder'](xi_cnn) # (1, 20, seq_len)
10911090
xi_cnn = xi_cnn.permute(2, 0, 1).squeeze(1) # (seq_len, 20)
1092-
aa_list.append(F.log_softmax(xi_cnn[:seq_len, :], dim=-1))
1091+
aa_list.append(xi_cnn[:seq_len, :])
10931092
else:
10941093
aa_list.append(self.head['dnn_decoder'](xi[:seq_len, 0]))
10951094
if 'ss_head' in self.head:
@@ -1108,7 +1107,7 @@ def forward(self, data, **kwargs):
11081107
x_cnn = x.permute(1, 2, 0) # (batch, d_model, seq_len)
11091108
x_cnn = self.head['cnn_decoder'](x_cnn) # (batch, 20, seq_len)
11101109
x_cnn = x_cnn.permute(2, 0, 1) # (seq_len, batch, 20)
1111-
aa = F.log_softmax(x_cnn, dim=-1)
1110+
aa = x_cnn.squeeze(1) # (seq_len, 20)
11121111
else:
11131112
aa = self.head['dnn_decoder'](x)
11141113
if 'ss_head' in self.head:
@@ -1261,7 +1260,7 @@ def __init__(
12611260
nn.Conv1d(ssdecoder_hidden[1], ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], kernel_size=3, padding=1),
12621261
nn.GELU(),
12631262
nn.Conv1d(ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], 3, kernel_size=1),
1264-
nn.LogSoftmax(dim=-1)
1263+
#nn.LogSoftmax(dim=-1)
12651264

12661265
)
12671266
else:
@@ -1274,7 +1273,7 @@ def __init__(
12741273
nn.Linear(ssdecoder_hidden[1], ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1]),
12751274
nn.GELU(),
12761275
nn.Linear(ssdecoder_hidden[2] if len(ssdecoder_hidden) > 2 else ssdecoder_hidden[1], 3),
1277-
nn.LogSoftmax(dim=-1)
1276+
#nn.LogSoftmax(dim=-1)
12781277
)
12791278

12801279
# Bond angles prediction head (phi, psi, omega)

0 commit comments

Comments
 (0)