6161
6262import hmac
6363import secrets
64- from typing import List , NamedTuple , Sequence , Tuple
64+ from typing import List , Sequence , Tuple
6565
6666
6767# The length of the digest of the shared secret in bytes.
8080DIGEST_INDEX = 254
8181
8282
83- class RawShare (NamedTuple ):
84- """
85- A raw Shamir share.
86- """
87-
88- # TODO: Use a similar structure in the rest of my code
89-
90- x : int
91- data : bytes
92-
93-
9483# Source of random bytes. Can be overridden for deterministic testing.
9584RANDOM_BYTES = secrets .token_bytes
9685
@@ -113,31 +102,31 @@ def _precompute_exp_log() -> Tuple[List[int], List[int]]:
113102EXP_TABLE , LOG_TABLE = _precompute_exp_log ()
114103
115104
116- def _interpolate (shares : Sequence [RawShare ], x : int ) -> bytes :
105+ def _interpolate (shares : Sequence [Tuple [ int , bytes ] ], x : int ) -> bytes :
117106 """
118107 Returns f(x) given the Shamir shares (x_1, f(x_1)), ... , (x_k, f(x_k)).
119108 """
120- x_coordinates = set (share . x for share in shares )
109+ x_coordinates = set (share [ 0 ] for share in shares )
121110 if len (x_coordinates ) != len (shares ):
122111 raise ValueError ("Invalid set of shares. Share indices must be unique." )
123- share_value_lengths = set (len (share . data ) for share in shares )
112+ share_value_lengths = set (len (share [ 1 ] ) for share in shares )
124113 if len (share_value_lengths ) != 1 :
125114 raise ValueError (
126115 "Invalid set of shares. All share values must have the same length."
127116 )
128117 if x in x_coordinates :
129118 for share in shares :
130- if share . x == x :
131- return share . data
119+ if share [ 0 ] == x :
120+ return share [ 1 ]
132121 # Logarithm of the product of (x_i - x) for i = 1, ... , k.
133- log_prod = sum (LOG_TABLE [share . x ^ x ] for share in shares )
122+ log_prod = sum (LOG_TABLE [share [ 0 ] ^ x ] for share in shares )
134123 result = bytes (share_value_lengths .pop ())
135124 for share in shares :
136125 # The logarithm of the Lagrange basis polynomial evaluated at x.
137126 log_basis_eval = (
138127 log_prod
139- - LOG_TABLE [share . x ^ x ]
140- - sum (LOG_TABLE [share . x ^ other . x ] for other in shares )
128+ - LOG_TABLE [share [ 0 ] ^ x ]
129+ - sum (LOG_TABLE [share [ 0 ] ^ other [ 0 ] ] for other in shares )
141130 ) % 255
142131 result = bytes (
143132 intermediate_sum
@@ -146,7 +135,7 @@ def _interpolate(shares: Sequence[RawShare], x: int) -> bytes:
146135 if share_val != 0
147136 else 0
148137 )
149- for share_val , intermediate_sum in zip (share . data , result )
138+ for share_val , intermediate_sum in zip (share [ 1 ] , result )
150139 )
151140 return result
152141
@@ -179,23 +168,20 @@ def split_binary_secret_into_shares(
179168 f"The requested number of shares must not exceed { MAX_SHARE_COUNT } ."
180169 )
181170 if min_nr_shares == 1 :
182- # If the min_nr_shares is 1, then the digest of the shared secret is not used.
183- raw_shares = [RawShare (i , secret ) for i in range (nr_shares )]
184- else :
185- random_share_count = min_nr_shares - 2
186- raw_shares = [
187- RawShare (i , RANDOM_BYTES (len (secret ))) for i in range (random_share_count )
188- ]
189- random_part = RANDOM_BYTES (len (secret ) - DIGEST_LENGTH_BYTES )
190- digest = _create_digest (random_part , secret )
191- base_shares = raw_shares + [
192- RawShare (DIGEST_INDEX , digest + random_part ),
193- RawShare (SECRET_INDEX , secret ),
194- ]
195- for i in range (random_share_count , nr_shares ):
196- raw_shares .append (RawShare (i , _interpolate (base_shares , i )))
197- # TODO: Remove this back-and-forth conversion between our tuple and RawShare
198- return [(share .x , share .data ) for share in raw_shares ]
171+ # If the min_nr_shares is 1, then the digest of the secret is not used.
172+ return [(i , secret ) for i in range (nr_shares )]
173+ random_share_count = min_nr_shares - 2
174+ shares = [(i , RANDOM_BYTES (len (secret ))) for i in range (random_share_count )]
175+ digest_random_bytes = RANDOM_BYTES (len (secret ) - DIGEST_LENGTH_BYTES )
176+ digest = _create_digest (digest_random_bytes , secret )
177+ digest_share_data = digest + digest_random_bytes
178+ interpolation_shares = shares + [
179+ (DIGEST_INDEX , digest_share_data ),
180+ (SECRET_INDEX , secret ),
181+ ]
182+ for i in range (random_share_count , nr_shares ):
183+ shares .append ((i , _interpolate (interpolation_shares , i )))
184+ return shares
199185
200186
201187def reconstruct_binary_secret_from_shares (
@@ -204,13 +190,14 @@ def reconstruct_binary_secret_from_shares(
204190 """
205191 Reconstruct a binary secret from shares.
206192 """
207- raw_shares = [RawShare (x , data ) for (x , data ) in shares ]
208193 if min_nr_shares == 1 :
209- return next (iter (raw_shares )).data
210- secret = _interpolate (raw_shares , SECRET_INDEX )
211- digest_share = _interpolate (raw_shares , DIGEST_INDEX )
194+ first_share = next (iter (shares ))
195+ first_share_data = first_share [1 ]
196+ return first_share_data
197+ secret = _interpolate (shares , SECRET_INDEX )
198+ digest_share = _interpolate (shares , DIGEST_INDEX )
212199 digest = digest_share [:DIGEST_LENGTH_BYTES ]
213- random_part = digest_share [DIGEST_LENGTH_BYTES :]
214- if digest != _create_digest (random_part , secret ):
200+ digest_random_bytes = digest_share [DIGEST_LENGTH_BYTES :]
201+ if digest != _create_digest (digest_random_bytes , secret ):
215202 raise ValueError ("Invalid digest of the shared secret." )
216203 return secret
0 commit comments