Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install alphapulldown-input-parser
# common.smk imports parse_fold_chains from the parser package; >=0.5.0
# is required for the public helper.
run: pip install --quiet "alphapulldown-input-parser>=0.5.0"
# common.smk imports parse_fold_chains from the parser package; >=0.5.1
# is required (public helper + ".json" preservation for AF3 inputs, #41).
run: pip install --quiet "alphapulldown-input-parser>=0.5.1"
- name: Byte-compile common.smk
# common.smk carries the memory/length logic; delegates parsing to the parser.
run: python -m py_compile workflow/rules/common.smk
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ conda env create \
conda activate snake
```

This environment file installs Snakemake and all required plugins via conda and pulls in `alphapulldown-input-parser>=0.5.0` from PyPI in a single step.
This environment file installs Snakemake and all required plugins via conda and pulls in `alphapulldown-input-parser>=0.5.1` from PyPI in a single step.

That's it, you're done!

Expand Down Expand Up @@ -76,7 +76,7 @@ The original residue IDs are written to the mmCIF author-numbering fields
(`auth_seq_id` and `pdbx_PDB_ins_code`); overlapping IDs are disambiguated with
insertion codes such as `2A`, `2B`, and so on.
Make sure the prediction container or runtime environment includes a matching
AlphaPulldown build together with `alphapulldown-input-parser>=0.5.0`.
AlphaPulldown build together with `alphapulldown-input-parser>=0.5.1`.

</details>

Expand Down
84 changes: 84 additions & 0 deletions test/test_memory_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,90 @@ class _Job:
assert f"--mem {mem}" in cmd, cmd


# ---------------------------------------------------------------------------
# AF3 JSON inputs (ligands etc.) — issue #41: a `*.json` token in a fold must be
# treated as a direct AF3 input, never as a protein to download / build features for.
# ---------------------------------------------------------------------------


def _write_af3_json(directory: str, name: str, *, protein_len: int = 0, ligand=None):
"""Write a minimal AF3 input JSON; optional protein sequence and/or ligand."""
sequences = []
if protein_len:
sequences.append({"protein": {"id": "A", "sequence": "A" * protein_len}})
if ligand:
sequences.append({"ligand": {"id": "L", "ccdCodes": [ligand]}})
path = os.path.join(directory, name)
with open(path, "w") as handle:
json.dump({"name": name, "sequences": sequences}, handle)
return path


def test_is_json_input_detects_json_tokens():
assert common.is_json_input("ligand.json")
assert common.is_json_input("/path/to/LIGAND.JSON") # case-insensitive
assert not common.is_json_input("P12345")
assert not common.is_json_input("Prot.fasta")


def test_split_fold_inputs_separates_proteins_and_json():
# The reported case: protein + ligand JSON with a copy number.
assert common.split_fold_inputs("P12345+ligand.json:80") == (
["P12345"],
["ligand.json"],
)
# Pure protein folds yield no JSON inputs; copies/regions are stripped.
assert common.split_fold_inputs("P01258+P0AEZ3:2") == (["P01258", "P0AEZ3"], [])
# Paths are reduced to a base (protein) / basename (json).
assert common.split_fold_inputs("/p/Prot.fasta+/q/lig.json") == (
["Prot"],
["lig.json"],
)
# De-duplication, first-seen order preserved.
assert common.split_fold_inputs("A+A+lig.json+lig.json") == (["A"], ["lig.json"])


def test_format_af3_requested_fold_passes_json_through():
# Regression for #41: protein -> generated feature JSON; *.json left untouched.
assert (
common.format_af3_requested_fold("P12345+ligand.json:80")
== "P12345_af3_input.json+ligand.json:80"
)
assert (
common.format_af3_requested_fold("P01258+P0AEZ3:2")
== "P01258_af3_input.json+P0AEZ3_af3_input.json:2"
)
assert common.format_af3_requested_fold("P01258:1-100:2") == (
"P01258_af3_input.json:1-100:2"
)


def test_chain_residue_count_reads_json_input():
common._AF3_INPUT_COUNT_CACHE.clear()
with tempfile.TemporaryDirectory() as d:
# Ligand-only JSON has no polymer sequence -> contributes 0.
_write_af3_json(d, "ligand.json", ligand="ATP")
assert common.chain_residue_count("ligand.json", d, d, is_af3=True) == 0
# A JSON carrying a protein sequence is counted by its polymer length.
_write_af3_json(d, "complex.json", protein_len=150)
assert common.chain_residue_count("complex.json", d, d, is_af3=True) == 150


def test_fold_total_tokens_counts_protein_not_ligand_json():
common._RESIDUE_COUNT_CACHE.clear()
common._AF3_INPUT_COUNT_CACHE.clear()
with tempfile.TemporaryDirectory() as d:
_write_fasta(d, "P12345", 200)
_write_af3_json(d, "ligand.json", ligand="ATP")
# Protein counted; ligand JSON adds 0 and does not error.
assert (
common.fold_total_tokens(
"P12345+ligand.json:80", d, "+", features_dir=d, is_af3=True
)
== 200
)


def _run_all():
fns = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)]
for fn in fns:
Expand Down
82 changes: 43 additions & 39 deletions workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include: "rules/common.smk"

from pathlib import Path
from os import makedirs, listdir, symlink, remove
from os.path import abspath, join, splitext, basename, exists
from os.path import abspath, join, splitext, basename, exists, lexists, realpath

from alphapulldown_input_parser.parser import FoldDataset, generate_fold_specifications

Expand Down Expand Up @@ -51,35 +51,10 @@ def feature_name(base: str) -> str:
return f"{base}_af3_input.json" if IS_AF3 else f"{base}.{FEATURE_SUFFIX}"


def _format_af3_requested_fold(fold: str, delimiter: str) -> str:
"""Convert a logical fold specification into AF3 JSON feature inputs.

Examples:
"P01258+P0AEZ3:2" -> "P01258_af3_input.json+P0AEZ3_af3_input.json:2"
"P01258:1-100:2" -> "P01258_af3_input.json:1-100:2"
"P01258" -> "P01258_af3_input.json"

Rationale:
- Features are generated as "<base>_af3_input.json".
- Copy numbers and region ranges apply to the logical protein, not the file name.
- alphapulldown-input-parser accepts those suffixes after the JSON filename.
"""

converted_parts: list[str] = []
for token in fold.split(delimiter):
token = token.strip()
if not token:
continue

parts = [p.strip() for p in token.split(":") if p.strip()]
base = parts[0]
suffix = ":".join(parts[1:]) if len(parts) > 1 else ""

json_name = f"{base}_af3_input.json"
converted = f"{json_name}:{suffix}" if suffix else json_name
converted_parts.append(converted)

return delimiter.join(converted_parts)
# AF3 ``--input`` formatting (protein chains -> generated feature JSON; direct
# ``*.json`` inputs passed through) and protein/JSON partitioning live in
# common.smk (``format_af3_requested_fold`` / ``split_fold_inputs``) so they are
# unit-testable; both are imported via the ``include`` above.

protein_delimiter = config.get("protein_delimiter", "+")
exclude_permutations = config.get("exclude_permutations", True)
Expand Down Expand Up @@ -183,6 +158,11 @@ if MAX_PROTEIN_LENGTH > 0 or MAX_TOTAL_LENGTH > 0:
_features_dir = join(config["output_directory"], "features")

def _resolve_protein_length(name):
if is_json_input(name):
# Direct AF3 JSON input (e.g. a ligand): not a UniProt protein, so do
# not query UniProt. Count polymer residues from the file if present,
# else 0 (a known length, so the fold is kept rather than flagged).
return chain_residue_count(name, _data_dir, _features_dir, IS_AF3)
if name in sequence_length_cache:
return sequence_length_cache[name]
length = 0
Expand Down Expand Up @@ -248,14 +228,29 @@ if MAX_PROTEIN_LENGTH > 0 or MAX_TOTAL_LENGTH > 0:
f"{sorted(_unknown_proteins)[:5]}"
)

# Proteins required by the surviving folds (matches FoldDataset dedup ordering).
# Feature requirements of the surviving folds. Protein chains need features
# (downloaded + generated, or symlinked when precomputed); direct AF3 ``*.json``
# inputs (e.g. ligands) are supplied via feature_directory and are NEVER downloaded
# or generated -- they are required as the JSON file itself. Both lists preserve
# first-seen order and are de-duplicated.
kept_proteins = []
required_feature_files = []
_seen_proteins = set()
_seen_features = set()
for fold in kept_folds:
for _name in dataset.sequences_by_fold.get(fold, ()):
_protein_bases, _json_basenames = split_fold_inputs(fold, protein_delimiter)
for _name in _protein_bases:
if _name not in _seen_proteins:
_seen_proteins.add(_name)
kept_proteins.append(_name)
_feat = feature_name(_name)
if _feat not in _seen_features:
_seen_features.add(_feat)
required_feature_files.append(_feat)
for _json in _json_basenames:
if _json not in _seen_features:
_seen_features.add(_json)
required_feature_files.append(_json)

required_folds = [
join(config["output_directory"], "predictions", fold, "completed_fold.txt")
Expand All @@ -271,8 +266,8 @@ RECURSIVE_REPORT = (
)

required_feature_paths = [
join(config["output_directory"], "features", feature_name(fasta_basename))
for fasta_basename in kept_proteins
join(config["output_directory"], "features", feature_basename)
for feature_basename in required_feature_files
]
if config.get("only_generate_features", False):
required_targets = required_feature_paths
Expand Down Expand Up @@ -416,9 +411,13 @@ rule symlink_features:
**linear_resources(mem=800, runtime=10),
run:
for in_file, out_file in zip(input, output):
if exists(out_file):
source = realpath(in_file)
destination = abspath(out_file)
if source == destination:
continue
if lexists(out_file):
remove(out_file)
symlink(abspath(in_file), out_file)
symlink(source, out_file)

rule create_features:
input:
Expand Down Expand Up @@ -471,9 +470,14 @@ rule create_features:
"""

def lookup_features(wildcards):
# Inputs for inference: generated/precomputed protein features plus any direct
# AF3 JSON inputs (e.g. ligands), which are required as the JSON file itself
# rather than as a generated <name>_af3_input.json.
protein_bases, json_basenames = split_fold_inputs(wildcards.fold, protein_delimiter)
feature_files = [feature_name(base) for base in protein_bases] + list(json_basenames)
return [
join(config["output_directory"], "features", feature_name(feature))
for feature in dataset.sequences_by_fold[wildcards.fold]
join(config["output_directory"], "features", feature_basename)
for feature_basename in feature_files
]

rule structure_inference:
Expand All @@ -489,7 +493,7 @@ rule structure_inference:
for individual_fold in wildcards.fold.split(" ")
],
requested_fold = (
lambda wc: _format_af3_requested_fold(wc.fold, protein_delimiter)
lambda wc: format_af3_requested_fold(wc.fold, protein_delimiter)
if IS_AF3
else wc.fold
),
Expand Down
4 changes: 3 additions & 1 deletion workflow/envs/alphapulldown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ dependencies:
- pip:
# >=0.5.0 exposes the public parse_fold_chains() used by the length-aware
# memory model and the length filter at workflow-parse time.
- alphapulldown-input-parser>=0.5.0
# >=0.5.1 preserves ".json" tokens in FoldDataset normalization so AF3 JSON
# inputs (e.g. ligands) are not mistaken for proteins (issue #41).
- alphapulldown-input-parser>=0.5.1
87 changes: 87 additions & 0 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,83 @@ def parse_fold_chains(fold: str, delimiter: str = "+") -> list[tuple[str, int]]:
]


def is_json_input(name: str) -> bool:
"""True if a fold token names a direct AF3 JSON input (e.g. a ``ligand.json``).

Such tokens are AlphaFold 3 inputs supplied as-is via ``feature_directory``;
they are *not* proteins and must never be downloaded or sent through feature
generation. Everything else is treated as a protein chain reference.
"""
return str(name).lower().endswith(".json")


def split_fold_inputs(
fold: str, delimiter: str = "+"
) -> tuple[list[str], list[str]]:
"""Partition a fold spec into protein chains and direct AF3 JSON inputs.

Returns ``(protein_bases, json_basenames)``:

- ``protein_bases`` -- base names (path + extension stripped) of chains that
need feature generation/download, mirroring the parser's stem handling.
- ``json_basenames`` -- basenames of ``*.json`` tokens supplied directly as AF3
inputs (e.g. ligands), which are provided via ``feature_directory`` and never
generated.

Both lists preserve first-seen order and are de-duplicated. Copy numbers and
region ranges (``ligand.json:80``, ``A:1-100``) are stripped by the underlying
chain parser, so only the chain name survives here.
"""
protein_bases: list[str] = []
json_basenames: list[str] = []
seen_proteins: set[str] = set()
seen_json: set[str] = set()
for name, _copies in parse_fold_chains(fold, delimiter):
if is_json_input(name):
base = os.path.basename(name)
if base not in seen_json:
seen_json.add(base)
json_basenames.append(base)
else:
base = os.path.splitext(os.path.basename(name))[0]
if base not in seen_proteins:
seen_proteins.add(base)
protein_bases.append(base)
return protein_bases, json_basenames


def format_af3_requested_fold(fold: str, delimiter: str = "+") -> str:
"""Convert a logical fold spec into AlphaFold 3 inference ``--input`` tokens.

Protein chains map to their generated feature file ``<base>_af3_input.json``;
tokens that are already ``*.json`` (direct AF3 JSON inputs such as ligands) are
passed through unchanged. Copy numbers and region ranges are preserved after the
file name.

Examples:
``P01258+P0AEZ3:2`` -> ``P01258_af3_input.json+P0AEZ3_af3_input.json:2``
``P01258+ligand.json:80`` -> ``P01258_af3_input.json+ligand.json:80``
``P01258:1-100:2`` -> ``P01258_af3_input.json:1-100:2``

Rationale:
- Protein features are generated as ``<base>_af3_input.json``.
- JSON inputs are already AF3 inputs and must not get a second suffix.
- Copy numbers / region ranges apply to the logical chain, not the file
name; ``alphapulldown-input-parser`` accepts them after the JSON filename.
"""
converted_parts: list[str] = []
for token in str(fold).split(delimiter):
token = token.strip()
if not token:
continue
parts = [p.strip() for p in token.split(":") if p.strip()]
base = parts[0]
suffix = ":".join(parts[1:]) if len(parts) > 1 else ""
json_name = base if is_json_input(base) else f"{base}_af3_input.json"
converted_parts.append(f"{json_name}:{suffix}" if suffix else json_name)
return delimiter.join(converted_parts)


@functools.lru_cache(maxsize=None)
def fetch_uniprot_length(uniprot_id: str, timeout: float = 30.0) -> int:
"""Residue length of a UniProt entry via the REST API; 0 on any failure.
Expand Down Expand Up @@ -137,7 +214,17 @@ def chain_residue_count(
parse-time length table, which covers the AF2 precomputed-feature case where
neither a FASTA nor an AF3 JSON exists). Returns 0 when length is unknown so
sizing degrades to the base allocation plus retry escalation.

A direct AF3 JSON input (``ligand.json``) is read from the file itself in
``features_dir``; ligand-only inputs have no polymer ``sequence`` and so
contribute 0 (consistent with AF3 ligand atoms not being counted as tokens).
"""
if is_json_input(name):
if features_dir:
return af3_input_residue_count(
os.path.join(features_dir, os.path.basename(name))
)
return 0
length = residue_count(os.path.join(data_dir, f"{name}.fasta"))
if length == 0 and is_af3 and features_dir:
length = af3_input_residue_count(
Expand Down
Loading