diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 6432336d..9914a55a 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -339,6 +339,8 @@ def fold( lm_mask_pct: float | None = None, early_exit: bool = False, lm_dropout: float | None = 0.3, + msa_max_depth: int = 1024, + msa_column_mask_rate: float = 0.1, complex_id: str = "pred", ) -> MolecularComplexResult | list[MolecularComplexResult]: """Fold a structure end-to-end: encode → model → decode. @@ -362,6 +364,12 @@ def fold( LM-embedding dropout for this fold (fresh mask per loop → diverse ensemble on repeated folds). Defaults to ``0.3`` (paper folding-eval value); ``0``/``None`` disables. + msa_max_depth : int + Maximum number of MSA rows kept per loop (row subsampling + is drawn fresh per loop). Only affects inputs that carry an MSA. + msa_column_mask_rate : float + Fraction of MSA columns masked once before the loop + (shared across loops). Only affects inputs that carry an MSA. complex_id : str Identifier assigned to the predicted MolecularComplex(es). @@ -393,6 +401,8 @@ def fold( num_sampling_steps=num_sampling_steps, num_diffusion_samples=num_diffusion_samples, early_exit=early_exit, + msa_max_depth=msa_max_depth, + msa_column_mask_rate=msa_column_mask_rate, **sampler_kwargs, )