diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 790314b0..3f1ef7a5 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -22,7 +22,7 @@ from __future__ import annotations -from pydantic import model_validator +from pydantic import PrivateAttr, model_validator from lean_spec.types import StrictBaseModel @@ -46,12 +46,25 @@ class PoseidonXmss(StrictBaseModel): params24: Poseidon1Params """Poseidon1 parameters for 24-width permutation.""" + _engine16: Poseidon1 | None = PrivateAttr(default=None) + _engine24: Poseidon1 | None = PrivateAttr(default=None) + @model_validator(mode="after") def _validate_strict_types(self) -> PoseidonXmss: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, params16=Poseidon1Params, params24=Poseidon1Params) return self + def _get_engine(self, width: int) -> Poseidon1: + """Return a cached Poseidon1 engine for the given width.""" + if width == 16: + if self._engine16 is None: + self._engine16 = Poseidon1(self.params16) + return self._engine16 + if self._engine24 is None: + self._engine24 = Poseidon1(self.params24) + return self._engine24 + def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: """ Implements the Poseidon1 hash in **compression mode**. @@ -85,8 +98,7 @@ def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp] # Select the correct permutation parameters based on the state width. if width not in (16, 24): raise ValueError(f"Width must be 16 or 24, got {width}") - params = self.params16 if width == 16 else self.params24 - engine = Poseidon1(params) + engine = self._get_engine(width) # Create a padded input by extending with zeros to match the state width. padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec)) @@ -175,7 +187,7 @@ def sponge( # Determine the permutation parameters and the size of the rate. if width not in (16, 24): raise ValueError(f"Width must be 16 or 24, got {width}") - params = self.params16 if width == 16 else self.params24 + engine = self._get_engine(width) rate = width - len(capacity_value) # Pad the input vector with zeros to be an exact multiple of the rate size. @@ -189,9 +201,6 @@ def sponge( state = [Fp(value=0)] * width state[:cap_len] = capacity_value - # Create the engine once for efficiency. - engine = Poseidon1(params) - # Absorb the input in rate-sized chunks via replacement. for i in range(0, len(padded_input), rate): chunk = padded_input[i : i + rate]