@@ -197,6 +197,8 @@ def print_about():
197197 help = 'Learning rate (default: 1e-4)' )
198198parser .add_argument ('--batch-size' , '-bs' , type = int , default = 10 ,
199199 help = 'Batch size (default: 10)' )
200+ parser .add_argument ('--max-residues' , type = int , default = None ,
201+ help = 'Maximum residue count per structure; larger entries are skipped (default: None)' )
200202parser .add_argument ('--output-dir' , '-o' , type = str , default = './models/' ,
201203 help = 'Directory to save models/results (default: ./models/)' )
202204parser .add_argument ('--model-name' , type = str , default = 'monodecoder_model' ,
@@ -650,6 +652,10 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve
650652print (f" Angles Weight: { angles_weight } " )
651653print (f" SS Weight: { ss_weight } " )
652654
655+ # Per-epoch weight values (for optional scheduling/visualization)
656+ xweight_epoch = xweight
657+ ss_weight_epoch = ss_weight
658+
653659# Save configuration if requested
654660if args .save_config :
655661 config_dict = vars (args ).copy ()
@@ -673,6 +679,19 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve
673679converter = pdbgraph .PDB2PyG (aapropcsv = './foldtree2/config/aaindex1.csv' )
674680struct_dat = pdbgraph .StructureDataset (dataset_path )
675681
682+ # Filter by maximum protein size (residue count) if requested
683+ if args .max_residues is not None :
684+ max_residues = args .max_residues
685+ print (f"Filtering dataset to max_residues={ max_residues } " )
686+ keep_indices = []
687+ for i in tqdm .tqdm (range (len (struct_dat )), desc = 'Filtering structs by residue count' , leave = False ):
688+ sample = struct_dat [i ]
689+ n_res = sample ['res' ].x .shape [0 ] if 'res' in sample else 0
690+ if n_res <= max_residues :
691+ keep_indices .append (i )
692+ print (f"Kept { len (keep_indices )} / { len (struct_dat )} structures after max_residues filter" )
693+ struct_dat = torch .utils .data .Subset (struct_dat , keep_indices )
694+
676695# Create train/validation split
677696torch .manual_seed (args .val_seed )
678697val_size = int (len (struct_dat ) * args .val_split )
@@ -1240,15 +1259,6 @@ def validate(encoder, decoder, val_loader, device, args):
12401259best_loss = float ('inf' )
12411260global_step = 0 # Track global training steps for warmup and scheduling
12421261
1243- # Mirror notebook stability settings: avoid flash/mem-efficient SDP kernels.
1244- if torch .cuda .is_available () and hasattr (torch .backends , 'cuda' ):
1245- try :
1246- torch .backends .cuda .enable_flash_sdp (False )
1247- torch .backends .cuda .enable_mem_efficient_sdp (False )
1248- torch .backends .cuda .enable_math_sdp (True )
1249- print ("Using math SDP kernel (flash and mem-efficient disabled for stability)" )
1250- except Exception as e :
1251- print (f"Warning: could not configure SDP kernels: { e } " )
12521262
12531263amp_dtype = torch .float16
12541264if torch .cuda .is_available () and torch .cuda .is_bf16_supported ():
@@ -1289,6 +1299,10 @@ def validate(encoder, decoder, val_loader, device, args):
12891299early_stop_wait = 0
12901300
12911301for epoch in range (args .epochs ):
1302+ # Keep per-epoch weights accessible for scheduling and TensorBoard logging
1303+ xweight_epoch = xweight
1304+ ss_weight_epoch = ss_weight
1305+
12921306 total_loss_x = 0
12931307 total_loss_edge = 0
12941308 total_vq = 0
@@ -1299,10 +1313,6 @@ def validate(encoder, decoder, val_loader, device, args):
12991313 total_lddt_loss = 0
13001314 total_fape_loss = 0
13011315 total_delta_loss = 0
1302-
1303- # Notebook parity: allow coarse reweighting after burn-in-like epochs.
1304- xweight_epoch = max (xweight , 0.5 ) if (args .jump_aa_loss is not None and epoch >= args .jump_aa_loss ) else xweight
1305- ss_weight_epoch = max (ss_weight , 0.5 ) if (args .jump_ss_loss is not None and epoch >= args .jump_ss_loss ) else ss_weight
13061316
13071317 for batch_idx , data in enumerate (tqdm .tqdm (train_loader , desc = f"Epoch { epoch + 1 } /{ args .epochs } " )):
13081318 # Periodically clear CUDA cache to avoid OOM errors
@@ -1565,6 +1575,8 @@ def validate(encoder, decoder, val_loader, device, args):
15651575 if scheduler is not None and scheduler_step_mode == 'step' :
15661576 scheduler .step ()
15671577
1578+ torch .cuda .empty_cache () # Clear cache after each update to reduce fragmentation
1579+ gc .collect () # Run garbage collection to free memory
15681580 global_step += 1
15691581
15701582
@@ -1618,8 +1630,13 @@ def validate(encoder, decoder, val_loader, device, args):
16181630 gc .collect ()
16191631
16201632 # Run quick validation on 10 random proteins
1621- val_metrics = quick_validate (encoder , decoder , val_dataset , device , args , n_samples = 10 )
1633+ val_metrics = quick_validate (encoder , decoder , val_dataset , device , args , n_samples = 5 )
16221634
1635+ # Clear CUDA cache
1636+ torch .cuda .empty_cache ()
1637+ gc .collect ()
1638+
1639+
16231640 # Update learning rate scheduler (for epoch-based schedulers)
16241641 if scheduler is not None and scheduler_step_mode == 'epoch' :
16251642 if args .lr_schedule == 'plateau' :
0 commit comments