diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 9914a55a..21f89d58 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -339,7 +339,7 @@ def fold( lm_mask_pct: float | None = None, early_exit: bool = False, lm_dropout: float | None = 0.3, - msa_max_depth: int = 1024, + msa_max_depth: int | None = 1024, msa_column_mask_rate: float = 0.1, complex_id: str = "pred", ) -> MolecularComplexResult | list[MolecularComplexResult]: @@ -364,9 +364,11 @@ def fold( LM-embedding dropout for this fold (fresh mask per loop → diverse ensemble on repeated folds). Defaults to ``0.3`` (paper folding-eval value); ``0``/``None`` disables. - msa_max_depth : int + msa_max_depth : int, optional Maximum number of MSA rows kept per loop (row subsampling - is drawn fresh per loop). Only affects inputs that carry an MSA. + is drawn fresh per loop). When ``None``, MSA row subsampling is + disabled and the full MSA depth is used. Only affects inputs that + carry an MSA. msa_column_mask_rate : float Fraction of MSA columns masked once before the loop (shared across loops). Only affects inputs that carry an MSA. @@ -403,6 +405,8 @@ def fold( early_exit=early_exit, msa_max_depth=msa_max_depth, msa_column_mask_rate=msa_column_mask_rate, + # A null depth means "use the full MSA" => no subsampling. + msa_subsample_at_inference=msa_max_depth is not None, **sampler_kwargs, )