Skip to content

Commit 7976dc3

Browse files
authored
Use two-level dedup in named data store (#18351)
Replace single SHA-256 hash with a two-level approach: 1. Fast fingerprint (length + first 32 bytes) for cheap rejection 2. SHA-256 only when the fingerprint matches, to confirm without full byte comparison For a 35B MoE model with ~29 GB of named data where most buffers are unique, the fingerprint rejects non-matches instantly. SHA-256 is only computed on the rare fingerprint match, avoiding the ~98s cost of hashing everything upfront. Fingerprint collisions are handled by storing a list of candidate buffer indices per fingerprint, so no dedup opportunities are lost. Test plan: - All 12 tests pass in test_named_data_store.py - Added test_fingerprint_collision: same fingerprint, different content produces separate buffers - Added test_fingerprint_collision_with_dedup: after a collision, a true duplicate of an earlier blob still dedupes correctly
1 parent 4f80b77 commit 7976dc3

2 files changed

Lines changed: 73 additions & 28 deletions

File tree

exir/_serialize/_named_data_store.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import hashlib
1010
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Union
11+
from typing import Dict, List, Optional, Tuple, Union
1212

1313
import torch
1414
from executorch.exir._serialize.data_serializer import DataEntry
@@ -75,13 +75,11 @@ class NamedDataStore:
7575
# Map of {filename: {key: DataEntry}}.
7676
external_data: Dict[str, Dict[str, DataEntry]]
7777

78-
# Cache of the data hash for deduplication.
79-
# Use a hash instead of the data as a key because a sha256 collision is
80-
# unlikely, and the data may be large.
81-
data_hash_to_buffer_idx: Dict[bytes, int]
82-
# Cache of the key to buffer idx to ensure uniqueness.
83-
# If a key is added multiple times, check the buffer idx to ensure that the
84-
# data is identical too.
78+
# Fast fingerprint for dedup: (length, first 32 bytes) -> buffer indices.
79+
fingerprint_to_buffer_idx: Dict[Tuple[int, bytes], List[int]]
80+
# SHA-256 digest per buffer index, computed lazily on first dedup check.
81+
buffer_sha256: Dict[int, bytes]
82+
# Cache of key to buffer idx to detect duplicate key registration.
8583
key_to_buffer_idx: Dict[str, int]
8684

8785
def __init__(self) -> None:
@@ -91,10 +89,17 @@ def __init__(self) -> None:
9189
self.buffers = []
9290
self.pte_data = {}
9391
self.external_data = {}
94-
95-
self.data_hash_to_buffer_idx = {}
92+
self.fingerprint_to_buffer_idx = {}
93+
self.buffer_sha256 = {}
9694
self.key_to_buffer_idx = {}
9795

96+
def _get_buffer_sha256(self, buffer_idx: int) -> bytes:
97+
sha = self.buffer_sha256.get(buffer_idx)
98+
if sha is None:
99+
sha = hashlib.sha256(self.buffers[buffer_idx]).digest()
100+
self.buffer_sha256[buffer_idx] = sha
101+
return sha
102+
98103
def _add_named_data_to_map(
99104
self,
100105
key: str,
@@ -119,31 +124,34 @@ def _add_named_data_to_map(
119124
ValueError: when the key exists in the store, and corresponding data
120125
is different.
121126
"""
122-
# Get data hash.
123-
hashed = hashlib.sha256(data).digest()
124-
125127
# Check if the key exists.
126128
buffer_idx = self.key_to_buffer_idx.get(key, -1)
127-
# If the key exists, the corresponding data must be identical.
128-
if (
129-
buffer_idx != -1
130-
and self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx
131-
):
132-
raise ValueError(
133-
f"Duplicate key {key} with different data. "
134-
f"Existing data size: {len(self.buffers[buffer_idx])} bytes. "
135-
f"New data size: {len(data)} bytes."
136-
)
129+
if buffer_idx != -1:
130+
if data != self.buffers[buffer_idx]:
131+
raise ValueError(
132+
f"Duplicate key {key} with different data. "
133+
f"Existing data size: {len(self.buffers[buffer_idx])} bytes. "
134+
f"New data size: {len(data)} bytes."
135+
)
137136
else:
138-
# Key doesn't exist; check if the data exists.
139-
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
137+
# Two-level dedup: cheap fingerprint rejects non-matches fast,
138+
# SHA-256 confirms matches without full byte comparison.
139+
fingerprint = (len(data), data[:32])
140+
candidates = self.fingerprint_to_buffer_idx.get(fingerprint)
141+
if candidates is not None:
142+
new_sha = hashlib.sha256(data).digest()
143+
for candidate in candidates:
144+
if new_sha == self._get_buffer_sha256(candidate):
145+
buffer_idx = candidate
146+
break
147+
140148
if buffer_idx == -1:
141-
# The data doesn't exist; add it to the data store.
142149
buffer_idx = len(self.buffers)
143150
self.buffers.append(data)
144-
self.data_hash_to_buffer_idx[hashed] = buffer_idx
151+
self.fingerprint_to_buffer_idx.setdefault(fingerprint, []).append(
152+
buffer_idx
153+
)
145154

146-
# Add key to the map and the key cache.
147155
local_key_to_buffer_idx[key] = DataEntry(
148156
buffer_index=buffer_idx,
149157
alignment=alignment,

exir/_serialize/test/test_named_data_store.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,40 @@ def test_merge_duplicate_error(self) -> None:
210210
# Merge store2 into store1 raises error as key1 is already in store1
211211
# with different data.
212212
self.assertRaises(ValueError, store1.merge_named_data_store, output2)
213+
214+
def test_fingerprint_collision(self) -> None:
215+
"""Two blobs with same length and first 32 bytes but different content
216+
must not be deduped."""
217+
store = NamedDataStore()
218+
prefix = b"A" * 32
219+
data1 = prefix + b"X" * 100
220+
data2 = prefix + b"Y" * 100
221+
self.assertEqual(len(data1), len(data2))
222+
223+
store.add_named_data("key1", data1, None, None)
224+
store.add_named_data("key2", data2, None, None)
225+
226+
output = store.get_named_data_store_output()
227+
self.assertEqual(len(output.buffers), 2)
228+
self.assertEqual(output.buffers[0], data1)
229+
self.assertEqual(output.buffers[1], data2)
230+
self.assertEqual(output.pte_data["key1"].buffer_index, 0)
231+
self.assertEqual(output.pte_data["key2"].buffer_index, 1)
232+
233+
def test_fingerprint_collision_with_dedup(self) -> None:
234+
"""After a fingerprint collision, a true duplicate of the first blob
235+
must still be deduped correctly."""
236+
store = NamedDataStore()
237+
prefix = b"A" * 32
238+
data1 = prefix + b"X" * 100
239+
data2 = prefix + b"Y" * 100
240+
241+
store.add_named_data("key1", data1, None, None)
242+
store.add_named_data("key2", data2, None, None)
243+
store.add_named_data("key3", data1, None, None) # duplicate of key1
244+
245+
output = store.get_named_data_store_output()
246+
self.assertEqual(len(output.buffers), 2)
247+
self.assertEqual(output.pte_data["key1"].buffer_index, 0)
248+
self.assertEqual(output.pte_data["key2"].buffer_index, 1)
249+
self.assertEqual(output.pte_data["key3"].buffer_index, 0)

0 commit comments

Comments
 (0)