From 6d442c64df5bf8e9f1fca8d88a73607e8c0eb96e Mon Sep 17 00:00:00 2001 From: Fausto Milletari Date: Mon, 8 Jun 2026 20:26:05 +0000 Subject: [PATCH 1/2] Expose ESMFold2 MSA inference-diversity knobs in local fold() The transformers ESMFold2Model.forward accepts msa_max_depth, msa_column_mask_rate, and msa_subsample_at_inference, but the local ESMFold2InputBuilder.fold() entrypoint didn't surface them, so users running ESMFold2 locally couldn't access the feature. Add the three knobs to fold() and forward them to the model. Defaults match the model's hardcoded values (1024 / 0.1 / True), so behavior is unchanged unless set; they only take effect for inputs that carry an MSA. Co-Authored-By: Claude Opus 4.8 (1M context) --- esm/models/esmfold2/processor.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 6432336d..2ab65f1b 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -339,6 +339,9 @@ def fold( lm_mask_pct: float | None = None, early_exit: bool = False, lm_dropout: float | None = 0.3, + msa_max_depth: int = 1024, + msa_column_mask_rate: float = 0.1, + msa_subsample_at_inference: bool = True, complex_id: str = "pred", ) -> MolecularComplexResult | list[MolecularComplexResult]: """Fold a structure end-to-end: encode → model → decode. @@ -362,6 +365,15 @@ def fold( LM-embedding dropout for this fold (fresh mask per loop → diverse ensemble on repeated folds). Defaults to ``0.3`` (paper folding-eval value); ``0``/``None`` disables. + msa_max_depth : int + Maximum number of MSA rows kept per loop (row subsampling + is drawn fresh per loop). Only affects inputs that carry an MSA. + msa_column_mask_rate : float + Fraction of MSA columns masked once before the loop + (shared across loops). Only affects inputs that carry an MSA. + msa_subsample_at_inference : bool + Whether to subsample MSA rows at inference time. The query row is + always kept. Only affects inputs that carry an MSA. complex_id : str Identifier assigned to the predicted MolecularComplex(es). @@ -393,6 +405,9 @@ def fold( num_sampling_steps=num_sampling_steps, num_diffusion_samples=num_diffusion_samples, early_exit=early_exit, + msa_max_depth=msa_max_depth, + msa_column_mask_rate=msa_column_mask_rate, + msa_subsample_at_inference=msa_subsample_at_inference, **sampler_kwargs, ) From 868be038d6df4dffd3a28b56b950c39cc927dc87 Mon Sep 17 00:00:00 2001 From: Fausto Milletari Date: Mon, 8 Jun 2026 20:36:02 +0000 Subject: [PATCH 2/2] Drop msa_subsample_at_inference from local fold() Its model default is True and we always want it on, so omitting the knob is behavior-preserving and leaves fold() exposing exactly msa_max_depth + msa_column_mask_rate, matching the FoldingConfig (forge) interface. Co-Authored-By: Claude Opus 4.8 (1M context) --- esm/models/esmfold2/processor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 2ab65f1b..9914a55a 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -341,7 +341,6 @@ def fold( lm_dropout: float | None = 0.3, msa_max_depth: int = 1024, msa_column_mask_rate: float = 0.1, - msa_subsample_at_inference: bool = True, complex_id: str = "pred", ) -> MolecularComplexResult | list[MolecularComplexResult]: """Fold a structure end-to-end: encode → model → decode. @@ -371,9 +370,6 @@ def fold( msa_column_mask_rate : float Fraction of MSA columns masked once before the loop (shared across loops). Only affects inputs that carry an MSA. - msa_subsample_at_inference : bool - Whether to subsample MSA rows at inference time. The query row is - always kept. Only affects inputs that carry an MSA. complex_id : str Identifier assigned to the predicted MolecularComplex(es). @@ -407,7 +403,6 @@ def fold( early_exit=early_exit, msa_max_depth=msa_max_depth, msa_column_mask_rate=msa_column_mask_rate, - msa_subsample_at_inference=msa_subsample_at_inference, **sampler_kwargs, )