From 13f3fb035fd439a3dfb2493ff88ca26bb37895a3 Mon Sep 17 00:00:00 2001 From: Dima Molodenskiy Date: Mon, 8 Jun 2026 16:02:54 +0200 Subject: [PATCH] Re-pin numpy<2 after jax install in alphafold2 dockerfile The `pip install --upgrade "jax[cuda12]"==0.5.3` step pulls in numpy 2.x, which breaks AlphaFold multimer at runtime: alphafold/data/msa_pairing.py does `np.sum(x for x in feats)` and numpy>=2 raises `TypeError: Calling np.sum(generator) is deprecated`, so every multimer prediction fails during feature pairing (before inference). The existing numpy<2 pins (conda block + pyproject) run before the jax upgrade and therefore do not stick. Commit fff63051 ("Tests (#600)") removed the trailing `RUN pip install "numpy<2"` (added by 69af382a) that previously force-downgraded numpy as the final dependency step. Restore it. jax 0.5.3 runs fine with numpy 1.26.x. Co-Authored-By: Claude Opus 4.8 --- docker/alphafold2.dockerfile | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docker/alphafold2.dockerfile b/docker/alphafold2.dockerfile index 6d484dc5..a8ccc85c 100644 --- a/docker/alphafold2.dockerfile +++ b/docker/alphafold2.dockerfile @@ -64,6 +64,14 @@ RUN pip3 install --upgrade pip --no-cache-dir \ && pip3 install --upgrade --no-cache-dir \ "jax[cuda12]"==0.5.3 +# `pip install --upgrade "jax[cuda12]"==0.5.3` above drags in numpy 2.x, which makes +# AlphaFold multimer fail at runtime: alphafold/data/msa_pairing.py does +# `np.sum(x for x in feats)` and numpy>=2 raises +# `TypeError: Calling np.sum(generator) is deprecated`. The numpy<2 pins earlier in +# this file (conda + pyproject) run BEFORE the jax upgrade so they do not stick; re-pin +# numpy<2 as the LAST dependency step. jax 0.5.3 runs fine with numpy 1.26.x. +RUN pip install --no-cache-dir "numpy<2" + # Strip Python caches to reduce layer size RUN find /opt/conda -type d -name "__pycache__" -prune -exec rm -rf {} + \ && find /opt/conda -type f -name "*.pyc" -delete \