Skip to content

Commit 44de813

Browse files
committed
fix overflow with precision adjustment.
1 parent 5dfdcdd commit 44de813

9 files changed

Lines changed: 1362 additions & 235 deletions

File tree

foldtree2/learn_folding.py

Lines changed: 647 additions & 0 deletions
Large diffs are not rendered by default.

foldtree2/learn_lightning.py

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,24 @@ def print_about():
261261
help='Learning rate for AdamW when using Muon (default: 3e-4)')
262262

263263
# Mixed precision and pLDDT masking
264+
264265
parser.add_argument('--mixed-precision', action='store_true', default=True,
265266
help='Use mixed precision training (default: True)')
266267
parser.add_argument('--mask-plddt', action='store_true',
267268
help='Mask low pLDDT residues in loss calculations')
268269
parser.add_argument('--plddt-threshold', type=float, default=0.3,
269270
help='pLDDT threshold for masking (default: 0.3)')
270271

272+
# lDDT and FAPE loss arguments
273+
parser.add_argument('--lddt-loss', action='store_true', default=False,
274+
help='Enable lDDT loss during training')
275+
parser.add_argument('--lddt-weight', type=float, default=0.0,
276+
help='Weight for lDDT loss (default: 0.0)')
277+
parser.add_argument('--fape-loss', action='store_true', default=False,
278+
help='Enable FAPE loss during training')
279+
parser.add_argument('--fape-weight', type=float, default=0.0,
280+
help='Weight for FAPE loss (default: 0.0)')
281+
271282
# Multi-GPU settings
272283
parser.add_argument('--gpus', type=int, default=-1,
273284
help='Number of GPUs to use (default: -1, use all available GPUs; set to specific number to limit)')
@@ -470,36 +481,50 @@ def forward(self, data):
470481
data['res'].x = z
471482
out = self.decoder(data, None)
472483
return out, vqloss
484+
485+
@staticmethod
486+
def _batch_has_invalid_inputs(data):
487+
for node_type in getattr(data, 'node_types', []):
488+
node_x = getattr(data[node_type], 'x', None)
489+
if node_x is not None and (torch.isnan(node_x).any() or torch.isinf(node_x).any()):
490+
return True
491+
return False
473492

