Skip to content

Commit 0027d58

Browse files
tcoratgerclaude
andauthored
refactor: move attestation aggregation logic to natural homes (#512)
Move `State.aggregate()` and `State._select_proofs_greedily()` out of State, which only used `self.validators` as a lookup table. - `select_greedily` → `AggregatedSignatureProof.select_greedily()` completes the proof lifecycle (select, aggregate, verify) - Aggregation orchestration → `Store.aggregate()` which owns all three input pools (gossip sigs, new payloads, known payloads) - Proof compaction → inlined in `State.build_block()` where it is used No new files or types introduced. Tests updated to go through Store. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2c7871d commit 0027d58

4 files changed

Lines changed: 225 additions & 182 deletions

File tree

src/lean_spec/subspecs/containers/state/state.py

Lines changed: 17 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from collections.abc import Iterable
66
from collections.abc import Set as AbstractSet
7-
from typing import TYPE_CHECKING
87

98
from lean_spec.subspecs.ssz.hash import hash_tree_root
109
from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof
@@ -16,13 +15,13 @@
1615
Uint64,
1716
)
1817

19-
from ..attestation import AggregatedAttestation, AggregationBits, AttestationData
18+
from ..attestation import AggregatedAttestation, AttestationData
2019
from ..block import Block, BlockBody, BlockHeader
2120
from ..block.types import AggregatedAttestations
2221
from ..checkpoint import Checkpoint
2322
from ..config import Config
2423
from ..slot import Slot
25-
from ..validator import ValidatorIndex, ValidatorIndices
24+
from ..validator import ValidatorIndex
2625
from .types import (
2726
HistoricalBlockHashes,
2827
JustificationRoots,
@@ -31,9 +30,6 @@
3130
Validators,
3231
)
3332

34-
if TYPE_CHECKING:
35-
from lean_spec.subspecs.forkchoice import AttestationSignatureEntry
36-
3733

