Skip to content

Latest commit

 

History

History
206 lines (175 loc) · 6.24 KB

File metadata and controls

206 lines (175 loc) · 6.24 KB

learn_monodecoder.py Synchronization with test_monodecoders.ipynb

Summary

The learn_monodecoder.py script has been updated to match the exact training logic from the test_monodecoders.ipynb notebook. This ensures consistency between interactive development and production training runs.

Key Changes Made

1. Mixed Precision Training Support

  • Added: torch.cuda.amp.autocast and GradScaler support
  • CLI Flag: --mixed-precision (default: True)
  • Benefit: Faster training with reduced memory usage
  • Implementation: Wraps forward pass in autocast context, scales gradients

2. Muon Optimizer Support

  • Added: Full support for Muon optimizer with parameter grouping
  • CLI Flags:
    • --use-muon: Enable Muon optimizer
    • --use-muon-encoder: Use mk1_MuonEncoder
    • --use-muon-decoders: Use Muon-compatible decoders
    • --muon-lr: Learning rate for Muon (default: 0.02)
    • --adamw-lr: Learning rate for AdamW when using Muon (default: 3e-4)
  • Implementation: Automatically separates parameters into:
    • Hidden weights (2D+): Use Muon momentum-based Newton updates
    • Hidden gains/biases (1D): Use AdamW
    • Non-hidden params (input/head): Use AdamW

3. Process Group Initialization

  • Added: Automatic initialization for ReduceLROnPlateau scheduler
  • Implementation: Creates single-process group in non-distributed mode
  • Fix: Resolves "Default process group has not been initialized" error

4. pLDDT Masking Support

  • Added: Ability to mask low-confidence residues in loss calculations
  • CLI Flags:
    • --mask-plddt: Enable pLDDT masking
    • --plddt-threshold: Threshold value (default: 0.3)
  • Applies To:
    • Angles loss: angles_reconstruction_loss(..., plddt_mask=data['plddt'].x)
    • Secondary structure loss: Masks residues with plddt < threshold

5. Secondary Structure Loss

  • Added: Secondary structure prediction loss tracking
  • Loss Weight: ss_weight = 0.25
  • Implementation: Cross-entropy loss on 3-class SS prediction (helix/sheet/coil)
  • Logging: Added to TensorBoard and console output

6. Updated Loss Weights (Matching Notebook)

edgeweight = 0.25      # was 0.05
logitweight = 0.25     # was 0.08
xweight = 1            # unchanged
fft2weight = 0.01      # unchanged
vqweight = 0.1         # was 0.001
angles_weight = 0.05   # was 0.001
ss_weight = 0.25       # new

7. Encoder/Decoder Configuration Updates

  • Encoder:

    • Uses mk1_MuonEncoder when --use-muon-encoder is set
    • Standard mk1_Encoder otherwise
    • Updated nheads from 8 to 16
    • Updated hidden_channels from 2 layers to 3 layers
    • Added concat_positions=True
  • Decoder:

    • Uses Muon-compatible decoders when --use-muon-decoders is set
    • Updated decoder_type specifications in configs
    • Aligned all hyperparameters with notebook

8. Removed Burn-in Logic

  • Removed: All burn-in related code
  • Reason: Not present in notebook, simpler training logic

9. Improved Scheduler Function

  • Added: get_scheduler() helper function
  • Features:
    • Process group initialization for plateau scheduler
    • Support for all scheduler types (linear, cosine, plateau, etc.)
    • Returns both scheduler and step mode ('step' or 'epoch')

10. Training Loop Improvements

  • Mixed Precision: Proper gradient scaling with scaler.step()
  • Gradient Accumulation: Cleanup at epoch end for incomplete batches
  • Loss Calculation: Exact match with notebook logic including:
    • All loss components properly computed
    • pLDDT masking applied correctly
    • Secondary structure handling
  • Metrics Tracking: Added ss_loss to all tracking and logging

Usage Examples

Basic Training

python foldtree2/learn_monodecoder.py \
  -d structs_traininffttest.h5 \
  -e 100 \
  -bs 20 \
  -lr 1e-4 \
  -o ./models/

Training with Muon Optimizer

python foldtree2/learn_monodecoder.py \
  -d structs_traininffttest.h5 \
  -e 100 \
  -bs 20 \
  --use-muon \
  --use-muon-encoder \
  --use-muon-decoders \
  --muon-lr 0.02 \
  --adamw-lr 3e-4 \
  --mixed-precision \
  -o ./models/

Training with pLDDT Masking

python foldtree2/learn_monodecoder.py \
  -d structs_traininffttest.h5 \
  -e 100 \
  -bs 20 \
  --mask-plddt \
  --plddt-threshold 0.3 \
  -lr 1e-4 \
  -o ./models/

Training with Config File

python foldtree2/learn_monodecoder.py \
  --config config_with_warmup.yaml \
  --use-muon \
  --use-muon-encoder \
  --use-muon-decoders

Configuration File Support

The script now supports YAML and JSON config files. CLI arguments override config file values.

Example YAML config:

# config_muon_training.yaml
dataset: "structs_traininffttest.h5"
epochs: 100
batch_size: 20
learning_rate: 1e-4
use_muon: true
use_muon_encoder: true
use_muon_decoders: true
muon_lr: 0.02
adamw_lr: 0.0003
mixed_precision: true
mask_plddt: true
plddt_threshold: 0.3
lr_schedule: "plateau"
gradient_accumulation_steps: 1
clip_grad: true
EMA: true
commitment_cost: 0.9

Verification

To verify the changes, compare the following sections:

  1. Loss Calculation: Lines containing loss computation should match notebook cell #27
  2. Optimizer Setup: Lines containing optimizer initialization should match notebook cell #16
  3. Scheduler Setup: get_scheduler function should match notebook cell #12
  4. Model Initialization: Encoder/decoder configs should match notebook cells #14-15

Testing

Run a quick test to ensure everything works:

python foldtree2/learn_monodecoder.py \
  -d structs_traininffttest.h5 \
  -e 1 \
  -bs 5 \
  --use-muon \
  --use-muon-encoder \
  --use-muon-decoders \
  --mixed-precision \
  --mask-plddt \
  -o ./models/test/

This should run 1 epoch without errors and save a checkpoint.

Next Steps

  1. ✅ Script now matches notebook training logic
  2. Test with full training run
  3. Compare loss curves between script and notebook
  4. Validate Muon optimizer performance
  5. Benchmark mixed precision speedup
  6. Test gradient accumulation with larger effective batch sizes

Notes

  • All code changes preserve backward compatibility
  • Default arguments match notebook settings where applicable
  • Mixed precision is enabled by default for optimal performance
  • Process group initialization is automatic when needed