474493
def training_step(self, batch, batch_idx):
475494
data = batch
476-
495+
496+
batch_size = data['res'].batch.max().item() + 1 if hasattr(data['res'], 'batch') and data['res'].batch is not None else 1
497+
498+
if self._batch_has_invalid_inputs(data):
499+
self.log('train/skipped_bad_batch', 1.0, on_step=True, on_epoch=True, batch_size=batch_size)
500+
return torch.zeros((), device=self.device, requires_grad=True)
501+
477502
# Forward pass
478503
out, vqloss = self(data)
479-
504+
480505
# Get edge index
481506
edge_index = data.edge_index_dict.get(('res', 'contactPoints', 'res')) if hasattr(data, 'edge_index_dict') else None
482-
507+
483508
# Edge reconstruction loss
484509
logitloss = torch.tensor(0.0, device=self.device)
485510
edgeloss = torch.tensor(0.0, device=self.device)
486511
if edge_index is not None:
487512
edgeloss, logitloss = recon_loss_diag(data, edge_index, self.decoder, plddt=self.args.mask_plddt, key='edge_probs')
488-
513+
489514
# Amino acid reconstruction loss
490515
xloss = aa_reconstruction_loss(data['AA'].x, out['aa'])
491-
516+
492517
# FFT2 loss
493518
fft2loss = torch.tensor(0.0, device=self.device)
494519
if 'fft2pred' in out and out['fft2pred'] is not None:
495520
fft2loss = F.smooth_l1_loss(torch.cat([data['fourier2dr'].x, data['fourier2di'].x], axis=1), out['fft2pred'])
496-
521+
497522
# Angles loss
498523
angles_loss = torch.tensor(0.0, device=self.device)
499524
if out.get('angles') is not None:
500525
angles_loss = angles_reconstruction_loss(out['angles'], data['bondangles'].x,
501526
plddt_mask=data['plddt'].x if self.args.mask_plddt else None)
502-
527+
503528
# Secondary structure loss
504529
ss_loss = torch.tensor(0.0, device=self.device)
505530
if out.get('ss_pred') is not None:
@@ -509,16 +534,40 @@ def training_step(self, batch, batch_idx):
509534
ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
510535
else:
511536
ss_loss = F.cross_entropy(out['ss_pred'], data['ss'].x)
512-
537+
538+
# lDDT loss
539+
lddt_loss = torch.tensor(0.0, device=self.device)
540+
if getattr(self.args, 'lddt_loss', False):
541+
from foldtree2.src.losses.losses import lddt_reconstruction_loss
542+
# Use predicted and true coordinates (assume out['coords'] and data['coords'].x)
543+
if out.get('coords') is not None and hasattr(data, 'coords') and hasattr(data['coords'], 'x'):
544+
lddt_loss = lddt_reconstruction_loss(
545+
out['coords'], data['coords'].x,
546+
plddt=data['plddt'].x if self.args.mask_plddt else None,
547+
plddt_thresh=self.args.plddt_threshold if self.args.mask_plddt else 0.0
548+
)
549+
550+
# FAPE loss
551+
fape_loss = torch.tensor(0.0, device=self.device)
552+
if getattr(self.args, 'fape_loss', False):
553+
from foldtree2.src.losses.losses import quaternion_fape_loss
554+
# Use predicted and true quaternion frames (assume out['quat'], out['trans'], data['quat'].x, data['trans'].x)
555+
if all([out.get('quat') is not None, out.get('trans') is not None, hasattr(data, 'quat'), hasattr(data['quat'], 'x'), hasattr(data, 'trans'), hasattr(data['trans'], 'x')]):
556+
fape_loss = quaternion_fape_loss(
557+
data['quat'].x, data['trans'].x,
558+
out['quat'], out['trans']
559+
)
560+
513561
# Total loss
514562
loss = (self.xweight * xloss + self.edgeweight * edgeloss + self.vqweight * vqloss +
515563
self.fft2weight * fft2loss + self.angles_weight * angles_loss +
516-
self.ss_weight * ss_loss + self.logitweight * logitloss)
517-
518-
# Get batch size from PyG batch (number of graphs in batch)
519-
# For PyTorch Geometric, we need to count unique batch indices
520-
batch_size = data['res'].batch.max().item() + 1 if hasattr(data['res'], 'batch') and data['res'].batch is not None else 1
521-
564+
self.ss_weight * ss_loss + self.logitweight * logitloss +
565+
self.args.lddt_weight * lddt_loss + self.args.fape_weight * fape_loss)
566+
567+
if not torch.isfinite(loss):
568+
self.log('train/skipped_nonfinite_loss', 1.0, on_step=True, on_epoch=True, batch_size=batch_size)
569+
return torch.zeros((), device=self.device, requires_grad=True)
570+
522571
# Log metrics with explicit batch_size to avoid iteration over FeatureStore
523572
self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
524573
self.log('train/aa_loss', xloss, on_step=False, on_epoch=True, batch_size=batch_size)
@@ -528,13 +577,15 @@ def training_step(self, batch, batch_idx):
528577
self.log('train/angles_loss', angles_loss, on_step=False, on_epoch=True, batch_size=batch_size)
529578
self.log('train/ss_loss', ss_loss, on_step=False, on_epoch=True, batch_size=batch_size)
530579
self.log('train/logit_loss', logitloss, on_step=False, on_epoch=True, batch_size=batch_size)
531-
580+
self.log('train/lddt_loss', lddt_loss, on_step=False, on_epoch=True, batch_size=batch_size)
581+
self.log('train/fape_loss', fape_loss, on_step=False, on_epoch=True, batch_size=batch_size)
582+
532583
# Log commitment cost if using scheduling
533584
if self.args.use_commitment_scheduling and hasattr(self.encoder, 'vector_quantizer'):
534585
current_commitment = self.encoder.vector_quantizer.get_commitment_cost()
535586
self.log('train/commitment_cost', current_commitment, on_step=False, on_epoch=True, batch_size=batch_size)
536-
537-
587+
588+
return loss
538589
# Clear CUDA cache
539590
torch.cuda.empty_cache()
540591
gc.collect()
@@ -884,12 +935,29 @@ def has_modular_structure(model):
884935
else:
885936
strategy = 'auto'
886937

938+
# Mirror notebook stability settings: avoid flash/mem-efficient SDP kernels.
939+
if torch.cuda.is_available() and hasattr(torch.backends, 'cuda'):
940+
try:
941+
torch.backends.cuda.enable_flash_sdp(False)
942+
torch.backends.cuda.enable_mem_efficient_sdp(False)
943+
torch.backends.cuda.enable_math_sdp(True)
944+
print("Using math SDP kernel (flash and mem-efficient disabled for stability)")
945+
except Exception as e:
946+
print(f"Warning: could not configure SDP kernels: {e}")
947+
948+
trainer_precision = 32
949+
if args.mixed_precision:
950+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
951+
trainer_precision = 'bf16-mixed'
952+
else:
953+
trainer_precision = '16-mixed'
954+
887955
trainer = pl.Trainer(
888956
max_epochs=args.epochs,
889957
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
890958
devices=devices,
891959
strategy=strategy,
892-
precision='16-mixed' if args.mixed_precision else 32,
960+
precision=trainer_precision,
893961
gradient_clip_val=1.0 if args.clip_grad else 0,
894962
# Lightning automatically passes this to DeepSpeed config as gradient_accumulation_steps
895963
accumulate_grad_batches=args.gradient_accumulation_steps,
@@ -908,6 +976,8 @@ def has_modular_structure(model):
908976
if args.mask_plddt:
909977
print(f" pLDDT threshold: {args.plddt_threshold}")
910978
print(f" Mixed precision: {args.mixed_precision}")
979+
if args.mixed_precision:
980+
print(f" Trainer precision mode: {trainer_precision}")
911981
print(f" Gradient clipping: {args.clip_grad}")
912982
print()
913983

0 commit comments

Comments
 (0)