Skip to content

Commit 2c7871d

Browse files
tcoratgerclaude
andauthored
perf: cache Poseidon1 engines in PoseidonXmss to avoid redundant rebuilds (#511)
compress() and sponge() were creating a new Poseidon1 engine on every call, rebuilding the circulant MDS matrix and converting round constants to numpy arrays each time. Cache engines as PrivateAttr so they are built once per PoseidonXmss instance and reused across all calls. Measured 2-2.5x speedup on pure Poseidon/XMSS tests. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1806938 commit 2c7871d

1 file changed

Lines changed: 16 additions & 7 deletions

File tree

src/lean_spec/subspecs/xmss/poseidon.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from pydantic import model_validator
25+
from pydantic import PrivateAttr, model_validator
2626

2727
from lean_spec.types import StrictBaseModel
2828

@@ -46,12 +46,25 @@ class PoseidonXmss(StrictBaseModel):
4646
params24: Poseidon1Params
4747
"""Poseidon1 parameters for 24-width permutation."""
4848

49+
_engine16: Poseidon1 | None = PrivateAttr(default=None)
50+
_engine24: Poseidon1 | None = PrivateAttr(default=None)
51+
4952
@model_validator(mode="after")
5053
def _validate_strict_types(self) -> PoseidonXmss:
5154
"""Reject subclasses to prevent type confusion attacks."""
5255
enforce_strict_types(self, params16=Poseidon1Params, params24=Poseidon1Params)
5356
return self
5457

58+
def _get_engine(self, width: int) -> Poseidon1:
59+
"""Return a cached Poseidon1 engine for the given width."""
60+
if width == 16:
61+
if self._engine16 is None:
62+
self._engine16 = Poseidon1(self.params16)
63+
return self._engine16
64+
if self._engine24 is None:
65+
self._engine24 = Poseidon1(self.params24)
66+
return self._engine24
67+
5568
def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]:
5669
"""
5770
Implements the Poseidon1 hash in **compression mode**.
@@ -85,8 +98,7 @@ def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]
8598
# Select the correct permutation parameters based on the state width.
8699
if width not in (16, 24):
87100
raise ValueError(f"Width must be 16 or 24, got {width}")
88-
params = self.params16 if width == 16 else self.params24
89-
engine = Poseidon1(params)
101+
engine = self._get_engine(width)
90102

91103
# Create a padded input by extending with zeros to match the state width.
92104
padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec))
@@ -175,7 +187,7 @@ def sponge(
175187
# Determine the permutation parameters and the size of the rate.
176188
if width not in (16, 24):
177189
raise ValueError(f"Width must be 16 or 24, got {width}")
178-
params = self.params16 if width == 16 else self.params24
190+
engine = self._get_engine(width)
179191
rate = width - len(capacity_value)
180192

181193
# Pad the input vector with zeros to be an exact multiple of the rate size.
@@ -189,9 +201,6 @@ def sponge(
189201
state = [Fp(value=0)] * width
190202
state[:cap_len] = capacity_value
191203

192-
# Create the engine once for efficiency.
193-
engine = Poseidon1(params)
194-
195204
# Absorb the input in rate-sized chunks via replacement.
196205
for i in range(0, len(padded_input), rate):
197206
chunk = padded_input[i : i + rate]

0 commit comments

Comments
 (0)