Skip to content

Commit 8d94e4e

Browse files
committed
Remove back-and-forth conversions
1 parent e9d53aa commit 8d94e4e

1 file changed

Lines changed: 31 additions & 44 deletions

File tree

common/shamir.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
import hmac
6363
import secrets
64-
from typing import List, NamedTuple, Sequence, Tuple
64+
from typing import List, Sequence, Tuple
6565

6666

6767
# The length of the digest of the shared secret in bytes.
@@ -80,17 +80,6 @@
8080
DIGEST_INDEX = 254
8181

8282

83-
class RawShare(NamedTuple):
84-
"""
85-
A raw Shamir share.
86-
"""
87-
88-
# TODO: Use a similar structure in the rest of my code
89-
90-
x: int
91-
data: bytes
92-
93-
9483
# Source of random bytes. Can be overridden for deterministic testing.
9584
RANDOM_BYTES = secrets.token_bytes
9685

@@ -113,31 +102,31 @@ def _precompute_exp_log() -> Tuple[List[int], List[int]]:
113102
EXP_TABLE, LOG_TABLE = _precompute_exp_log()
114103

115104

116-
def _interpolate(shares: Sequence[RawShare], x: int) -> bytes:
105+
def _interpolate(shares: Sequence[Tuple[int, bytes]], x: int) -> bytes:
117106
"""
118107
Returns f(x) given the Shamir shares (x_1, f(x_1)), ... , (x_k, f(x_k)).
119108
"""
120-
x_coordinates = set(share.x for share in shares)
109+
x_coordinates = set(share[0] for share in shares)
121110
if len(x_coordinates) != len(shares):
122111
raise ValueError("Invalid set of shares. Share indices must be unique.")
123-
share_value_lengths = set(len(share.data) for share in shares)
112+
share_value_lengths = set(len(share[1]) for share in shares)
124113
if len(share_value_lengths) != 1:
125114
raise ValueError(
126115
"Invalid set of shares. All share values must have the same length."
127116
)
128117
if x in x_coordinates:
129118
for share in shares:
130-
if share.x == x:
131-
return share.data
119+
if share[0] == x:
120+
return share[1]
132121
# Logarithm of the product of (x_i - x) for i = 1, ... , k.
133-
log_prod = sum(LOG_TABLE[share.x ^ x] for share in shares)
122+
log_prod = sum(LOG_TABLE[share[0] ^ x] for share in shares)
134123
result = bytes(share_value_lengths.pop())
135124
for share in shares:
136125
# The logarithm of the Lagrange basis polynomial evaluated at x.
137126
log_basis_eval = (
138127
log_prod
139-
- LOG_TABLE[share.x ^ x]
140-
- sum(LOG_TABLE[share.x ^ other.x] for other in shares)
128+
- LOG_TABLE[share[0] ^ x]
129+
- sum(LOG_TABLE[share[0] ^ other[0]] for other in shares)
141130
) % 255
142131
result = bytes(
143132
intermediate_sum
@@ -146,7 +135,7 @@ def _interpolate(shares: Sequence[RawShare], x: int) -> bytes:
146135
if share_val != 0
147136
else 0
148137
)
149-
for share_val, intermediate_sum in zip(share.data, result)
138+
for share_val, intermediate_sum in zip(share[1], result)
150139
)
151140
return result
152141

@@ -179,23 +168,20 @@ def split_binary_secret_into_shares(
179168
f"The requested number of shares must not exceed {MAX_SHARE_COUNT}."
180169
)
181170
if min_nr_shares == 1:
182-
# If the min_nr_shares is 1, then the digest of the shared secret is not used.
183-
raw_shares = [RawShare(i, secret) for i in range(nr_shares)]
184-
else:
185-
random_share_count = min_nr_shares - 2
186-
raw_shares = [
187-
RawShare(i, RANDOM_BYTES(len(secret))) for i in range(random_share_count)
188-
]
189-
random_part = RANDOM_BYTES(len(secret) - DIGEST_LENGTH_BYTES)
190-
digest = _create_digest(random_part, secret)
191-
base_shares = raw_shares + [
192-
RawShare(DIGEST_INDEX, digest + random_part),
193-
RawShare(SECRET_INDEX, secret),
194-
]
195-
for i in range(random_share_count, nr_shares):
196-
raw_shares.append(RawShare(i, _interpolate(base_shares, i)))
197-
# TODO: Remove this back-and-forth conversion between our tuple and RawShare
198-
return [(share.x, share.data) for share in raw_shares]
171+
# If the min_nr_shares is 1, then the digest of the secret is not used.
172+
return [(i, secret) for i in range(nr_shares)]
173+
random_share_count = min_nr_shares - 2
174+
shares = [(i, RANDOM_BYTES(len(secret))) for i in range(random_share_count)]
175+
digest_random_bytes = RANDOM_BYTES(len(secret) - DIGEST_LENGTH_BYTES)
176+
digest = _create_digest(digest_random_bytes, secret)
177+
digest_share_data = digest + digest_random_bytes
178+
interpolation_shares = shares + [
179+
(DIGEST_INDEX, digest_share_data),
180+
(SECRET_INDEX, secret),
181+
]
182+
for i in range(random_share_count, nr_shares):
183+
shares.append((i, _interpolate(interpolation_shares, i)))
184+
return shares
199185

200186

201187
def reconstruct_binary_secret_from_shares(
@@ -204,13 +190,14 @@ def reconstruct_binary_secret_from_shares(
204190
"""
205191
Reconstruct a binary secret from shares.
206192
"""
207-
raw_shares = [RawShare(x, data) for (x, data) in shares]
208193
if min_nr_shares == 1:
209-
return next(iter(raw_shares)).data
210-
secret = _interpolate(raw_shares, SECRET_INDEX)
211-
digest_share = _interpolate(raw_shares, DIGEST_INDEX)
194+
first_share = next(iter(shares))
195+
first_share_data = first_share[1]
196+
return first_share_data
197+
secret = _interpolate(shares, SECRET_INDEX)
198+
digest_share = _interpolate(shares, DIGEST_INDEX)
212199
digest = digest_share[:DIGEST_LENGTH_BYTES]
213-
random_part = digest_share[DIGEST_LENGTH_BYTES:]
214-
if digest != _create_digest(random_part, secret):
200+
digest_random_bytes = digest_share[DIGEST_LENGTH_BYTES:]
201+
if digest != _create_digest(digest_random_bytes, secret):
215202
raise ValueError("Invalid digest of the shared secret.")
216203
return secret

0 commit comments

Comments
 (0)