From d1741c412ff5b98bc8c833f5e31aebbd61f9b981 Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 24 Oct 2025 12:06:38 +1100 Subject: [PATCH] Add storage to generic merkle tree - Adds storage trait - Adds default in memory storage - Adds results - closes #95 --- crates/merkle/src/fixed_sparse_merkle.rs | 2 +- crates/merkle/src/generic_sparse_merkle.rs | 342 +++++++++++++++------ 2 files changed, 250 insertions(+), 94 deletions(-) diff --git a/crates/merkle/src/fixed_sparse_merkle.rs b/crates/merkle/src/fixed_sparse_merkle.rs index bb5cb10..0279b3e 100644 --- a/crates/merkle/src/fixed_sparse_merkle.rs +++ b/crates/merkle/src/fixed_sparse_merkle.rs @@ -88,7 +88,7 @@ impl StorageMut for InMemoryStorage { /// /// * `H` - The hasher implementation to use for computing node hashes #[derive(Debug)] -pub struct SparseMerkleTree { +pub struct SparseMerkleTree { /// Underlying storage for nodes and leaves storage: S, /// Default hashes for each level (empty subtree roots) diff --git a/crates/merkle/src/generic_sparse_merkle.rs b/crates/merkle/src/generic_sparse_merkle.rs index aab2922..8e477ec 100644 --- a/crates/merkle/src/generic_sparse_merkle.rs +++ b/crates/merkle/src/generic_sparse_merkle.rs @@ -1,6 +1,6 @@ use alloy::primitives::{Address, B256, keccak256}; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, marker::PhantomData}; +use std::{collections::HashMap, convert::Infallible, marker::PhantomData}; /// Trait for hash function implementations used in Merkle trees. /// @@ -12,6 +12,54 @@ pub trait Hasher: Default { fn hash(self, data: &[&[u8]]) -> [u8; 32]; } +/// Trait for storage backends used by the sparse Merkle tree. +pub trait Storage { + /// Error type for storage operations. + type Error; + + /// Retrieves a leaf value at the given index. + /// + /// Returns `None` if no value exists at the index. + fn get_leaf(&self, index: &B256) -> Result, Self::Error>; + + /// Retrieves a node hash at the given level and index. + /// + /// Returns `None` if no node exists at the specified position. + fn get_node(&self, level: u8, index: B256) -> Result, Self::Error>; + + /// Checks if a leaf exists at the given index. + fn leaf_index_exists(&self, index: &B256) -> Result; + + /// Returns an iterator over all stored nodes. + /// + /// Each item is a tuple of (level, index, hash). + fn nodes(&self) -> Result + '_, Self::Error>; + + /// Returns an iterator over all stored leaf values. + /// + /// Each item is a tuple of (index, value). + fn leaves(&self) -> Result, Self::Error>; + /// Get number of leaves stored in the tree. + fn num_leaves(&self) -> Result; + /// Get number of nodes stored (for debugging/analysis). + fn num_nodes(&self) -> Result; +} + +/// Trait for storage backends used by the sparse Merkle tree. +pub trait StorageMut: Storage { + /// Inserts a leaf value at the given index. + fn insert_leaf(&mut self, index: B256, value: V) -> Result<(), Self::Error>; + + /// Deletes a leaf value at the given index. + fn delete_leaf(&mut self, index: &B256) -> Result<(), Self::Error>; + + /// Inserts a node hash at the given level and index. + fn insert_node(&mut self, level: u8, index: B256, hash: [u8; 32]) -> Result<(), Self::Error>; + + /// Deletes a node at the given level and index. + fn delete_node(&mut self, level: u8, index: B256) -> Result<(), Self::Error>; +} + /// Keccak256 hasher implementation for Merkle trees. /// /// This hasher uses the Keccak256 algorithm, which is the same hash function @@ -46,19 +94,26 @@ pub trait MerkleValue: Clone + std::fmt::Debug + PartialEq { } } -/// Sparse Merkle Tree with 256-bit keys. -/// Only stores non-default nodes to minimize storage -#[derive(Debug)] -pub struct GenericSparseMerkleTree { +/// In-memory storage backend for the sparse Merkle tree. +#[derive(Debug, Clone)] +pub struct InMemoryStorage { /// Storage for non-default nodes only: (level, index) -> hash /// Level 0 = leaves, Level 255 = root nodes: HashMap<(u8, B256), [u8; 32]>, /// Storage for leaf values (only non-empty values) leaves: HashMap, +} + +/// Sparse Merkle Tree with 256-bit keys. +/// Only stores non-default nodes to minimize storage +#[derive(Debug)] +pub struct GenericSparseMerkleTree> { + storage: S, /// Pre-computed default hashes for each level (256 levels: 0-255) default_hashes: Box<[[u8; 32]; 256]>, /// Phantom data for the hasher type _hasher: PhantomData, + _value: PhantomData, } /// Merkle proof for verifying inclusion or exclusion in the tree. @@ -72,86 +127,182 @@ pub struct MerkleProof { pub root: [u8; 32], } -impl Default for GenericSparseMerkleTree { +impl Storage for InMemoryStorage { + type Error = Infallible; + + fn get_leaf(&self, index: &B256) -> Result, Self::Error> { + Ok(self.leaves.get(index).cloned()) + } + + fn get_node(&self, level: u8, index: B256) -> Result, Self::Error> { + Ok(self.nodes.get(&(level, index)).cloned()) + } + + fn leaf_index_exists(&self, index: &B256) -> Result { + Ok(self.leaves.contains_key(index)) + } + + fn nodes(&self) -> Result + '_, Self::Error> { + Ok(self + .nodes + .iter() + .map(|(&(level, index), &hash)| (level, index, hash))) + } + + fn leaves(&self) -> Result, Self::Error> { + Ok(self + .leaves + .iter() + .map(|(&index, value)| (index, value.clone()))) + } + + /// Get number of values stored in the tree. + fn num_leaves(&self) -> Result { + Ok(self.leaves.len()) + } + + /// Get number of nodes stored (for debugging/analysis). + fn num_nodes(&self) -> Result { + Ok(self.nodes.len()) + } +} + +impl StorageMut for InMemoryStorage { + fn insert_leaf(&mut self, index: B256, value: V) -> Result<(), Self::Error> { + self.leaves.insert(index, value); + Ok(()) + } + + fn delete_leaf(&mut self, index: &B256) -> Result<(), Self::Error> { + self.leaves.remove(index); + Ok(()) + } + + fn insert_node(&mut self, level: u8, index: B256, hash: [u8; 32]) -> Result<(), Self::Error> { + self.nodes.insert((level, index), hash); + Ok(()) + } + + fn delete_node(&mut self, level: u8, index: B256) -> Result<(), Self::Error> { + self.nodes.remove(&(level, index)); + Ok(()) + } +} + +impl Default for GenericSparseMerkleTree { fn default() -> Self { - Self::new() + Self::new(S::default()) } } -impl GenericSparseMerkleTree { - /// Create a new Sparse Merkle Tree with 256 levels (0-255). - pub fn new() -> Self { - let mut default_hashes = Box::new([[0u8; 32]; 256]); +/// Generates default hashes for all 256 levels of the tree. +/// +/// Default hashes represent the hash of an empty subtree at each level. +/// Level 0 is the hash of an empty leaf value, and each subsequent level +/// is the hash of two identical children from the level below. +/// +/// # Type Parameters +/// +/// * `H` - The hasher implementation to use +/// * `V` - The value type that implements [`MerkleValue`] +/// +/// # Returns +/// +/// Returns a boxed array of 256 default hashes, one for each tree level. +pub fn generate_default_hashes() -> Box<[[u8; 32]; 256]> { + let mut default_hashes = Box::new([[0u8; 32]; 256]); + + // Level 0 (leaves) - default is hash of empty value + let empty_value = V::empty(); + default_hashes[0] = hash_leaf::(empty_value.to_bytes().as_ref()); + + // Build default hashes for each level up to the root (level 255) + let mut level = 1; + while level < 256 { + let child_hash = default_hashes[level - 1]; + default_hashes[level] = hash_pair::(&child_hash, &child_hash); + level += 1; + } - // Level 0 (leaves) - default is hash of empty value - let empty_value = V::empty(); - default_hashes[0] = hash_leaf::(empty_value.to_bytes().as_ref()); + default_hashes +} - // Build default hashes for each level up to the root (level 255) - for level in 1..256 { - let child_hash = default_hashes[level - 1]; - default_hashes[level] = hash_pair::(&child_hash, &child_hash); - } +impl GenericSparseMerkleTree { + /// Creates a new sparse Merkle tree with in-memory storage. + pub fn new_in_memory() -> Self { + Self::new(InMemoryStorage::default()) + } +} - GenericSparseMerkleTree { - nodes: HashMap::new(), - leaves: HashMap::new(), - default_hashes, +impl GenericSparseMerkleTree { + /// Create a new Sparse Merkle Tree with 256 levels (0-255). + pub fn new(storage: S) -> Self { + Self { + storage, + default_hashes: generate_default_hashes::(), _hasher: PhantomData, + _value: PhantomData, } } +} +impl GenericSparseMerkleTree +where + S: StorageMut, + H: Hasher, + V: MerkleValue, +{ /// Insert or update a value using user and token pair. - pub fn insert(&mut self, user: Address, token: Address, value: V) { + pub fn insert(&mut self, user: Address, token: Address, value: V) -> Result<(), S::Error> { let index = account_key(user, token); - self.insert_raw(index, value); + self.insert_raw(index, value) } /// Delete a value (sets it to empty). - pub fn delete(&mut self, user: Address, token: Address) { + pub fn delete(&mut self, user: Address, token: Address) -> Result<(), S::Error> { let index = account_key(user, token); - self.delete_raw(index); + self.delete_raw(index) } /// Insert with raw B256 index. - pub fn insert_raw(&mut self, index: B256, value: V) { + pub fn insert_raw(&mut self, index: B256, value: V) -> Result<(), S::Error> { if value.is_empty() { - self.delete_raw(index); - return; + self.delete_raw(index)?; + return Ok(()); } // Store the leaf value - self.leaves.insert(index, value.clone()); + self.storage.insert_leaf(index, value)?; // Update the tree by propagating changes up - self.update_path(index); + self.update_path(index) } /// Delete with raw B256 index. - pub fn delete_raw(&mut self, index: B256) { + pub fn delete_raw(&mut self, index: B256) -> Result<(), S::Error> { // Remove the leaf value - self.leaves.remove(&index); + self.storage.delete_leaf(&index)?; // Update the tree by propagating changes up - self.update_path(index); + self.update_path(index) } /// Update the path from a leaf to the root, storing only non-default nodes. - fn update_path(&mut self, index: B256) { + fn update_path(&mut self, index: B256) -> Result<(), S::Error> { let mut current_index = index; for level in 0..255 { // Determine current node's hash let current_hash = if level == 0 { // Leaf level - if let Some(value) = self.leaves.get(¤t_index) { + if let Some(value) = self.storage.get_leaf(¤t_index)? { hash_leaf::(value.to_bytes().as_ref()) } else { self.default_hashes[0] // Empty leaf } } else { // Internal node - get from storage or compute from children - if let Some(&hash) = self.nodes.get(&(level as u8, current_index)) { + if let Some(hash) = self.storage.get_node(level as u8, current_index)? { hash } else { self.default_hashes[level] @@ -167,16 +318,15 @@ impl GenericSparseMerkleTree { // Get sibling hash let sibling_hash = if level == 0 { // Sibling is a leaf - if let Some(sibling_value) = self.leaves.get(&sibling_index) { + if let Some(sibling_value) = self.storage.get_leaf(&sibling_index)? { hash_leaf::(sibling_value.to_bytes().as_ref()) } else { self.default_hashes[0] } } else { // Sibling is an internal node - self.nodes - .get(&(level as u8, sibling_index)) - .copied() + self.storage + .get_node(level as u8, sibling_index)? .unwrap_or(self.default_hashes[level]) }; @@ -190,50 +340,51 @@ impl GenericSparseMerkleTree { // Only store if different from default (this is the key sparse optimization!) if parent_hash == self.default_hashes[parent_level] { // Remove if it exists (node became default) - self.nodes.remove(&(parent_level as u8, parent_index)); + self.storage.delete_node(parent_level as u8, parent_index)?; } else { // Store non-default node - self.nodes - .insert((parent_level as u8, parent_index), parent_hash); + self.storage + .insert_node(parent_level as u8, parent_index, parent_hash)?; } current_index = parent_index; } + Ok(()) } /// Get the root hash of the tree. - pub fn root(&self) -> B256 { + pub fn root(&self) -> Result { // Root is at level 255 - find any node at this level // In our current implementation, there should be only one - for ((level, _idx), &hash) in &self.nodes { - if *level == 255 { - return B256::from(hash); + for (level, _idx, hash) in self.storage.nodes()? { + if level == 255 { + return Ok(B256::from(hash)); } } // If no root node exists, return the default hash for level 255 - B256::from(self.default_hashes[255]) + Ok(B256::from(self.default_hashes[255])) } /// Get a value by user and token. - pub fn get(&self, user: Address, token: Address) -> V { + pub fn get(&self, user: Address, token: Address) -> Result { let index = account_key(user, token); - self.leaves.get(&index).cloned().unwrap_or_else(V::empty) + Ok(self.storage.get_leaf(&index)?.unwrap_or_else(V::empty)) } /// Get a value by raw index. - pub fn get_by_index(&self, index: B256) -> V { - self.leaves.get(&index).cloned().unwrap_or_else(V::empty) + pub fn get_by_index(&self, index: &B256) -> Result { + Ok(self.storage.get_leaf(index)?.unwrap_or_else(V::empty)) } /// Generate a Merkle proof for a given account. - pub fn generate_proof(&self, user: Address, token: Address) -> MerkleProof { + pub fn generate_proof(&self, user: Address, token: Address) -> Result { let index = account_key(user, token); self.generate_proof_raw(index) } /// Generate proof with raw B256 index. - pub fn generate_proof_raw(&self, index: B256) -> MerkleProof { + pub fn generate_proof_raw(&self, index: B256) -> Result { let mut siblings = Vec::with_capacity(255); let mut current_index = index; @@ -244,16 +395,15 @@ impl GenericSparseMerkleTree { // Get sibling hash let sibling_hash = if level == 0 { // Sibling is a leaf - if let Some(sibling_value) = self.leaves.get(&sibling_index) { + if let Some(sibling_value) = self.storage.get_leaf(&sibling_index)? { hash_leaf::(sibling_value.to_bytes().as_ref()) } else { self.default_hashes[0] } } else { // Sibling is an internal node - self.nodes - .get(&(level as u8, sibling_index)) - .copied() + self.storage + .get_node(level as u8, sibling_index)? .unwrap_or(self.default_hashes[level]) }; @@ -262,23 +412,23 @@ impl GenericSparseMerkleTree { } // Get leaf hash - let leaf_hash = if let Some(value) = self.leaves.get(&index) { + let leaf_hash = if let Some(value) = self.storage.get_leaf(&index)? { hash_leaf::(value.to_bytes().as_ref()) } else { self.default_hashes[0] }; - MerkleProof { + Ok(MerkleProof { leaf: leaf_hash, siblings, - root: *self.root(), - } + root: *self.root()?, + }) } /// Check if a value exists for the given account. - pub fn has_value(&self, user: Address, token: Address) -> bool { + pub fn has_value(&self, user: Address, token: Address) -> Result { let index = account_key(user, token); - self.leaves.contains_key(&index) + self.storage.leaf_index_exists(&index) } /// Get the tree index for an account. @@ -287,21 +437,18 @@ impl GenericSparseMerkleTree { } /// Get all stored values with their indices. - pub fn get_all_values(&self) -> Vec<(B256, V)> { - self.leaves - .iter() - .map(|(&index, value)| (index, value.clone())) - .collect() + pub fn get_all_values(&self) -> Result, S::Error> { + Ok(self.storage.leaves()?.collect()) } /// Get number of values stored in the tree. - pub fn value_count(&self) -> usize { - self.leaves.len() + pub fn value_count(&self) -> Result { + self.storage.num_leaves() } /// Get number of nodes stored (for debugging/analysis). - pub fn node_count(&self) -> usize { - self.nodes.len() + pub fn node_count(&self) -> Result { + self.storage.num_nodes() } } @@ -427,6 +574,15 @@ pub fn verify_proof( computed_hash == proof.root } +impl Default for InMemoryStorage { + fn default() -> Self { + Self { + nodes: Default::default(), + leaves: Default::default(), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -447,34 +603,34 @@ mod tests { #[test] fn test_basic_operations() { - let mut tree = GenericSparseMerkleTree::::new(); + let mut tree = GenericSparseMerkleTree::::new_in_memory(); let user = Address::from([1u8; 20]); let token = Address::from([2u8; 20]); let value = TestValue(100); // Insert value - tree.insert(user, token, value.clone()); + tree.insert(user, token, value.clone()).unwrap(); // Retrieve value - let retrieved = tree.get(user, token); + let retrieved = tree.get(user, token).unwrap(); assert_eq!(retrieved, value); // Check root is not zero - assert_ne!(tree.root(), B256::ZERO); + assert_ne!(tree.root().unwrap(), B256::ZERO); } #[test] fn test_proof_generation_and_verification() { - let mut tree = GenericSparseMerkleTree::::new(); + let mut tree = GenericSparseMerkleTree::::new_in_memory(); let user = Address::from([1u8; 20]); let token = Address::from([2u8; 20]); let value = TestValue(100); - tree.insert(user, token, value.clone()); + tree.insert(user, token, value.clone()).unwrap(); - let proof = tree.generate_proof(user, token); + let proof = tree.generate_proof(user, token).unwrap(); let index = tree.get_index(user, token); assert!(verify_proof::( @@ -484,46 +640,46 @@ mod tests { #[test] fn test_sparse_storage() { - let mut tree = GenericSparseMerkleTree::::new(); + let mut tree = GenericSparseMerkleTree::::new_in_memory(); // Initially should have no stored nodes - assert_eq!(tree.node_count(), 0); - assert_eq!(tree.value_count(), 0); + assert_eq!(tree.node_count().unwrap(), 0); + assert_eq!(tree.value_count().unwrap(), 0); let user = Address::from([1u8; 20]); let token = Address::from([2u8; 20]); let value = TestValue(100); // Insert one value - tree.insert(user, token, value.clone()); + tree.insert(user, token, value.clone()).unwrap(); // Should store much fewer than 256 nodes (the sparse optimization!) - let node_count = tree.node_count(); + let node_count = tree.node_count().unwrap(); println!("Node count after 1 insert: {}", node_count); assert!( node_count < 256, "Should store much fewer than 256 nodes, got {}", node_count ); - assert_eq!(tree.value_count(), 1); + assert_eq!(tree.value_count().unwrap(), 1); // Verify we can retrieve the value - assert_eq!(tree.get(user, token), value); + assert_eq!(tree.get(user, token).unwrap(), value); // Delete the value - tree.delete(user, token); + tree.delete(user, token).unwrap(); // Should clean up and reduce node count - let node_count_after_delete = tree.node_count(); + let node_count_after_delete = tree.node_count().unwrap(); println!("Node count after delete: {}", node_count_after_delete); assert!( node_count_after_delete <= node_count, "Node count should not increase after delete" ); - assert_eq!(tree.value_count(), 0); + assert_eq!(tree.value_count().unwrap(), 0); // Should return empty value - assert_eq!(tree.get(user, token), TestValue::empty()); + assert_eq!(tree.get(user, token).unwrap(), TestValue::empty()); } #[test]