@@ -537,11 +537,11 @@ def training_step(self, batch, batch_idx):
537537
538538 # lDDT loss
539539 lddt_loss = torch .tensor (0.0 , device = self .device )
540- if (self .args .lddt_weight > 0 or getattr (self .args , 'lddt_loss' , False )) and out .get ('coords ' ) is not None and hasattr (data , 'coords' ) and hasattr (data ['coords' ], 'x' ):
540+ if (self .args .lddt_weight > 0 or getattr (self .args , 'lddt_loss' , False )) and out .get ('quat_pred' ) is not None and out . get ( 'trans_pred ' ) is not None and hasattr (data , 'coords' ) and hasattr (data ['coords' ], 'x' ):
541541 from foldtree2 .src .losses .losses import batch_lddt_loss
542542 lddt_loss = batch_lddt_loss (
543- pred_q = out .get ('quat' , None ),
544- pred_t = out .get ('trans' , None ),
543+ pred_q = out .get ('quat_pred' ),
544+ pred_t = out .get ('trans_pred' ),
545545 true_coords = data ['coords' ].x ,
546546 batch = getattr (data ['res' ], 'batch' , None ),
547547 plddt = data ['plddt' ].x if self .args .mask_plddt else None ,
@@ -550,26 +550,37 @@ def training_step(self, batch, batch_idx):
550550
551551 # FAPE loss
552552 fape_loss = torch .tensor (0.0 , device = self .device )
553- if (self .args .fape_weight > 0 or getattr (self .args , 'fape_loss' , False )) and out .get ('quat ' ) is not None and out .get ('trans ' ) is not None and hasattr (data , 'quat ' ) and hasattr (data ['quat ' ], 'x' ) and hasattr (data , 'trans ' ) and hasattr (data ['trans ' ], 'x' ):
553+ if (self .args .fape_weight > 0 or getattr (self .args , 'fape_loss' , False )) and out .get ('quat_pred ' ) is not None and out .get ('trans_pred ' ) is not None and hasattr (data , 'q_true ' ) and hasattr (data ['q_true ' ], 'x' ) and hasattr (data , 'coords ' ) and hasattr (data ['coords ' ], 'x' ):
554554 from foldtree2 .src .losses .losses import batch_fape_loss
555+ _fape_batch = getattr (data ['res' ], 'batch' , None )
556+ _pred_disp = out ['trans_pred' ]
557+ # Convert CA-to-CA displacements to CA positions (cumsum per structure)
558+ # FAPE requires absolute positions; cumsum from origin is translation-invariant
559+ if _fape_batch is not None :
560+ _pred_pos = torch .zeros_like (_pred_disp )
561+ for _b in torch .unique (_fape_batch ):
562+ _m = (_fape_batch == _b ).nonzero (as_tuple = True )[0 ]
563+ _pred_pos [_m ] = torch .cumsum (_pred_disp [_m ], dim = 0 )
564+ else :
565+ _pred_pos = torch .cumsum (_pred_disp , dim = 0 )
555566 fape_loss = batch_fape_loss (
556- true_q = data ['quat ' ].x ,
557- true_t = data ['trans ' ].x ,
558- pred_q = out ['quat ' ],
559- pred_t = out [ 'trans' ] ,
560- batch = getattr ( data [ 'res' ], 'batch' , None ) ,
567+ true_q = data ['q_true ' ].x ,
568+ true_t = data ['coords ' ].x ,
569+ pred_q = out ['quat_pred ' ],
570+ pred_t = _pred_pos ,
571+ batch = _fape_batch ,
561572 )
562573
563574 # Delta loss
564575 delta_loss_val = torch .tensor (0.0 , device = self .device )
565- if (self .args .delta_weight > 0 or getattr (self .args , 'delta_loss' , False )) and out .get ('coords' ) is not None and hasattr (data , 'coords' ) and hasattr (data ['coords' ], 'x' ):
576+ if (self .args .delta_weight > 0 or getattr (self .args , 'delta_loss' , False )) and ( out .get ('quat_pred' ) is not None or out . get ( ' coords' ) is not None ) and hasattr (data , 'coords' ) and hasattr (data ['coords' ], 'x' ):
566577 from foldtree2 .src .losses .losses import batch_delta_loss
567578 try :
568- if out .get ('quat ' ) is not None and out .get ('trans ' ) is not None :
579+ if out .get ('quat_pred ' ) is not None and out .get ('trans_pred ' ) is not None :
569580 delta_loss_val = batch_delta_loss (
570581 true_ca = data ['coords' ].x ,
571- pred_q = out ['quat ' ],
572- pred_t = out ['trans ' ],
582+ pred_q = out ['quat_pred ' ],
583+ pred_t = out ['trans_pred ' ],
573584 batch = getattr (data ['res' ], 'batch' , None ),
574585 plddt = data ['plddt' ].x if self .args .mask_plddt else None ,
575586 plddt_thresh = self .args .plddt_threshold if self .args .mask_plddt else 0.0 ,
0 commit comments