3834
class State(Container):
3935
"""The main consensus state object."""
@@ -692,7 +688,7 @@ def build_block(
692688

693689
found_entries = True
694690

695-
selected, _ = self._select_proofs_greedily(proofs)
691+
selected, _ = AggregatedSignatureProof.select_greedily(proofs)
696692
aggregated_signatures.extend(selected)
697693
for proof in selected:
698694
aggregated_attestations.append(
@@ -725,16 +721,23 @@ def build_block(
725721

726722
# Compact: merge all proofs sharing the same AttestationData into one
727723
# using recursive children aggregation.
724+
#
725+
# During the fixed-point loop above, multiple proofs may have been
726+
# selected for the same AttestationData across iterations. Group them
727+
# and merge each group into a single recursive proof.
728728
proof_groups: dict[AttestationData, list[AggregatedSignatureProof]] = {}
729729
for att, sig in zip(aggregated_attestations, aggregated_signatures, strict=True):
730730
proof_groups.setdefault(att.data, []).append(sig)
731731

732-
compacted_attestations: list[AggregatedAttestation] = []
733-
compacted_signatures: list[AggregatedSignatureProof] = []
732+
aggregated_attestations = []
733+
aggregated_signatures = []
734734
for att_data, proofs in proof_groups.items():
735735
if len(proofs) == 1:
736-
compacted_signatures.append(proofs[0])
736+
sig = proofs[0]
737737
else:
738+
# Multiple proofs for the same data were aggregated separately.
739+
# Merge them into one recursive proof using children-only
740+
# aggregation (no new raw signatures).
738741
children = [
739742
(
740743
proof,
@@ -745,24 +748,18 @@ def build_block(
745748
)
746749
for proof in proofs
747750
]
748-
merged = AggregatedSignatureProof.aggregate(
751+
sig = AggregatedSignatureProof.aggregate(
749752
xmss_participants=None,
750753
children=children,
751754
raw_xmss=[],
752755
message=att_data.data_root_bytes(),
753756
slot=att_data.slot,
754757
)
755-
compacted_signatures.append(merged)
756-
compacted_attestations.append(
757-
AggregatedAttestation(
758-
aggregation_bits=compacted_signatures[-1].participants,
759-
data=att_data,
760-
)
758+
aggregated_signatures.append(sig)
759+
aggregated_attestations.append(
760+
AggregatedAttestation(aggregation_bits=sig.participants, data=att_data)
761761
)
762762

763-
aggregated_attestations = compacted_attestations
764-
aggregated_signatures = compacted_signatures
765-
766763
# Create the final block with selected attestations.
767764
final_block = Block(
768765
slot=slot,
@@ -779,116 +776,3 @@ def build_block(
779776
final_block = final_block.model_copy(update={"state_root": hash_tree_root(post_state)})
780777

781778
return final_block, post_state, aggregated_attestations, aggregated_signatures
782-
783-
@staticmethod
784-
def _select_proofs_greedily(
785-
*proof_sets: set[AggregatedSignatureProof] | None,
786-
) -> tuple[list[AggregatedSignatureProof], set[ValidatorIndex]]:
787-
"""
788-
Greedy set-cover selection of signature proofs to maximize validator coverage.
789-
790-
Repeatedly selects the proof covering the most uncovered validators until
791-
no proof adds new coverage. Earlier proof sets are prioritized.
792-
793-
Args:
794-
proof_sets: Candidate proof sets in priority order.
795-
796-
Returns:
797-
Selected proofs and the set of covered validator indices.
798-
"""
799-
selected: list[AggregatedSignatureProof] = []
800-
covered: set[ValidatorIndex] = set()
801-
for proofs in proof_sets:
802-
if not proofs:
803-
continue
804-
remaining = list(proofs)
805-
while remaining:
806-
# Pick the proof that covers the most new validators.
807-
best = max(
808-
remaining,
809-
key=lambda p: len(set(p.participants.to_validator_indices()) - covered),
810-
)
811-
new_coverage = set(best.participants.to_validator_indices()) - covered
812-
# Stop when no proof in this set adds new coverage.
813-
if not new_coverage:
814-
break
815-
selected.append(best)
816-
covered.update(new_coverage)
817-
remaining.remove(best)
818-
return selected, covered
819-
820-
def aggregate(
821-
self,
822-
attestation_signatures: dict[AttestationData, set[AttestationSignatureEntry]] | None = None,
823-
new_payloads: dict[AttestationData, set[AggregatedSignatureProof]] | None = None,
824-
known_payloads: dict[AttestationData, set[AggregatedSignatureProof]] | None = None,
825-
) -> list[tuple[AggregatedAttestation, AggregatedSignatureProof]]:
826-
"""
827-
Aggregate gossip signatures using new payloads, with known payloads as helpers.
828-
829-
Args:
830-
attestation_signatures: Raw XMSS signatures from gossip, keyed by attestation data.
831-
new_payloads: Aggregated proofs pending processing (child proofs).
832-
known_payloads: Known aggregated proofs already accepted.
833-
834-
Returns:
835-
List of (attestation, proof) pairs from aggregation.
836-
"""
837-
gossip_sigs = attestation_signatures or {}
838-
new = new_payloads or {}
839-
known = known_payloads or {}
840-
841-
attestation_keys = new.keys() | gossip_sigs.keys()
842-
if not attestation_keys:
843-
return []
844-
845-
results: list[tuple[AggregatedAttestation, AggregatedSignatureProof]] = []
846-
847-
for data in attestation_keys:
848-
# Phase 1: Greedily select child proofs for maximum validator coverage.
849-
# New payloads are prioritized over known payloads.
850-
child_proofs, covered = self._select_proofs_greedily(new.get(data), known.get(data))
851-
852-
# Phase 2: Collect raw XMSS signatures for validators not yet covered.
853-
# Sorted by validator index for deterministic output.
854-
raw_entries = [
855-
(
856-
e.validator_id,
857-
self.validators[e.validator_id].get_attestation_pubkey(),
858-
e.signature,
859-
)
860-
for e in sorted(gossip_sigs.get(data, set()), key=lambda e: e.validator_id)
861-
if e.validator_id not in covered
862-
]
863-
864-
# Need at least one raw signature, or two child proofs to aggregate.
865-
if not raw_entries and len(child_proofs) < 2:
866-
continue
867-
868-
xmss_participants = AggregationBits.from_validator_indices(
869-
ValidatorIndices(data=[vid for vid, _, _ in raw_entries])
870-
)
871-
raw_xmss = [(pk, sig) for _, pk, sig in raw_entries]
872-
873-
# Phase 3: Build recursive children with their public keys from the registry.
874-
children = [
875-
(
876-
child,
877-
[
878-
self.validators[vid].get_attestation_pubkey()
879-
for vid in child.participants.to_validator_indices()
880-
],
881-
)
882-
for child in child_proofs
883-
]
884-
proof = AggregatedSignatureProof.aggregate(
885-
xmss_participants=xmss_participants,
886-
children=children,
887-
raw_xmss=raw_xmss,
888-
message=data.data_root_bytes(),
889-
slot=data.slot,
890-
)
891-
attestation = AggregatedAttestation(aggregation_bits=proof.participants, data=data)
892-
results.append((attestation, proof))
893-
894-
return results

src/lean_spec/subspecs/forkchoice/store.py

Lines changed: 122 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
JUSTIFICATION_LOOKBACK_SLOTS,
1717
)
1818
from lean_spec.subspecs.containers import (
19+
AggregationBits,
1920
AttestationData,
2021
Block,
2122
Checkpoint,
@@ -28,6 +29,7 @@
2829
from lean_spec.subspecs.containers.attestation.attestation import SignedAggregatedAttestation
2930
from lean_spec.subspecs.containers.block import BlockLookup
3031
from lean_spec.subspecs.containers.slot import Slot
32+
from lean_spec.subspecs.containers.validator import ValidatorIndices
3133
from lean_spec.subspecs.metrics import registry as metrics
3234
from lean_spec.subspecs.ssz.hash import hash_tree_root
3335
from lean_spec.subspecs.xmss.aggregation import (
@@ -928,32 +930,134 @@ def update_safe_target(self) -> "Store":
928930

929931
def aggregate(self) -> tuple["Store", list[SignedAggregatedAttestation]]:
930932
"""
931-
Aggregate committee signatures and payloads together.
933+
Turn raw validator votes into compact aggregated attestations.
932934
933-
This method aggregates signatures from the attestation_signatures map.
935+
Validators cast individual signatures over gossip. Before those
936+
votes can influence fork choice or be included in a block, they
937+
must be combined into compact cryptographic proofs.
938+
939+
The store holds three pools of attestation evidence:
940+
941+
- **Gossip signatures**: individual validator votes arriving in real-time.
942+
- **New payloads**: aggregated proofs from the current round, not yet
943+
committed to the chain.
944+
- **Known payloads**: previously accepted proofs, reusable as building
945+
blocks for deeper aggregation.
946+
947+
For each unique piece of attestation data the algorithm proceeds in three phases:
948+
949+
1. **Select** — greedily pick existing proofs that maximize
950+
validator coverage (new before known).
951+
2. **Fill** — collect raw gossip signatures for any validators
952+
not yet covered.
953+
3. **Aggregate** — delegate to the XMSS subspec to produce a
954+
single cryptographic proof.
955+
956+
After aggregation the store is updated:
957+
958+
- Consumed gossip signatures are removed.
959+
- Newly produced proofs are recorded for future reuse.
934960
935961
Returns:
936-
Tuple of (new Store with updated payloads, list of new SignedAggregatedAttestation).
962+
Updated store and the list of freshly produced signed attestations.
937963
"""
938-
head_state = self.states[self.head]
939-
940-
aggregated_results = head_state.aggregate(
941-
attestation_signatures=self.attestation_signatures,
942-
new_payloads=self.latest_new_aggregated_payloads,
943-
known_payloads=self.latest_known_aggregated_payloads,
944-
)
964+
validators = self.states[self.head].validators
965+
gossip_sigs = self.attestation_signatures
966+
new = self.latest_new_aggregated_payloads
967+
known = self.latest_known_aggregated_payloads
945968

946969
new_aggregates: list[SignedAggregatedAttestation] = []
970+
971+
# Only attestation data with a new payload or a raw gossip signature
972+
# can trigger aggregation. Known payloads alone cannot — they exist
973+
# only to help extend coverage when combined with fresh evidence.
974+
for data in new.keys() | gossip_sigs.keys():
975+
# Phase 1: Select
976+
#
977+
# Start with the cheapest option: reuse proofs that already
978+
# cover many validators.
979+
#
980+
# Child proofs are aggregated signatures from prior rounds.
981+
# Selecting them first keeps the final proof tree shallow
982+
# and avoids redundant cryptographic work.
983+
#
984+
# New payloads go first because they represent uncommitted
985+
# work — known payloads fill remaining gaps.
986+
child_proofs, covered = AggregatedSignatureProof.select_greedily(
987+
new.get(data), known.get(data)
988+
)
989+
990+
# Phase 2: Fill
991+
#
992+
# For every validator not yet covered by a child proof,
993+
# include its individual gossip signature.
994+
#
995+
# Sorting by validator index guarantees deterministic proof
996+
# construction regardless of network arrival order.
997+
raw_entries = [
998+
(
999+
e.validator_id,
1000+
validators[e.validator_id].get_attestation_pubkey(),
1001+
e.signature,
1002+
)
1003+
for e in sorted(gossip_sigs.get(data, set()), key=lambda e: e.validator_id)
1004+
if e.validator_id not in covered
1005+
]
1006+
1007+
# The XMSS layer enforces a minimum: either at least one raw
1008+
# signature, or at least two child proofs to merge.
1009+
#
1010+
# A lone child proof is already a valid proof — nothing to do.
1011+
if not raw_entries and len(child_proofs) < 2:
1012+
continue
1013+
1014+
# Encode the set of raw signers as a compact bitfield.
1015+
xmss_participants = AggregationBits.from_validator_indices(
1016+
ValidatorIndices(data=[vid for vid, _, _ in raw_entries])
1017+
)
1018+
raw_xmss = [(pk, sig) for _, pk, sig in raw_entries]
1019+
1020+
# Phase 3: Aggregate
1021+
#
1022+
# Build the recursive proof tree.
1023+
#
1024+
# Each child proof needs its participants' public keys so
1025+
# the XMSS prover can verify inner proofs while constructing
1026+
# the outer one.
1027+
children = [
1028+
(
1029+
child,
1030+
[
1031+
validators[vid].get_attestation_pubkey()
1032+
for vid in child.participants.to_validator_indices()
1033+
],
1034+
)
1035+
for child in child_proofs
1036+
]
1037+
1038+
# Hand everything to the XMSS subspec.
1039+
# Out comes a single proof covering all selected validators.
1040+
proof = AggregatedSignatureProof.aggregate(
1041+
xmss_participants=xmss_participants,
1042+
children=children,
1043+
raw_xmss=raw_xmss,
1044+
message=data.data_root_bytes(),
1045+
slot=data.slot,
1046+
)
1047+
new_aggregates.append(SignedAggregatedAttestation(data=data, proof=proof))
1048+
1049+
# ── Store bookkeeping ────────────────────────────────────────
1050+
#
1051+
# Record freshly produced proofs so future rounds can reuse them.
1052+
# Remove gossip signatures that were consumed by this aggregation.
9471053
new_aggregated_payloads: dict[AttestationData, set[AggregatedSignatureProof]] = {}
948-
aggregated_attestation_data: set[AttestationData] = set()
949-
for att, proof in aggregated_results:
950-
aggregated_attestation_data.add(att.data)
951-
new_aggregates.append(SignedAggregatedAttestation(data=att.data, proof=proof))
952-
new_aggregated_payloads.setdefault(att.data, set()).add(proof)
1054+
for signed_att in new_aggregates:
1055+
new_aggregated_payloads.setdefault(signed_att.data, set()).add(signed_att.proof)
1056+
9531057
remaining_attestation_signatures = {
954-
attestation_data: signatures
955-
for attestation_data, signatures in self.attestation_signatures.items()
956-
if attestation_data not in aggregated_attestation_data
1058+
data: sigs
1059+
for data, sigs in self.attestation_signatures.items()
1060+
if data not in new_aggregated_payloads
9571061
}
9581062

9591063
return self.model_copy(

0 commit comments

Comments
 (0)