@@ -261,13 +261,24 @@ def print_about():
261261 help = 'Learning rate for AdamW when using Muon (default: 3e-4)' )
262262
263263# Mixed precision and pLDDT masking
264+
264265parser .add_argument ('--mixed-precision' , action = 'store_true' , default = True ,
265266 help = 'Use mixed precision training (default: True)' )
266267parser .add_argument ('--mask-plddt' , action = 'store_true' ,
267268 help = 'Mask low pLDDT residues in loss calculations' )
268269parser .add_argument ('--plddt-threshold' , type = float , default = 0.3 ,
269270 help = 'pLDDT threshold for masking (default: 0.3)' )
270271
272+ # lDDT and FAPE loss arguments
273+ parser .add_argument ('--lddt-loss' , action = 'store_true' , default = False ,
274+ help = 'Enable lDDT loss during training' )
275+ parser .add_argument ('--lddt-weight' , type = float , default = 0.0 ,
276+ help = 'Weight for lDDT loss (default: 0.0)' )
277+ parser .add_argument ('--fape-loss' , action = 'store_true' , default = False ,
278+ help = 'Enable FAPE loss during training' )
279+ parser .add_argument ('--fape-weight' , type = float , default = 0.0 ,
280+ help = 'Weight for FAPE loss (default: 0.0)' )
281+
271282# Multi-GPU settings
272283parser .add_argument ('--gpus' , type = int , default = - 1 ,
273284 help = 'Number of GPUs to use (default: -1, use all available GPUs; set to specific number to limit)' )
@@ -470,36 +481,50 @@ def forward(self, data):
470481 data ['res' ].x = z
471482 out = self .decoder (data , None )
472483 return out , vqloss
484+
485+ @staticmethod
486+ def _batch_has_invalid_inputs (data ):
487+ for node_type in getattr (data , 'node_types' , []):
488+ node_x = getattr (data [node_type ], 'x' , None )
489+ if node_x is not None and (torch .isnan (node_x ).any () or torch .isinf (node_x ).any ()):
490+ return True
491+ return False
473492
474493 def training_step (self , batch , batch_idx ):
475494 data = batch
476-
495+
496+ batch_size = data ['res' ].batch .max ().item () + 1 if hasattr (data ['res' ], 'batch' ) and data ['res' ].batch is not None else 1
497+
498+ if self ._batch_has_invalid_inputs (data ):
499+ self .log ('train/skipped_bad_batch' , 1.0 , on_step = True , on_epoch = True , batch_size = batch_size )
500+ return torch .zeros ((), device = self .device , requires_grad = True )
501+
477502 # Forward pass
478503 out , vqloss = self (data )
479-
504+
480505 # Get edge index
481506 edge_index = data .edge_index_dict .get (('res' , 'contactPoints' , 'res' )) if hasattr (data , 'edge_index_dict' ) else None
482-
507+
483508 # Edge reconstruction loss
484509 logitloss = torch .tensor (0.0 , device = self .device )
485510 edgeloss = torch .tensor (0.0 , device = self .device )
486511 if edge_index is not None :
487512 edgeloss , logitloss = recon_loss_diag (data , edge_index , self .decoder , plddt = self .args .mask_plddt , key = 'edge_probs' )
488-
513+
489514 # Amino acid reconstruction loss
490515 xloss = aa_reconstruction_loss (data ['AA' ].x , out ['aa' ])
491-
516+
492517 # FFT2 loss
493518 fft2loss = torch .tensor (0.0 , device = self .device )
494519 if 'fft2pred' in out and out ['fft2pred' ] is not None :
495520 fft2loss = F .smooth_l1_loss (torch .cat ([data ['fourier2dr' ].x , data ['fourier2di' ].x ], axis = 1 ), out ['fft2pred' ])
496-
521+
497522 # Angles loss
498523 angles_loss = torch .tensor (0.0 , device = self .device )
499524 if out .get ('angles' ) is not None :
500525 angles_loss = angles_reconstruction_loss (out ['angles' ], data ['bondangles' ].x ,
501526 plddt_mask = data ['plddt' ].x if self .args .mask_plddt else None )
502-
527+
503528 # Secondary structure loss
504529 ss_loss = torch .tensor (0.0 , device = self .device )
505530 if out .get ('ss_pred' ) is not None :
@@ -509,16 +534,40 @@ def training_step(self, batch, batch_idx):
509534 ss_loss = F .cross_entropy (out ['ss_pred' ][mask ], data ['ss' ].x [mask ])
510535 else :
511536 ss_loss = F .cross_entropy (out ['ss_pred' ], data ['ss' ].x )
512-
537+
538+ # lDDT loss
539+ lddt_loss = torch .tensor (0.0 , device = self .device )
540+ if getattr (self .args , 'lddt_loss' , False ):
541+ from foldtree2 .src .losses .losses import lddt_reconstruction_loss
542+ # Use predicted and true coordinates (assume out['coords'] and data['coords'].x)
543+ if out .get ('coords' ) is not None and hasattr (data , 'coords' ) and hasattr (data ['coords' ], 'x' ):
544+ lddt_loss = lddt_reconstruction_loss (
545+ out ['coords' ], data ['coords' ].x ,
546+ plddt = data ['plddt' ].x if self .args .mask_plddt else None ,
547+ plddt_thresh = self .args .plddt_threshold if self .args .mask_plddt else 0.0
548+ )
549+
550+ # FAPE loss
551+ fape_loss = torch .tensor (0.0 , device = self .device )
552+ if getattr (self .args , 'fape_loss' , False ):
553+ from foldtree2 .src .losses .losses import quaternion_fape_loss
554+ # Use predicted and true quaternion frames (assume out['quat'], out['trans'], data['quat'].x, data['trans'].x)
555+ if all ([out .get ('quat' ) is not None , out .get ('trans' ) is not None , hasattr (data , 'quat' ), hasattr (data ['quat' ], 'x' ), hasattr (data , 'trans' ), hasattr (data ['trans' ], 'x' )]):
556+ fape_loss = quaternion_fape_loss (
557+ data ['quat' ].x , data ['trans' ].x ,
558+ out ['quat' ], out ['trans' ]
559+ )
560+
513561 # Total loss
514562 loss = (self .xweight * xloss + self .edgeweight * edgeloss + self .vqweight * vqloss +
515563 self .fft2weight * fft2loss + self .angles_weight * angles_loss +
516- self .ss_weight * ss_loss + self .logitweight * logitloss )
517-
518- # Get batch size from PyG batch (number of graphs in batch)
519- # For PyTorch Geometric, we need to count unique batch indices
520- batch_size = data ['res' ].batch .max ().item () + 1 if hasattr (data ['res' ], 'batch' ) and data ['res' ].batch is not None else 1
521-
564+ self .ss_weight * ss_loss + self .logitweight * logitloss +
565+ self .args .lddt_weight * lddt_loss + self .args .fape_weight * fape_loss )
566+
567+ if not torch .isfinite (loss ):
568+ self .log ('train/skipped_nonfinite_loss' , 1.0 , on_step = True , on_epoch = True , batch_size = batch_size )
569+ return torch .zeros ((), device = self .device , requires_grad = True )
570+
522571 # Log metrics with explicit batch_size to avoid iteration over FeatureStore
523572 self .log ('train/loss' , loss , on_step = True , on_epoch = True , prog_bar = True , batch_size = batch_size )
524573 self .log ('train/aa_loss' , xloss , on_step = False , on_epoch = True , batch_size = batch_size )
@@ -528,13 +577,15 @@ def training_step(self, batch, batch_idx):
528577 self .log ('train/angles_loss' , angles_loss , on_step = False , on_epoch = True , batch_size = batch_size )
529578 self .log ('train/ss_loss' , ss_loss , on_step = False , on_epoch = True , batch_size = batch_size )
530579 self .log ('train/logit_loss' , logitloss , on_step = False , on_epoch = True , batch_size = batch_size )
531-
580+ self .log ('train/lddt_loss' , lddt_loss , on_step = False , on_epoch = True , batch_size = batch_size )
581+ self .log ('train/fape_loss' , fape_loss , on_step = False , on_epoch = True , batch_size = batch_size )
582+
532583 # Log commitment cost if using scheduling
533584 if self .args .use_commitment_scheduling and hasattr (self .encoder , 'vector_quantizer' ):
534585 current_commitment = self .encoder .vector_quantizer .get_commitment_cost ()
535586 self .log ('train/commitment_cost' , current_commitment , on_step = False , on_epoch = True , batch_size = batch_size )
536-
537-
587+
588+ return loss
538589 # Clear CUDA cache
539590 torch .cuda .empty_cache ()
540591 gc .collect ()
@@ -884,12 +935,29 @@ def has_modular_structure(model):
884935else :
885936 strategy = 'auto'
886937
938+ # Mirror notebook stability settings: avoid flash/mem-efficient SDP kernels.
939+ if torch .cuda .is_available () and hasattr (torch .backends , 'cuda' ):
940+ try :
941+ torch .backends .cuda .enable_flash_sdp (False )
942+ torch .backends .cuda .enable_mem_efficient_sdp (False )
943+ torch .backends .cuda .enable_math_sdp (True )
944+ print ("Using math SDP kernel (flash and mem-efficient disabled for stability)" )
945+ except Exception as e :
946+ print (f"Warning: could not configure SDP kernels: { e } " )
947+
948+ trainer_precision = 32
949+ if args .mixed_precision :
950+ if torch .cuda .is_available () and torch .cuda .is_bf16_supported ():
951+ trainer_precision = 'bf16-mixed'
952+ else :
953+ trainer_precision = '16-mixed'
954+
887955trainer = pl .Trainer (
888956 max_epochs = args .epochs ,
889957 accelerator = 'gpu' if torch .cuda .is_available () else 'cpu' ,
890958 devices = devices ,
891959 strategy = strategy ,
892- precision = '16-mixed' if args . mixed_precision else 32 ,
960+ precision = trainer_precision ,
893961 gradient_clip_val = 1.0 if args .clip_grad else 0 ,
894962 # Lightning automatically passes this to DeepSpeed config as gradient_accumulation_steps
895963 accumulate_grad_batches = args .gradient_accumulation_steps ,
@@ -908,6 +976,8 @@ def has_modular_structure(model):
908976if args .mask_plddt :
909977 print (f" pLDDT threshold: { args .plddt_threshold } " )
910978print (f" Mixed precision: { args .mixed_precision } " )
979+ if args .mixed_precision :
980+ print (f" Trainer precision mode: { trainer_precision } " )
911981print (f" Gradient clipping: { args .clip_grad } " )
912982print ()
913983
0 commit comments