diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 97e74a2..d7cd3ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,9 +19,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install alphapulldown-input-parser - # common.smk imports parse_fold_chains from the parser package; >=0.5.0 - # is required for the public helper. - run: pip install --quiet "alphapulldown-input-parser>=0.5.0" + # common.smk imports parse_fold_chains from the parser package; >=0.5.1 + # is required (public helper + ".json" preservation for AF3 inputs, #41). + run: pip install --quiet "alphapulldown-input-parser>=0.5.1" - name: Byte-compile common.smk # common.smk carries the memory/length logic; delegates parsing to the parser. run: python -m py_compile workflow/rules/common.smk diff --git a/README.md b/README.md index 9ff4249..e44fe64 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ conda env create \ conda activate snake ``` -This environment file installs Snakemake and all required plugins via conda and pulls in `alphapulldown-input-parser>=0.5.0` from PyPI in a single step. +This environment file installs Snakemake and all required plugins via conda and pulls in `alphapulldown-input-parser>=0.5.1` from PyPI in a single step. That's it, you're done! @@ -76,7 +76,7 @@ The original residue IDs are written to the mmCIF author-numbering fields (`auth_seq_id` and `pdbx_PDB_ins_code`); overlapping IDs are disambiguated with insertion codes such as `2A`, `2B`, and so on. Make sure the prediction container or runtime environment includes a matching -AlphaPulldown build together with `alphapulldown-input-parser>=0.5.0`. +AlphaPulldown build together with `alphapulldown-input-parser>=0.5.1`. diff --git a/test/test_memory_resources.py b/test/test_memory_resources.py index 777c094..56b80da 100644 --- a/test/test_memory_resources.py +++ b/test/test_memory_resources.py @@ -331,6 +331,90 @@ class _Job: assert f"--mem {mem}" in cmd, cmd +# --------------------------------------------------------------------------- +# AF3 JSON inputs (ligands etc.) — issue #41: a `*.json` token in a fold must be +# treated as a direct AF3 input, never as a protein to download / build features for. +# --------------------------------------------------------------------------- + + +def _write_af3_json(directory: str, name: str, *, protein_len: int = 0, ligand=None): + """Write a minimal AF3 input JSON; optional protein sequence and/or ligand.""" + sequences = [] + if protein_len: + sequences.append({"protein": {"id": "A", "sequence": "A" * protein_len}}) + if ligand: + sequences.append({"ligand": {"id": "L", "ccdCodes": [ligand]}}) + path = os.path.join(directory, name) + with open(path, "w") as handle: + json.dump({"name": name, "sequences": sequences}, handle) + return path + + +def test_is_json_input_detects_json_tokens(): + assert common.is_json_input("ligand.json") + assert common.is_json_input("/path/to/LIGAND.JSON") # case-insensitive + assert not common.is_json_input("P12345") + assert not common.is_json_input("Prot.fasta") + + +def test_split_fold_inputs_separates_proteins_and_json(): + # The reported case: protein + ligand JSON with a copy number. + assert common.split_fold_inputs("P12345+ligand.json:80") == ( + ["P12345"], + ["ligand.json"], + ) + # Pure protein folds yield no JSON inputs; copies/regions are stripped. + assert common.split_fold_inputs("P01258+P0AEZ3:2") == (["P01258", "P0AEZ3"], []) + # Paths are reduced to a base (protein) / basename (json). + assert common.split_fold_inputs("/p/Prot.fasta+/q/lig.json") == ( + ["Prot"], + ["lig.json"], + ) + # De-duplication, first-seen order preserved. + assert common.split_fold_inputs("A+A+lig.json+lig.json") == (["A"], ["lig.json"]) + + +def test_format_af3_requested_fold_passes_json_through(): + # Regression for #41: protein -> generated feature JSON; *.json left untouched. + assert ( + common.format_af3_requested_fold("P12345+ligand.json:80") + == "P12345_af3_input.json+ligand.json:80" + ) + assert ( + common.format_af3_requested_fold("P01258+P0AEZ3:2") + == "P01258_af3_input.json+P0AEZ3_af3_input.json:2" + ) + assert common.format_af3_requested_fold("P01258:1-100:2") == ( + "P01258_af3_input.json:1-100:2" + ) + + +def test_chain_residue_count_reads_json_input(): + common._AF3_INPUT_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + # Ligand-only JSON has no polymer sequence -> contributes 0. + _write_af3_json(d, "ligand.json", ligand="ATP") + assert common.chain_residue_count("ligand.json", d, d, is_af3=True) == 0 + # A JSON carrying a protein sequence is counted by its polymer length. + _write_af3_json(d, "complex.json", protein_len=150) + assert common.chain_residue_count("complex.json", d, d, is_af3=True) == 150 + + +def test_fold_total_tokens_counts_protein_not_ligand_json(): + common._RESIDUE_COUNT_CACHE.clear() + common._AF3_INPUT_COUNT_CACHE.clear() + with tempfile.TemporaryDirectory() as d: + _write_fasta(d, "P12345", 200) + _write_af3_json(d, "ligand.json", ligand="ATP") + # Protein counted; ligand JSON adds 0 and does not error. + assert ( + common.fold_total_tokens( + "P12345+ligand.json:80", d, "+", features_dir=d, is_af3=True + ) + == 200 + ) + + def _run_all(): fns = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] for fn in fns: diff --git a/workflow/Snakefile b/workflow/Snakefile index a7abfc4..76e80a6 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -12,7 +12,7 @@ include: "rules/common.smk" from pathlib import Path from os import makedirs, listdir, symlink, remove -from os.path import abspath, join, splitext, basename, exists +from os.path import abspath, join, splitext, basename, exists, lexists, realpath from alphapulldown_input_parser.parser import FoldDataset, generate_fold_specifications @@ -51,35 +51,10 @@ def feature_name(base: str) -> str: return f"{base}_af3_input.json" if IS_AF3 else f"{base}.{FEATURE_SUFFIX}" -def _format_af3_requested_fold(fold: str, delimiter: str) -> str: - """Convert a logical fold specification into AF3 JSON feature inputs. - - Examples: - "P01258+P0AEZ3:2" -> "P01258_af3_input.json+P0AEZ3_af3_input.json:2" - "P01258:1-100:2" -> "P01258_af3_input.json:1-100:2" - "P01258" -> "P01258_af3_input.json" - - Rationale: - - Features are generated as "_af3_input.json". - - Copy numbers and region ranges apply to the logical protein, not the file name. - - alphapulldown-input-parser accepts those suffixes after the JSON filename. - """ - - converted_parts: list[str] = [] - for token in fold.split(delimiter): - token = token.strip() - if not token: - continue - - parts = [p.strip() for p in token.split(":") if p.strip()] - base = parts[0] - suffix = ":".join(parts[1:]) if len(parts) > 1 else "" - - json_name = f"{base}_af3_input.json" - converted = f"{json_name}:{suffix}" if suffix else json_name - converted_parts.append(converted) - - return delimiter.join(converted_parts) +# AF3 ``--input`` formatting (protein chains -> generated feature JSON; direct +# ``*.json`` inputs passed through) and protein/JSON partitioning live in +# common.smk (``format_af3_requested_fold`` / ``split_fold_inputs``) so they are +# unit-testable; both are imported via the ``include`` above. protein_delimiter = config.get("protein_delimiter", "+") exclude_permutations = config.get("exclude_permutations", True) @@ -183,6 +158,11 @@ if MAX_PROTEIN_LENGTH > 0 or MAX_TOTAL_LENGTH > 0: _features_dir = join(config["output_directory"], "features") def _resolve_protein_length(name): + if is_json_input(name): + # Direct AF3 JSON input (e.g. a ligand): not a UniProt protein, so do + # not query UniProt. Count polymer residues from the file if present, + # else 0 (a known length, so the fold is kept rather than flagged). + return chain_residue_count(name, _data_dir, _features_dir, IS_AF3) if name in sequence_length_cache: return sequence_length_cache[name] length = 0 @@ -248,14 +228,29 @@ if MAX_PROTEIN_LENGTH > 0 or MAX_TOTAL_LENGTH > 0: f"{sorted(_unknown_proteins)[:5]}" ) -# Proteins required by the surviving folds (matches FoldDataset dedup ordering). +# Feature requirements of the surviving folds. Protein chains need features +# (downloaded + generated, or symlinked when precomputed); direct AF3 ``*.json`` +# inputs (e.g. ligands) are supplied via feature_directory and are NEVER downloaded +# or generated -- they are required as the JSON file itself. Both lists preserve +# first-seen order and are de-duplicated. kept_proteins = [] +required_feature_files = [] _seen_proteins = set() +_seen_features = set() for fold in kept_folds: - for _name in dataset.sequences_by_fold.get(fold, ()): + _protein_bases, _json_basenames = split_fold_inputs(fold, protein_delimiter) + for _name in _protein_bases: if _name not in _seen_proteins: _seen_proteins.add(_name) kept_proteins.append(_name) + _feat = feature_name(_name) + if _feat not in _seen_features: + _seen_features.add(_feat) + required_feature_files.append(_feat) + for _json in _json_basenames: + if _json not in _seen_features: + _seen_features.add(_json) + required_feature_files.append(_json) required_folds = [ join(config["output_directory"], "predictions", fold, "completed_fold.txt") @@ -271,8 +266,8 @@ RECURSIVE_REPORT = ( ) required_feature_paths = [ - join(config["output_directory"], "features", feature_name(fasta_basename)) - for fasta_basename in kept_proteins + join(config["output_directory"], "features", feature_basename) + for feature_basename in required_feature_files ] if config.get("only_generate_features", False): required_targets = required_feature_paths @@ -416,9 +411,13 @@ rule symlink_features: **linear_resources(mem=800, runtime=10), run: for in_file, out_file in zip(input, output): - if exists(out_file): + source = realpath(in_file) + destination = abspath(out_file) + if source == destination: + continue + if lexists(out_file): remove(out_file) - symlink(abspath(in_file), out_file) + symlink(source, out_file) rule create_features: input: @@ -471,9 +470,14 @@ rule create_features: """ def lookup_features(wildcards): + # Inputs for inference: generated/precomputed protein features plus any direct + # AF3 JSON inputs (e.g. ligands), which are required as the JSON file itself + # rather than as a generated _af3_input.json. + protein_bases, json_basenames = split_fold_inputs(wildcards.fold, protein_delimiter) + feature_files = [feature_name(base) for base in protein_bases] + list(json_basenames) return [ - join(config["output_directory"], "features", feature_name(feature)) - for feature in dataset.sequences_by_fold[wildcards.fold] + join(config["output_directory"], "features", feature_basename) + for feature_basename in feature_files ] rule structure_inference: @@ -489,7 +493,7 @@ rule structure_inference: for individual_fold in wildcards.fold.split(" ") ], requested_fold = ( - lambda wc: _format_af3_requested_fold(wc.fold, protein_delimiter) + lambda wc: format_af3_requested_fold(wc.fold, protein_delimiter) if IS_AF3 else wc.fold ), diff --git a/workflow/envs/alphapulldown.yaml b/workflow/envs/alphapulldown.yaml index 8f45cd4..264facb 100644 --- a/workflow/envs/alphapulldown.yaml +++ b/workflow/envs/alphapulldown.yaml @@ -14,4 +14,6 @@ dependencies: - pip: # >=0.5.0 exposes the public parse_fold_chains() used by the length-aware # memory model and the length filter at workflow-parse time. - - alphapulldown-input-parser>=0.5.0 + # >=0.5.1 preserves ".json" tokens in FoldDataset normalization so AF3 JSON + # inputs (e.g. ligands) are not mistaken for proteins (issue #41). + - alphapulldown-input-parser>=0.5.1 diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index f8b0969..f5ba68d 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -102,6 +102,83 @@ def parse_fold_chains(fold: str, delimiter: str = "+") -> list[tuple[str, int]]: ] +def is_json_input(name: str) -> bool: + """True if a fold token names a direct AF3 JSON input (e.g. a ``ligand.json``). + + Such tokens are AlphaFold 3 inputs supplied as-is via ``feature_directory``; + they are *not* proteins and must never be downloaded or sent through feature + generation. Everything else is treated as a protein chain reference. + """ + return str(name).lower().endswith(".json") + + +def split_fold_inputs( + fold: str, delimiter: str = "+" +) -> tuple[list[str], list[str]]: + """Partition a fold spec into protein chains and direct AF3 JSON inputs. + + Returns ``(protein_bases, json_basenames)``: + + - ``protein_bases`` -- base names (path + extension stripped) of chains that + need feature generation/download, mirroring the parser's stem handling. + - ``json_basenames`` -- basenames of ``*.json`` tokens supplied directly as AF3 + inputs (e.g. ligands), which are provided via ``feature_directory`` and never + generated. + + Both lists preserve first-seen order and are de-duplicated. Copy numbers and + region ranges (``ligand.json:80``, ``A:1-100``) are stripped by the underlying + chain parser, so only the chain name survives here. + """ + protein_bases: list[str] = [] + json_basenames: list[str] = [] + seen_proteins: set[str] = set() + seen_json: set[str] = set() + for name, _copies in parse_fold_chains(fold, delimiter): + if is_json_input(name): + base = os.path.basename(name) + if base not in seen_json: + seen_json.add(base) + json_basenames.append(base) + else: + base = os.path.splitext(os.path.basename(name))[0] + if base not in seen_proteins: + seen_proteins.add(base) + protein_bases.append(base) + return protein_bases, json_basenames + + +def format_af3_requested_fold(fold: str, delimiter: str = "+") -> str: + """Convert a logical fold spec into AlphaFold 3 inference ``--input`` tokens. + + Protein chains map to their generated feature file ``_af3_input.json``; + tokens that are already ``*.json`` (direct AF3 JSON inputs such as ligands) are + passed through unchanged. Copy numbers and region ranges are preserved after the + file name. + + Examples: + ``P01258+P0AEZ3:2`` -> ``P01258_af3_input.json+P0AEZ3_af3_input.json:2`` + ``P01258+ligand.json:80`` -> ``P01258_af3_input.json+ligand.json:80`` + ``P01258:1-100:2`` -> ``P01258_af3_input.json:1-100:2`` + + Rationale: + - Protein features are generated as ``_af3_input.json``. + - JSON inputs are already AF3 inputs and must not get a second suffix. + - Copy numbers / region ranges apply to the logical chain, not the file + name; ``alphapulldown-input-parser`` accepts them after the JSON filename. + """ + converted_parts: list[str] = [] + for token in str(fold).split(delimiter): + token = token.strip() + if not token: + continue + parts = [p.strip() for p in token.split(":") if p.strip()] + base = parts[0] + suffix = ":".join(parts[1:]) if len(parts) > 1 else "" + json_name = base if is_json_input(base) else f"{base}_af3_input.json" + converted_parts.append(f"{json_name}:{suffix}" if suffix else json_name) + return delimiter.join(converted_parts) + + @functools.lru_cache(maxsize=None) def fetch_uniprot_length(uniprot_id: str, timeout: float = 30.0) -> int: """Residue length of a UniProt entry via the REST API; 0 on any failure. @@ -137,7 +214,17 @@ def chain_residue_count( parse-time length table, which covers the AF2 precomputed-feature case where neither a FASTA nor an AF3 JSON exists). Returns 0 when length is unknown so sizing degrades to the base allocation plus retry escalation. + + A direct AF3 JSON input (``ligand.json``) is read from the file itself in + ``features_dir``; ligand-only inputs have no polymer ``sequence`` and so + contribute 0 (consistent with AF3 ligand atoms not being counted as tokens). """ + if is_json_input(name): + if features_dir: + return af3_input_residue_count( + os.path.join(features_dir, os.path.basename(name)) + ) + return 0 length = residue_count(os.path.join(data_dir, f"{name}.fasta")) if length == 0 and is_af3 and features_dir: length = af3_input_residue_count(