Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/lean_spec/subspecs/xmss/poseidon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from __future__ import annotations

from pydantic import model_validator
from pydantic import PrivateAttr, model_validator

from lean_spec.types import StrictBaseModel

Expand All @@ -46,12 +46,25 @@ class PoseidonXmss(StrictBaseModel):
params24: Poseidon1Params
"""Poseidon1 parameters for 24-width permutation."""

_engine16: Poseidon1 | None = PrivateAttr(default=None)
_engine24: Poseidon1 | None = PrivateAttr(default=None)

@model_validator(mode="after")
def _validate_strict_types(self) -> PoseidonXmss:
"""Reject subclasses to prevent type confusion attacks."""
enforce_strict_types(self, params16=Poseidon1Params, params24=Poseidon1Params)
return self

def _get_engine(self, width: int) -> Poseidon1:
"""Return a cached Poseidon1 engine for the given width."""
if width == 16:
if self._engine16 is None:
self._engine16 = Poseidon1(self.params16)
return self._engine16
if self._engine24 is None:
self._engine24 = Poseidon1(self.params24)
return self._engine24

def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]:
"""
Implements the Poseidon1 hash in **compression mode**.
Expand Down Expand Up @@ -85,8 +98,7 @@ def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]
# Select the correct permutation parameters based on the state width.
if width not in (16, 24):
raise ValueError(f"Width must be 16 or 24, got {width}")
params = self.params16 if width == 16 else self.params24
engine = Poseidon1(params)
engine = self._get_engine(width)

# Create a padded input by extending with zeros to match the state width.
padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec))
Expand Down Expand Up @@ -175,7 +187,7 @@ def sponge(
# Determine the permutation parameters and the size of the rate.
if width not in (16, 24):
raise ValueError(f"Width must be 16 or 24, got {width}")
params = self.params16 if width == 16 else self.params24
engine = self._get_engine(width)
rate = width - len(capacity_value)

# Pad the input vector with zeros to be an exact multiple of the rate size.
Expand All @@ -189,9 +201,6 @@ def sponge(
state = [Fp(value=0)] * width
state[:cap_len] = capacity_value

# Create the engine once for efficiency.
engine = Poseidon1(params)

# Absorb the input in rate-sized chunks via replacement.
for i in range(0, len(padded_input), rate):
chunk = padded_input[i : i + rate]
Expand Down
Loading