Skip to content

Commit b6cfd7b

Browse files
author
dmoi
committed
messing with encoder
1 parent fd314f8 commit b6cfd7b

4 files changed

Lines changed: 562 additions & 106 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ models/
9898
**/checkpoints/
9999
***/logs/
100100
**/**__pycache__/
101+
**/aster/
101102

102103

103104
# Data files

foldtree2/notebooks/experiments/test_monodecoders.ipynb

Lines changed: 544 additions & 92 deletions
Large diffs are not rendered by default.
168 Bytes
Binary file not shown.

foldtree2/src/losses/losses.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def jaccard_distance_multiset(A: torch.Tensor,
5353
jaccard_similarity = min_sum / (max_sum + eps)
5454
return jaccard_similarity
5555

56-
def recon_loss_diag(data, pos_edge_index: Tensor, decoder=None, poslossmod=1, neglossmod=1, plddt=False, nclamp=30, key=None , nbins=8) -> Tensor:
56+
def recon_loss_diag(data, pos_edge_index: Tensor, decoder=None, poslossmod=1, neglossmod=1, plddt=False, nclamp=30, key=None , nbins=8 , plddt_thresh=0.3) -> Tensor:
5757
# Remove the diagonal
5858
pos_edge_index = pos_edge_index[:, pos_edge_index[0] != pos_edge_index[1]]
5959
res = decoder(data, pos_edge_index)
@@ -71,13 +71,13 @@ def recon_loss_diag(data, pos_edge_index: Tensor, decoder=None, poslossmod=1, ne
7171

7272
if 'edge_logits' in res and res['edge_logits'] is not None:
7373
#apply recon loss disto
74-
disto_loss_pos = recon_loss_disto(data, res, pos_edge_index, plddt=plddt, key='edge_logits', no_bins=nbins)
74+
disto_loss_pos = recon_loss_disto(data, res, pos_edge_index, plddt=plddt, key='edge_logits', no_bins=nbins , plddt_thresh=plddt_thresh)
7575

7676
if plddt == True:
7777
c1 = data['plddt'].x[pos_edge_index[0]].squeeze(1)
7878
c2 = data['plddt'].x[pos_edge_index[1]].squeeze(1)
79-
c1 = c1 > .30
80-
c2 = c2 > .30
79+
c1 = c1 > plddt_thresh
80+
c2 = c2 > plddt_thresh
8181
mask = c1 & c2
8282
mask = mask.squeeze(0) # Ensure mask is 1D
8383
pos_loss = pos_loss[mask]
@@ -98,20 +98,20 @@ def recon_loss_diag(data, pos_edge_index: Tensor, decoder=None, poslossmod=1, ne
9898
if plddt == True:
9999
c1 = data['plddt'].x[neg_edge_index[0]].squeeze(1)
100100
c2 = data['plddt'].x[neg_edge_index[1]].squeeze(1)
101-
c1 = c1 > .30
102-
c2 = c2 > .30
101+
c1 = c1 > plddt_thresh
102+
c2 = c2 > plddt_thresh
103103
mask = c1 & c2
104104
mask = mask.squeeze(0) # Ensure mask is 1D
105105
neg_loss = neg_loss[mask]
106106

107107
neg_loss = neg_loss.mean()
108108
if 'edge_logits' in res and res['edge_logits'] is not None:
109109
#apply recon loss disto
110-
disto_loss_neg = recon_loss_disto(data, res, neg_edge_index, plddt=plddt, key='edge_logits' , no_bins=nbins)
110+
disto_loss_neg = recon_loss_disto(data, res, neg_edge_index, plddt=plddt, key='edge_logits' , no_bins=nbins , plddt_thresh=plddt_thresh)
111111

112112
return poslossmod*pos_loss + neglossmod*neg_loss, disto_loss_pos * poslossmod + disto_loss_neg * neglossmod
113113

114-
def prody_reconstruction_loss(data, decoder=None, poslossmod=1, neglossmod=1, plddt=False, nclamp=30, key=None) -> Tensor:
114+
def prody_reconstruction_loss(data, decoder=None, poslossmod=1, neglossmod=1, plddt=False, nclamp=30, key=None , plddt_thresh=0.3) -> Tensor:
115115
for interaction_type in []:
116116
# Remove the diagonal
117117
pos_edge_index = data[f'{interaction_type}_edge_index']
@@ -161,11 +161,14 @@ def angles_reconstruction_loss(true, pred):
161161
return (1.0 - torch.cos(delta)).mean()
162162
"""
163163

164-
def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None):
164+
def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None , plddt_thresh = 0.3):
165165
delta = torch.atan2(torch.sin(pred - true), torch.cos(pred - true))
166-
loss = F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta)
167166
if plddt_mask is not None:
168-
loss = loss * plddt_mask
167+
mask = plddt_mask > plddt_thresh
168+
mask = mask.squeeze(1) # Ensure mask is 1D
169+
delta = delta[mask]
170+
loss = F.smooth_l1_loss(delta, torch.zeros_like(delta), beta=beta)
171+
169172
return loss.mean()
170173

171174

@@ -179,7 +182,7 @@ def gaussian_loss(mu , logvar , beta= 1.5):
179182
return beta*kl_loss
180183

181184

182-
def recon_loss_disto(data , res , edge_index: Tensor, plddt = True , nclamp = 30 ,no_bins = 8 , key = None) -> Tensor:
185+
def recon_loss_disto(data , res , edge_index: Tensor, plddt = True , nclamp = 30 ,no_bins = 8 , key = None , plddt_thresh=0.3) -> Tensor:
183186

184187
'''
185188
Calculates a reconstruction loss based on predicted and true coordinates, with optional filtering by pLDDT confidence and off-diagonal weighting.
@@ -204,8 +207,8 @@ def recon_loss_disto(data , res , edge_index: Tensor, plddt = True , nclamp =
204207
c1 = data['plddt'].x[edge_index[0]].view(-1,1)
205208
c2 = data['plddt'].x[edge_index[1]].view(-1,1)
206209
#both have to be above .5, binary and operation
207-
c1 = c1 > .30
208-
c2 = c2 > .30
210+
c1 = c1 > plddt_thresh
211+
c2 = c2 > plddt_thresh
209212
mask = c1 & c2
210213
mask = mask.squeeze(1) # Ensure mask is 1D
211214
disto_loss = disto_loss[mask]

0 commit comments

Comments
 (0)