Skip to content

Commit ff84f3b

Browse files
committed
remove flash attn guard
1 parent 2bb5fce commit ff84f3b

1 file changed

Lines changed: 31 additions & 14 deletions

File tree

foldtree2/learn_monodecoder.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def print_about():
197197
help='Learning rate (default: 1e-4)')
198198
parser.add_argument('--batch-size', '-bs', type=int, default=10,
199199
help='Batch size (default: 10)')
200+
parser.add_argument('--max-residues', type=int, default=None,
201+
help='Maximum residue count per structure; larger entries are skipped (default: None)')
200202
parser.add_argument('--output-dir', '-o', type=str, default='./models/',
201203
help='Directory to save models/results (default: ./models/)')
202204
parser.add_argument('--model-name', type=str, default='monodecoder_model',
@@ -650,6 +652,10 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve
650652
print(f" Angles Weight: {angles_weight}")
651653
print(f" SS Weight: {ss_weight}")
652654

655+
# Per-epoch weight values (for optional scheduling/visualization)
656+
xweight_epoch = xweight
657+
ss_weight_epoch = ss_weight
658+
653659
# Save configuration if requested
654660
if args.save_config:
655661
config_dict = vars(args).copy()
@@ -673,6 +679,19 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve
673679
converter = pdbgraph.PDB2PyG(aapropcsv='./foldtree2/config/aaindex1.csv')
674680
struct_dat = pdbgraph.StructureDataset(dataset_path)
675681

682+
# Filter by maximum protein size (residue count) if requested
683+
if args.max_residues is not None:
684+
max_residues = args.max_residues
685+
print(f"Filtering dataset to max_residues={max_residues}")
686+
keep_indices = []
687+
for i in tqdm.tqdm(range(len(struct_dat)), desc='Filtering structs by residue count', leave=False):
688+
sample = struct_dat[i]
689+
n_res = sample['res'].x.shape[0] if 'res' in sample else 0
690+
if n_res <= max_residues:
691+
keep_indices.append(i)
692+
print(f"Kept {len(keep_indices)} / {len(struct_dat)} structures after max_residues filter")
693+
struct_dat = torch.utils.data.Subset(struct_dat, keep_indices)
694+
676695
# Create train/validation split
677696
torch.manual_seed(args.val_seed)
678697
val_size = int(len(struct_dat) * args.val_split)
@@ -1240,15 +1259,6 @@ def validate(encoder, decoder, val_loader, device, args):
12401259
best_loss = float('inf')
12411260
global_step = 0 # Track global training steps for warmup and scheduling
12421261

1243-
# Mirror notebook stability settings: avoid flash/mem-efficient SDP kernels.
1244-
if torch.cuda.is_available() and hasattr(torch.backends, 'cuda'):
1245-
try:
1246-
torch.backends.cuda.enable_flash_sdp(False)
1247-
torch.backends.cuda.enable_mem_efficient_sdp(False)
1248-
torch.backends.cuda.enable_math_sdp(True)
1249-
print("Using math SDP kernel (flash and mem-efficient disabled for stability)")
1250-
except Exception as e:
1251-
print(f"Warning: could not configure SDP kernels: {e}")
12521262

12531263
amp_dtype = torch.float16
12541264
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
@@ -1289,6 +1299,10 @@ def validate(encoder, decoder, val_loader, device, args):
12891299
early_stop_wait = 0
12901300

12911301
for epoch in range(args.epochs):
1302+
# Keep per-epoch weights accessible for scheduling and TensorBoard logging
1303+
xweight_epoch = xweight
1304+
ss_weight_epoch = ss_weight
1305+
12921306
total_loss_x = 0
12931307
total_loss_edge = 0
12941308
total_vq = 0
@@ -1299,10 +1313,6 @@ def validate(encoder, decoder, val_loader, device, args):
12991313
total_lddt_loss = 0
13001314
total_fape_loss = 0
13011315
total_delta_loss = 0
1302-
1303-
# Notebook parity: allow coarse reweighting after burn-in-like epochs.
1304-
xweight_epoch = max(xweight, 0.5) if (args.jump_aa_loss is not None and epoch >= args.jump_aa_loss) else xweight
1305-
ss_weight_epoch = max(ss_weight, 0.5) if (args.jump_ss_loss is not None and epoch >= args.jump_ss_loss) else ss_weight
13061316

13071317
for batch_idx, data in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")):
13081318
# Periodically clear CUDA cache to avoid OOM errors
@@ -1565,6 +1575,8 @@ def validate(encoder, decoder, val_loader, device, args):
15651575
if scheduler is not None and scheduler_step_mode == 'step':
15661576
scheduler.step()
15671577

1578+
torch.cuda.empty_cache() # Clear cache after each update to reduce fragmentation
1579+
gc.collect() # Run garbage collection to free memory
15681580
global_step += 1
15691581

15701582

@@ -1618,8 +1630,13 @@ def validate(encoder, decoder, val_loader, device, args):
16181630
gc.collect()
16191631

16201632
# Run quick validation on 10 random proteins
1621-
val_metrics = quick_validate(encoder, decoder, val_dataset, device, args, n_samples=10)
1633+
val_metrics = quick_validate(encoder, decoder, val_dataset, device, args, n_samples=5)
16221634

1635+
# Clear CUDA cache
1636+
torch.cuda.empty_cache()
1637+
gc.collect()
1638+
1639+
16231640
# Update learning rate scheduler (for epoch-based schedulers)
16241641
if scheduler is not None and scheduler_step_mode == 'epoch':
16251642
if args.lr_schedule == 'plateau':

0 commit comments

Comments
 (0)