@@ -53,7 +53,7 @@ def jaccard_distance_multiset(A: torch.Tensor,
5353 jaccard_similarity = min_sum / (max_sum + eps )
5454 return jaccard_similarity
5555
56- def recon_loss_diag (data , pos_edge_index : Tensor , decoder = None , poslossmod = 1 , neglossmod = 1 , plddt = False , nclamp = 30 , key = None , nbins = 8 ) -> Tensor :
56+ def recon_loss_diag (data , pos_edge_index : Tensor , decoder = None , poslossmod = 1 , neglossmod = 1 , plddt = False , nclamp = 30 , key = None , nbins = 8 , plddt_thresh = 0.3 ) -> Tensor :
5757 # Remove the diagonal
5858 pos_edge_index = pos_edge_index [:, pos_edge_index [0 ] != pos_edge_index [1 ]]
5959 res = decoder (data , pos_edge_index )
@@ -71,13 +71,13 @@ def recon_loss_diag(data, pos_edge_index: Tensor, decoder=None, poslossmod=1, ne
7171
7272 if 'edge_logits' in res and res ['edge_logits' ] is not None :
7373 #apply recon loss disto
74- disto_loss_pos = recon_loss_disto (data , res , pos_edge_index , plddt = plddt , key = 'edge_logits' , no_bins = nbins )
74+ disto_loss_pos = recon_loss_disto (data , res , pos_edge_index , plddt = plddt , key = 'edge_logits' , no_bins = nbins , plddt_thresh = plddt_thresh )
7575
7676 if plddt == True :
7777 c1 = data ['plddt' ].x [pos_edge_index [0 ]].squeeze (1 )
7878 c2 = data ['plddt' ].x [pos_edge_index [1 ]].squeeze (1 )
79- c1 = c1 > .30
80- c2 = c2 > .30
79+ c1 = c1 > plddt_thresh
80+ c2 = c2 > plddt_thresh
8181 mask = c1 & c2
8282 mask = mask .squeeze (0 ) # Ensure mask is 1D
8383 pos_loss = pos_loss [mask ]
@@ -98,20 +98,20 @@ def recon_loss_diag(data, pos_edge_index: Tensor, decoder=None, poslossmod=1, ne
9898 if plddt == True :
9999 c1 = data ['plddt' ].x [neg_edge_index [0 ]].squeeze (1 )
100100 c2 = data ['plddt' ].x [neg_edge_index [1 ]].squeeze (1 )
101- c1 = c1 > .30
102- c2 = c2 > .30
101+ c1 = c1 > plddt_thresh
102+ c2 = c2 > plddt_thresh
103103 mask = c1 & c2
104104 mask = mask .squeeze (0 ) # Ensure mask is 1D
105105 neg_loss = neg_loss [mask ]
106106
107107 neg_loss = neg_loss .mean ()
108108 if 'edge_logits' in res and res ['edge_logits' ] is not None :
109109 #apply recon loss disto
110- disto_loss_neg = recon_loss_disto (data , res , neg_edge_index , plddt = plddt , key = 'edge_logits' , no_bins = nbins )
110+ disto_loss_neg = recon_loss_disto (data , res , neg_edge_index , plddt = plddt , key = 'edge_logits' , no_bins = nbins , plddt_thresh = plddt_thresh )
111111
112112 return poslossmod * pos_loss + neglossmod * neg_loss , disto_loss_pos * poslossmod + disto_loss_neg * neglossmod
113113
114- def prody_reconstruction_loss (data , decoder = None , poslossmod = 1 , neglossmod = 1 , plddt = False , nclamp = 30 , key = None ) -> Tensor :
114+ def prody_reconstruction_loss (data , decoder = None , poslossmod = 1 , neglossmod = 1 , plddt = False , nclamp = 30 , key = None , plddt_thresh = 0.3 ) -> Tensor :
115115 for interaction_type in []:
116116 # Remove the diagonal
117117 pos_edge_index = data [f'{ interaction_type } _edge_index' ]
@@ -161,11 +161,14 @@ def angles_reconstruction_loss(true, pred):
161161 return (1.0 - torch.cos(delta)).mean()
162162"""
163163
164- def angles_reconstruction_loss (true , pred , beta = 0.5 , plddt_mask = None ):
164+ def angles_reconstruction_loss (true , pred , beta = 0.5 , plddt_mask = None , plddt_thresh = 0.3 ):
165165 delta = torch .atan2 (torch .sin (pred - true ), torch .cos (pred - true ))
166- loss = F .smooth_l1_loss (delta , torch .zeros_like (delta ), beta = beta )
167166 if plddt_mask is not None :
168- loss = loss * plddt_mask
167+ mask = plddt_mask > plddt_thresh
168+ mask = mask .squeeze (1 ) # Ensure mask is 1D
169+ delta = delta [mask ]
170+ loss = F .smooth_l1_loss (delta , torch .zeros_like (delta ), beta = beta )
171+
169172 return loss .mean ()
170173
171174
@@ -179,7 +182,7 @@ def gaussian_loss(mu , logvar , beta= 1.5):
179182 return beta * kl_loss
180183
181184
182- def recon_loss_disto (data , res , edge_index : Tensor , plddt = True , nclamp = 30 ,no_bins = 8 , key = None ) -> Tensor :
185+ def recon_loss_disto (data , res , edge_index : Tensor , plddt = True , nclamp = 30 ,no_bins = 8 , key = None , plddt_thresh = 0.3 ) -> Tensor :
183186
184187 '''
185188 Calculates a reconstruction loss based on predicted and true coordinates, with optional filtering by pLDDT confidence and off-diagonal weighting.
@@ -204,8 +207,8 @@ def recon_loss_disto(data , res , edge_index: Tensor, plddt = True , nclamp =
204207 c1 = data ['plddt' ].x [edge_index [0 ]].view (- 1 ,1 )
205208 c2 = data ['plddt' ].x [edge_index [1 ]].view (- 1 ,1 )
206209 #both have to be above .5, binary and operation
207- c1 = c1 > .30
208- c2 = c2 > .30
210+ c1 = c1 > plddt_thresh
211+ c2 = c2 > plddt_thresh
209212 mask = c1 & c2
210213 mask = mask .squeeze (1 ) # Ensure mask is 1D
211214 disto_loss = disto_loss [mask ]
0 commit comments