Skip to content

Commit 907e359

Browse files
Merge branch 'HSM-1163'
2 parents 1f2b9ba + a7b0224 commit 907e359

2 files changed

Lines changed: 45 additions & 147 deletions

File tree

packages/wasm-mps/src/lib.rs

Lines changed: 30 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -253,20 +253,15 @@ mod mps {
253253
/// Process round 1 of DSG protocol.
254254
/// round1_messages: Public messages from other parties.
255255
/// state: Private state result from round 0.
256-
pub fn dsg_round1_process(
257-
round1_messages: &[Vec<u8>; 2],
258-
state: &[u8],
259-
) -> Result<MsgState, MpsError> {
256+
pub fn dsg_round1_process(round1_message: &[u8], state: &[u8]) -> Result<MsgState, MpsError> {
260257
// Parse state
261258
let state: DsgStateR1 =
262259
bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?;
263260

264261
// Parse messages
265-
let i0_msg1: SignMsg1 = bincode::deserialize(round1_messages[0].as_slice())
266-
.map_err(|_| MpsError::DeserializationError)?;
267-
let i1_msg1: SignMsg1 = bincode::deserialize(round1_messages[1].as_slice())
268-
.map_err(|_| MpsError::DeserializationError)?;
269-
let msgs = vec![i0_msg1, i1_msg1, state.msg];
262+
let i0_msg1: SignMsg1 =
263+
bincode::deserialize(round1_message).map_err(|_| MpsError::DeserializationError)?;
264+
let msgs = vec![i0_msg1, state.msg];
270265

271266
// Process all round1 messages together
272267
let (p2, msg2) = state
@@ -289,20 +284,15 @@ mod mps {
289284
/// Process round 2 of DSG protocol.
290285
/// round2_messages: Public messages from other parties.
291286
/// state: Private state result from round 1.
292-
pub fn dsg_round2_process(
293-
round2_messages: &[Vec<u8>; 2],
294-
state: &[u8],
295-
) -> Result<MsgState, MpsError> {
287+
pub fn dsg_round2_process(round2_message: &[u8], state: &[u8]) -> Result<MsgState, MpsError> {
296288
// Parse state
297289
let state: DsgStateR2 =
298290
bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?;
299291

300292
// Parse messages
301-
let i0_msg2: SignMsg2<EdwardsPoint> = bincode::deserialize(round2_messages[0].as_slice())
302-
.map_err(|_| MpsError::DeserializationError)?;
303-
let i1_msg2: SignMsg2<EdwardsPoint> = bincode::deserialize(round2_messages[1].as_slice())
304-
.map_err(|_| MpsError::DeserializationError)?;
305-
let msgs = vec![i0_msg2, i1_msg2, state.msg];
293+
let i0_msg2: SignMsg2<EdwardsPoint> =
294+
bincode::deserialize(round2_message).map_err(|_| MpsError::DeserializationError)?;
295+
let msgs = vec![i0_msg2, state.msg];
306296

307297
// Process all round2 messages together
308298
let party = state
@@ -328,20 +318,15 @@ mod mps {
328318
/// Process round 3 of DSG protocol.
329319
/// round3_messages: Public messages from other parties.
330320
/// state: Private state result from round 2.
331-
pub fn dsg_round3_process(
332-
round3_messages: &[Vec<u8>; 2],
333-
state: &[u8],
334-
) -> Result<Vec<u8>, MpsError> {
321+
pub fn dsg_round3_process(round3_message: &[u8], state: &[u8]) -> Result<Vec<u8>, MpsError> {
335322
// Parse state
336323
let state: DsgStateR3 =
337324
bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?;
338325

339326
// Parse messages
340-
let i0_msg3: SignMsg3<EdwardsPoint> = bincode::deserialize(round3_messages[0].as_slice())
341-
.map_err(|_| MpsError::DeserializationError)?;
342-
let i1_msg3: SignMsg3<EdwardsPoint> = bincode::deserialize(round3_messages[1].as_slice())
343-
.map_err(|_| MpsError::DeserializationError)?;
344-
let msgs = vec![i0_msg3, i1_msg3, state.msg];
327+
let i0_msg3: SignMsg3<EdwardsPoint> =
328+
bincode::deserialize(round3_message).map_err(|_| MpsError::DeserializationError)?;
329+
let msgs = vec![i0_msg3, state.msg];
345330

346331
// Process all round2 messages together
347332
let (signature, _) = state
@@ -512,11 +497,6 @@ mod tests {
512497
dkg_p0_1.state.as_slice(),
513498
)
514499
.unwrap();
515-
let dkg_p1_share = mps::dkg_round2_process(
516-
&[dkg_p0_1.msg.clone(), dkg_p2_1.msg.clone()],
517-
dkg_p1_1.state.as_slice(),
518-
)
519-
.unwrap();
520500
let dkg_p2_share = mps::dkg_round2_process(
521501
&[dkg_p0_1.msg.clone(), dkg_p1_1.msg.clone()],
522502
dkg_p2_1.state.as_slice(),
@@ -529,70 +509,31 @@ mod tests {
529509
// Process DSG round 0
530510
let dsg_p0_0 =
531511
mps::dsg_round0_process(dkg_p0_share.share.as_slice(), "m".to_string(), msg).unwrap();
532-
let dsg_p1_0 =
533-
mps::dsg_round0_process(dkg_p1_share.share.as_slice(), "m".to_string(), msg).unwrap();
534512
let dsg_p2_0 =
535513
mps::dsg_round0_process(dkg_p2_share.share.as_slice(), "m".to_string(), msg).unwrap();
536514

537515
// Process DSG round 1
538-
let dsg_p0_1 = mps::dsg_round1_process(
539-
&[dsg_p1_0.msg.clone(), dsg_p2_0.msg.clone()],
540-
dsg_p0_0.state.as_slice(),
541-
)
542-
.unwrap();
543-
let dsg_p1_1 = mps::dsg_round1_process(
544-
&[dsg_p0_0.msg.clone(), dsg_p2_0.msg.clone()],
545-
dsg_p1_0.state.as_slice(),
546-
)
547-
.unwrap();
548-
let dsg_p2_1 = mps::dsg_round1_process(
549-
&[dsg_p0_0.msg.clone(), dsg_p1_0.msg.clone()],
550-
dsg_p2_0.state.as_slice(),
551-
)
552-
.unwrap();
516+
let dsg_p0_1 =
517+
mps::dsg_round1_process(dsg_p2_0.msg.as_slice(), dsg_p0_0.state.as_slice()).unwrap();
518+
let dsg_p2_1 =
519+
mps::dsg_round1_process(dsg_p0_0.msg.as_slice(), dsg_p2_0.state.as_slice()).unwrap();
553520

554521
// Process DSG round 2
555-
let dsg_p0_2 = mps::dsg_round2_process(
556-
&[dsg_p1_1.msg.clone(), dsg_p2_1.msg.clone()],
557-
dsg_p0_1.state.as_slice(),
558-
)
559-
.unwrap();
560-
let dsg_p1_2 = mps::dsg_round2_process(
561-
&[dsg_p0_1.msg.clone(), dsg_p2_1.msg.clone()],
562-
dsg_p1_1.state.as_slice(),
563-
)
564-
.unwrap();
565-
let dsg_p2_2 = mps::dsg_round2_process(
566-
&[dsg_p0_1.msg.clone(), dsg_p1_1.msg.clone()],
567-
dsg_p2_1.state.as_slice(),
568-
)
569-
.unwrap();
522+
let dsg_p0_2 =
523+
mps::dsg_round2_process(dsg_p2_1.msg.as_slice(), dsg_p0_1.state.as_slice()).unwrap();
524+
let dsg_p2_2 =
525+
mps::dsg_round2_process(dsg_p0_1.msg.as_slice(), dsg_p2_1.state.as_slice()).unwrap();
570526

571527
// Process DSG round 3
572-
let dsg_p0_sig = mps::dsg_round3_process(
573-
&[dsg_p1_2.msg.clone(), dsg_p2_2.msg.clone()],
574-
dsg_p0_2.state.as_slice(),
575-
)
576-
.unwrap();
577-
let dsg_p1_sig = mps::dsg_round3_process(
578-
&[dsg_p0_2.msg.clone(), dsg_p2_2.msg.clone()],
579-
dsg_p1_2.state.as_slice(),
580-
)
581-
.unwrap();
582-
let dsg_p2_sig = mps::dsg_round3_process(
583-
&[dsg_p0_2.msg.clone(), dsg_p1_2.msg.clone()],
584-
dsg_p2_2.state.as_slice(),
585-
)
586-
.unwrap();
528+
let dsg_p0_sig =
529+
mps::dsg_round3_process(dsg_p2_2.msg.as_slice(), dsg_p0_2.state.as_slice()).unwrap();
530+
let dsg_p2_sig =
531+
mps::dsg_round3_process(dsg_p0_2.msg.as_slice(), dsg_p2_2.state.as_slice()).unwrap();
587532

588533
assert_eq!(
589534
dsg_p2_sig, dsg_p0_sig,
590535
"Party 0 signature differs from party 2 signature"
591536
);
592-
assert_eq!(
593-
dsg_p2_sig, dsg_p1_sig,
594-
"Party 1 signature differs from party 2 signature"
595-
);
596537

597538
// Verify signature
598539
VerifyingKey::from_bytes(&dkg_p0_share.pk)
@@ -602,13 +543,6 @@ mod tests {
602543
&Signature::from_bytes(dsg_p0_sig.as_slice().try_into().unwrap()),
603544
)
604545
.unwrap();
605-
VerifyingKey::from_bytes(&dkg_p1_share.pk)
606-
.unwrap()
607-
.verify(
608-
msg,
609-
&Signature::from_bytes(dsg_p1_sig.as_slice().try_into().unwrap()),
610-
)
611-
.unwrap();
612546
VerifyingKey::from_bytes(&dkg_p2_share.pk)
613547
.unwrap()
614548
.verify(
@@ -760,15 +694,8 @@ pub fn dsg_round0_process(
760694
}
761695

762696
#[wasm_bindgen]
763-
pub fn dsg_round1_process(round1_messages: Array, state: &[u8]) -> Result<MsgState, String> {
764-
let result = mps::dsg_round1_process(
765-
&[
766-
js_sys::Uint8Array::from(round1_messages.get(0)).to_vec(),
767-
js_sys::Uint8Array::from(round1_messages.get(1)).to_vec(),
768-
],
769-
state,
770-
)
771-
.map_err(|e| e.to_string())?;
697+
pub fn dsg_round1_process(round1_message: &[u8], state: &[u8]) -> Result<MsgState, String> {
698+
let result = mps::dsg_round1_process(round1_message, state).map_err(|e| e.to_string())?;
772699

773700
Ok(MsgState {
774701
msg: result.msg,
@@ -777,15 +704,8 @@ pub fn dsg_round1_process(round1_messages: Array, state: &[u8]) -> Result<MsgSta
777704
}
778705

779706
#[wasm_bindgen]
780-
pub fn dsg_round2_process(round2_messages: Array, state: &[u8]) -> Result<MsgState, String> {
781-
let result = mps::dsg_round2_process(
782-
&[
783-
js_sys::Uint8Array::from(round2_messages.get(0)).to_vec(),
784-
js_sys::Uint8Array::from(round2_messages.get(1)).to_vec(),
785-
],
786-
state,
787-
)
788-
.map_err(|e| e.to_string())?;
707+
pub fn dsg_round2_process(round2_message: &[u8], state: &[u8]) -> Result<MsgState, String> {
708+
let result = mps::dsg_round2_process(round2_message, state).map_err(|e| e.to_string())?;
789709

790710
Ok(MsgState {
791711
msg: result.msg,
@@ -794,15 +714,8 @@ pub fn dsg_round2_process(round2_messages: Array, state: &[u8]) -> Result<MsgSta
794714
}
795715

796716
#[wasm_bindgen]
797-
pub fn dsg_round3_process(round2_messages: Array, state: &[u8]) -> Result<Vec<u8>, String> {
798-
let result = mps::dsg_round3_process(
799-
&[
800-
js_sys::Uint8Array::from(round2_messages.get(0)).to_vec(),
801-
js_sys::Uint8Array::from(round2_messages.get(1)).to_vec(),
802-
],
803-
state,
804-
)
805-
.map_err(|e| e.to_string())?;
717+
pub fn dsg_round3_process(round2_message: &[u8], state: &[u8]) -> Result<Vec<u8>, String> {
718+
let result = mps::dsg_round3_process(round2_message, state).map_err(|e| e.to_string())?;
806719

807720
Ok(result.to_vec())
808721
}

packages/wasm-mps/test/mps.ts

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ describe("mps", function () {
193193
});
194194

195195
describe("dsg", function () {
196+
const otherIndex = [1, 0];
196197
let shares: Array<mps.Share>;
197198

198199
before("performs dkg", function () {
@@ -223,67 +224,51 @@ describe("mps", function () {
223224
);
224225

225226
it("performs round 0", function () {
226-
for (let i = 0; i < 3; i++) {
227+
for (const i of [0, 2]) {
227228
mps.dsg_round0_process(shares[i].share, "m", message);
228229
}
229230
});
230231

231232
let results1: Array<mps.MsgState>;
232233

233234
before("performs round 0", function () {
234-
results1 = [0, 1, 2].map((i) => mps.dsg_round0_process(shares[i].share, "m", message));
235+
results1 = [0, 2].map((i) => mps.dsg_round0_process(shares[i].share, "m", message));
235236
});
236237

237238
it("performs round 1", function () {
238-
for (let i = 0; i < 3; i++) {
239-
mps.dsg_round1_process(
240-
otherIndices[i].map((i) => results1[i].msg),
241-
results1[i].state,
242-
);
239+
for (let i = 0; i < 2; i++) {
240+
mps.dsg_round1_process(results1[otherIndex[i]].msg, results1[i].state);
243241
}
244242
});
245243

246244
let results2: Array<mps.MsgState>;
247245

248246
before("performs round 1", function () {
249-
results2 = [0, 1, 2].map((i) =>
250-
mps.dsg_round1_process(
251-
otherIndices[i].map((i) => results1[i].msg),
252-
results1[i].state,
253-
),
247+
results2 = [0, 1].map((i) =>
248+
mps.dsg_round1_process(results1[otherIndex[i]].msg, results1[i].state),
254249
);
255250
});
256251

257252
it("performs round 2", function () {
258-
for (let i = 0; i < 3; i++) {
259-
mps.dsg_round2_process(
260-
otherIndices[i].map((i) => results2[i].msg),
261-
results2[i].state,
262-
);
253+
for (let i = 0; i < 2; i++) {
254+
mps.dsg_round2_process(results2[otherIndex[i]].msg, results2[i].state);
263255
}
264256
});
265257

266258
let results3: Array<mps.MsgState>;
267259

268260
before("performs round 2", function () {
269-
results3 = [0, 1, 2].map((i) =>
270-
mps.dsg_round2_process(
271-
otherIndices[i].map((i) => results2[i].msg),
272-
results2[i].state,
273-
),
261+
results3 = [0, 1].map((i) =>
262+
mps.dsg_round2_process(results2[otherIndex[i]].msg, results2[i].state),
274263
);
275264
});
276265

277266
it("performs round 3", function () {
278-
const signatures = [0, 1, 2].map((i) =>
279-
mps.dsg_round3_process(
280-
otherIndices[i].map((i) => results3[i].msg),
281-
results3[i].state,
282-
),
267+
const signatures = [0, 1].map((i) =>
268+
mps.dsg_round3_process(results3[otherIndex[i]].msg, results3[i].state),
283269
);
284-
for (let i = 0; i < 3; i++) {
285-
assert(sodium.crypto_sign_verify_detached(signatures[i], message, shares[i].pk));
286-
}
270+
assert(sodium.crypto_sign_verify_detached(signatures[0], message, shares[0].pk));
271+
assert(sodium.crypto_sign_verify_detached(signatures[1], message, shares[2].pk));
287272
});
288273
});
289274
});

0 commit comments

Comments
 (0)