diff --git a/docker/alphafold2.dockerfile b/docker/alphafold2.dockerfile index 6d484dc5..a8ccc85c 100644 --- a/docker/alphafold2.dockerfile +++ b/docker/alphafold2.dockerfile @@ -64,6 +64,14 @@ RUN pip3 install --upgrade pip --no-cache-dir \ && pip3 install --upgrade --no-cache-dir \ "jax[cuda12]"==0.5.3 +# `pip install --upgrade "jax[cuda12]"==0.5.3` above drags in numpy 2.x, which makes +# AlphaFold multimer fail at runtime: alphafold/data/msa_pairing.py does +# `np.sum(x for x in feats)` and numpy>=2 raises +# `TypeError: Calling np.sum(generator) is deprecated`. The numpy<2 pins earlier in +# this file (conda + pyproject) run BEFORE the jax upgrade so they do not stick; re-pin +# numpy<2 as the LAST dependency step. jax 0.5.3 runs fine with numpy 1.26.x. +RUN pip install --no-cache-dir "numpy<2" + # Strip Python caches to reduce layer size RUN find /opt/conda -type d -name "__pycache__" -prune -exec rm -rf {} + \ && find /opt/conda -type f -name "*.pyc" -delete \