Skip to content

Commit e84dbb5

Browse files
committed
Simplify Shamir code (step 1)
1 parent ad888f5 commit e84dbb5

1 file changed

Lines changed: 46 additions & 57 deletions

File tree

common/shamir.py

Lines changed: 46 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -171,84 +171,61 @@ def _interpolate(shares: Sequence[RawShare], x: int) -> bytes:
171171
return result
172172

173173

174-
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
175-
return hmac.new(random_data, shared_secret, "sha256").digest()[:DIGEST_LENGTH_BYTES]
174+
def _create_digest(random_data: bytes, secret: bytes) -> bytes:
175+
return hmac.new(random_data, secret, "sha256").digest()[:DIGEST_LENGTH_BYTES]
176176

177177

178-
def _split_secret(
179-
threshold: int, share_count: int, shared_secret: bytes
180-
) -> List[RawShare]:
181-
if len(shared_secret) < MIN_KEY_LENGTH:
178+
def split_binary_secret_into_shares(
179+
secret: bytes,
180+
nr_shares: int,
181+
min_nr_shares: int,
182+
) -> list[(int, bytes)]:
183+
"""
184+
Split a secret into nr_shares shares. The minimum number of shares required to
185+
reconstruct the secret is min_nr_shares.
186+
"""
187+
if len(secret) < MIN_KEY_LENGTH:
182188
raise ValueError(
183189
f"The shared secret must be at least {MIN_KEY_LENGTH} bytes long."
184190
)
185191

186-
if threshold < 1:
187-
raise ValueError("The requested threshold must be a positive integer.")
192+
if min_nr_shares < 1:
193+
raise ValueError("The requested min_nr_shares must be a positive integer.")
188194

189-
if threshold > share_count:
195+
if min_nr_shares > nr_shares:
190196
raise ValueError(
191-
"The requested threshold must not exceed the number of shares."
197+
"The requested min_nr_shares must not exceed the number of shares."
192198
)
193199

194-
if share_count > MAX_SHARE_COUNT:
200+
if nr_shares > MAX_SHARE_COUNT:
195201
raise ValueError(
196202
f"The requested number of shares must not exceed {MAX_SHARE_COUNT}."
197203
)
198204

199-
# TODO: We won't allow a threshold of 1; we will require at least 2 (or even 3?)
200-
# If the threshold is 1, then the digest of the shared secret is not used.
201-
if threshold == 1:
202-
return [RawShare(i, shared_secret) for i in range(share_count)]
203-
204-
random_share_count = threshold - 2
205+
if min_nr_shares == 1:
206+
# If the min_nr_shares is 1, then the digest of the shared secret is not used.
207+
raw_shares = [RawShare(i, secret) for i in range(nr_shares)]
205208

206-
shares = [
207-
RawShare(i, RANDOM_BYTES(len(shared_secret))) for i in range(random_share_count)
208-
]
209+
else:
209210

210-
random_part = RANDOM_BYTES(len(shared_secret) - DIGEST_LENGTH_BYTES)
211-
digest = _create_digest(random_part, shared_secret)
211+
random_share_count = min_nr_shares - 2
212212

213-
base_shares = shares + [
214-
RawShare(DIGEST_INDEX, digest + random_part),
215-
RawShare(SECRET_INDEX, shared_secret),
216-
]
217-
218-
for i in range(random_share_count, share_count):
219-
shares.append(RawShare(i, _interpolate(base_shares, i)))
220-
221-
return shares
222-
223-
224-
def _recover_secret(threshold: int, shares: Sequence[RawShare]) -> bytes:
225-
# If the threshold is 1, then the digest of the shared secret is not used.
226-
# TODO: Disallow threshold of 1
227-
if threshold == 1:
228-
return next(iter(shares)).data
229-
230-
shared_secret = _interpolate(shares, SECRET_INDEX)
231-
digest_share = _interpolate(shares, DIGEST_INDEX)
232-
digest = digest_share[:DIGEST_LENGTH_BYTES]
233-
random_part = digest_share[DIGEST_LENGTH_BYTES:]
213+
raw_shares = [
214+
RawShare(i, RANDOM_BYTES(len(secret))) for i in range(random_share_count)
215+
]
234216

235-
if digest != _create_digest(random_part, shared_secret):
236-
raise ValueError("Invalid digest of the shared secret.")
217+
random_part = RANDOM_BYTES(len(secret) - DIGEST_LENGTH_BYTES)
218+
digest = _create_digest(random_part, secret)
237219

238-
return shared_secret
220+
base_shares = raw_shares + [
221+
RawShare(DIGEST_INDEX, digest + random_part),
222+
RawShare(SECRET_INDEX, secret),
223+
]
239224

225+
for i in range(random_share_count, nr_shares):
226+
raw_shares.append(RawShare(i, _interpolate(base_shares, i)))
240227

241-
def split_binary_secret_into_shares(
242-
secret: bytes,
243-
nr_shares: int,
244-
min_nr_shares: int,
245-
) -> list[(int, bytes)]:
246-
"""
247-
Split a binary secret into `nr_shares` shares. The minimum number of shares required to
248-
reconstruct the binary is `min_nr_shares`.
249-
"""
250228
# TODO: Remove this back-and-forth conversion between our tuple and RawShare
251-
raw_shares = _split_secret(min_nr_shares, nr_shares, secret)
252229
return [(share.x, share.data) for share in raw_shares]
253230

254231

@@ -259,4 +236,16 @@ def reconstruct_binary_secret_from_shares(
259236
Reconstruct a binary secret from shares.
260237
"""
261238
raw_shares = [RawShare(x, data) for (x, data) in shares]
262-
return _recover_secret(min_nr_shares, raw_shares)
239+
240+
if min_nr_shares == 1:
241+
return next(iter(raw_shares)).data
242+
243+
secret = _interpolate(raw_shares, SECRET_INDEX)
244+
digest_share = _interpolate(raw_shares, DIGEST_INDEX)
245+
digest = digest_share[:DIGEST_LENGTH_BYTES]
246+
random_part = digest_share[DIGEST_LENGTH_BYTES:]
247+
248+
if digest != _create_digest(random_part, secret):
249+
raise ValueError("Invalid digest of the shared secret.")
250+
251+
return secret

0 commit comments

Comments
 (0)