The learn_monodecoder.py script has been updated to match the exact training logic from the test_monodecoders.ipynb notebook. This ensures consistency between interactive development and production training runs.
- Added:
torch.cuda.amp.autocastandGradScalersupport - CLI Flag:
--mixed-precision(default: True) - Benefit: Faster training with reduced memory usage
- Implementation: Wraps forward pass in autocast context, scales gradients
- Added: Full support for Muon optimizer with parameter grouping
- CLI Flags:
--use-muon: Enable Muon optimizer--use-muon-encoder: Use mk1_MuonEncoder--use-muon-decoders: Use Muon-compatible decoders--muon-lr: Learning rate for Muon (default: 0.02)--adamw-lr: Learning rate for AdamW when using Muon (default: 3e-4)
- Implementation: Automatically separates parameters into:
- Hidden weights (2D+): Use Muon momentum-based Newton updates
- Hidden gains/biases (1D): Use AdamW
- Non-hidden params (input/head): Use AdamW
- Added: Automatic initialization for ReduceLROnPlateau scheduler
- Implementation: Creates single-process group in non-distributed mode
- Fix: Resolves "Default process group has not been initialized" error
- Added: Ability to mask low-confidence residues in loss calculations
- CLI Flags:
--mask-plddt: Enable pLDDT masking--plddt-threshold: Threshold value (default: 0.3)
- Applies To:
- Angles loss:
angles_reconstruction_loss(..., plddt_mask=data['plddt'].x) - Secondary structure loss: Masks residues with plddt < threshold
- Angles loss:
- Added: Secondary structure prediction loss tracking
- Loss Weight:
ss_weight = 0.25 - Implementation: Cross-entropy loss on 3-class SS prediction (helix/sheet/coil)
- Logging: Added to TensorBoard and console output
edgeweight = 0.25 # was 0.05
logitweight = 0.25 # was 0.08
xweight = 1 # unchanged
fft2weight = 0.01 # unchanged
vqweight = 0.1 # was 0.001
angles_weight = 0.05 # was 0.001
ss_weight = 0.25 # new-
Encoder:
- Uses mk1_MuonEncoder when
--use-muon-encoderis set - Standard mk1_Encoder otherwise
- Updated nheads from 8 to 16
- Updated hidden_channels from 2 layers to 3 layers
- Added concat_positions=True
- Uses mk1_MuonEncoder when
-
Decoder:
- Uses Muon-compatible decoders when
--use-muon-decodersis set - Updated decoder_type specifications in configs
- Aligned all hyperparameters with notebook
- Uses Muon-compatible decoders when
- Removed: All burn-in related code
- Reason: Not present in notebook, simpler training logic
- Added:
get_scheduler()helper function - Features:
- Process group initialization for plateau scheduler
- Support for all scheduler types (linear, cosine, plateau, etc.)
- Returns both scheduler and step mode ('step' or 'epoch')
- Mixed Precision: Proper gradient scaling with scaler.step()
- Gradient Accumulation: Cleanup at epoch end for incomplete batches
- Loss Calculation: Exact match with notebook logic including:
- All loss components properly computed
- pLDDT masking applied correctly
- Secondary structure handling
- Metrics Tracking: Added ss_loss to all tracking and logging
python foldtree2/learn_monodecoder.py \
-d structs_traininffttest.h5 \
-e 100 \
-bs 20 \
-lr 1e-4 \
-o ./models/python foldtree2/learn_monodecoder.py \
-d structs_traininffttest.h5 \
-e 100 \
-bs 20 \
--use-muon \
--use-muon-encoder \
--use-muon-decoders \
--muon-lr 0.02 \
--adamw-lr 3e-4 \
--mixed-precision \
-o ./models/python foldtree2/learn_monodecoder.py \
-d structs_traininffttest.h5 \
-e 100 \
-bs 20 \
--mask-plddt \
--plddt-threshold 0.3 \
-lr 1e-4 \
-o ./models/python foldtree2/learn_monodecoder.py \
--config config_with_warmup.yaml \
--use-muon \
--use-muon-encoder \
--use-muon-decodersThe script now supports YAML and JSON config files. CLI arguments override config file values.
Example YAML config:
# config_muon_training.yaml
dataset: "structs_traininffttest.h5"
epochs: 100
batch_size: 20
learning_rate: 1e-4
use_muon: true
use_muon_encoder: true
use_muon_decoders: true
muon_lr: 0.02
adamw_lr: 0.0003
mixed_precision: true
mask_plddt: true
plddt_threshold: 0.3
lr_schedule: "plateau"
gradient_accumulation_steps: 1
clip_grad: true
EMA: true
commitment_cost: 0.9To verify the changes, compare the following sections:
- Loss Calculation: Lines containing loss computation should match notebook cell #27
- Optimizer Setup: Lines containing optimizer initialization should match notebook cell #16
- Scheduler Setup: get_scheduler function should match notebook cell #12
- Model Initialization: Encoder/decoder configs should match notebook cells #14-15
Run a quick test to ensure everything works:
python foldtree2/learn_monodecoder.py \
-d structs_traininffttest.h5 \
-e 1 \
-bs 5 \
--use-muon \
--use-muon-encoder \
--use-muon-decoders \
--mixed-precision \
--mask-plddt \
-o ./models/test/This should run 1 epoch without errors and save a checkpoint.
- ✅ Script now matches notebook training logic
- Test with full training run
- Compare loss curves between script and notebook
- Validate Muon optimizer performance
- Benchmark mixed precision speedup
- Test gradient accumulation with larger effective batch sizes
- All code changes preserve backward compatibility
- Default arguments match notebook settings where applicable
- Mixed precision is enabled by default for optimal performance
- Process group initialization is automatic when needed