From 3b46dad9ee16d97ce81d42c62f55705106fa3f55 Mon Sep 17 00:00:00 2001 From: Fausto Milletari Date: Tue, 9 Jun 2026 16:28:28 +0000 Subject: [PATCH] Derive msa_subsample_at_inference from nullable msa_max_depth Make msa_max_depth nullable (int | None) in the local fold(), matching the forge FoldingConfig surface. The model's msa_subsample_at_inference kwarg is no longer exposed directly; instead it's derived: None depth => use the full MSA (no subsampling), otherwise subsample. The transformers model is unchanged (it already no-ops when max_depth is None or subsampling is disabled). Co-Authored-By: Claude Opus 4.8 (1M context) --- esm/models/esmfold2/processor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 9914a55a..21f89d58 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -339,7 +339,7 @@ def fold( lm_mask_pct: float | None = None, early_exit: bool = False, lm_dropout: float | None = 0.3, - msa_max_depth: int = 1024, + msa_max_depth: int | None = 1024, msa_column_mask_rate: float = 0.1, complex_id: str = "pred", ) -> MolecularComplexResult | list[MolecularComplexResult]: @@ -364,9 +364,11 @@ def fold( LM-embedding dropout for this fold (fresh mask per loop → diverse ensemble on repeated folds). Defaults to ``0.3`` (paper folding-eval value); ``0``/``None`` disables. - msa_max_depth : int + msa_max_depth : int, optional Maximum number of MSA rows kept per loop (row subsampling - is drawn fresh per loop). Only affects inputs that carry an MSA. + is drawn fresh per loop). When ``None``, MSA row subsampling is + disabled and the full MSA depth is used. Only affects inputs that + carry an MSA. msa_column_mask_rate : float Fraction of MSA columns masked once before the loop (shared across loops). Only affects inputs that carry an MSA. @@ -403,6 +405,8 @@ def fold( early_exit=early_exit, msa_max_depth=msa_max_depth, msa_column_mask_rate=msa_column_mask_rate, + # A null depth means "use the full MSA" => no subsampling. + msa_subsample_at_inference=msa_max_depth is not None, **sampler_kwargs, )