From 3b820abe3ebfe0f4f1800fa9bee64a070fd6e0e0 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 6 Apr 2026 22:18:30 +0200 Subject: [PATCH] perf: cache Poseidon1 engines in PoseidonXmss to avoid redundant rebuilds compress() and sponge() were creating a new Poseidon1 engine on every call, rebuilding the circulant MDS matrix and converting round constants to numpy arrays each time. Cache engines as PrivateAttr so they are built once per PoseidonXmss instance and reused across all calls. Measured 2-2.5x speedup on pure Poseidon/XMSS tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/lean_spec/subspecs/xmss/poseidon.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 790314b0..3f1ef7a5 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -22,7 +22,7 @@ from __future__ import annotations -from pydantic import model_validator +from pydantic import PrivateAttr, model_validator from lean_spec.types import StrictBaseModel @@ -46,12 +46,25 @@ class PoseidonXmss(StrictBaseModel): params24: Poseidon1Params """Poseidon1 parameters for 24-width permutation.""" + _engine16: Poseidon1 | None = PrivateAttr(default=None) + _engine24: Poseidon1 | None = PrivateAttr(default=None) + @model_validator(mode="after") def _validate_strict_types(self) -> PoseidonXmss: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, params16=Poseidon1Params, params24=Poseidon1Params) return self + def _get_engine(self, width: int) -> Poseidon1: + """Return a cached Poseidon1 engine for the given width.""" + if width == 16: + if self._engine16 is None: + self._engine16 = Poseidon1(self.params16) + return self._engine16 + if self._engine24 is None: + self._engine24 = Poseidon1(self.params24) + return self._engine24 + def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: """ Implements the Poseidon1 hash in **compression mode**. @@ -85,8 +98,7 @@ def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp] # Select the correct permutation parameters based on the state width. if width not in (16, 24): raise ValueError(f"Width must be 16 or 24, got {width}") - params = self.params16 if width == 16 else self.params24 - engine = Poseidon1(params) + engine = self._get_engine(width) # Create a padded input by extending with zeros to match the state width. padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec)) @@ -175,7 +187,7 @@ def sponge( # Determine the permutation parameters and the size of the rate. if width not in (16, 24): raise ValueError(f"Width must be 16 or 24, got {width}") - params = self.params16 if width == 16 else self.params24 + engine = self._get_engine(width) rate = width - len(capacity_value) # Pad the input vector with zeros to be an exact multiple of the rate size. @@ -189,9 +201,6 @@ def sponge( state = [Fp(value=0)] * width state[:cap_len] = capacity_value - # Create the engine once for efficiency. - engine = Poseidon1(params) - # Absorb the input in rate-sized chunks via replacement. for i in range(0, len(padded_input), rate): chunk = padded_input[i : i + rate]