Skip to content

Commit a438c22

Browse files
committed
simplify hint_decompose_bits_xmss
1 parent cca9966 commit a438c22

2 files changed

Lines changed: 12 additions & 15 deletions

File tree

crates/lean_vm/src/isa/hint.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl CustomHint {
129129

130130
pub fn n_args(&self) -> usize {
131131
match self {
132-
Self::DecomposeBitsXMSS => 5,
132+
Self::DecomposeBitsXMSS => 4,
133133
Self::DecomposeBitsMerkleWhir => 3,
134134
Self::DecomposeBits => 4,
135135
Self::LessThan => 3,
@@ -145,13 +145,11 @@ impl CustomHint {
145145
match self {
146146
Self::DecomposeBitsXMSS => {
147147
let decomposed_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize();
148-
let remaining_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize();
149-
let to_decompose_ptr = args[2].read_value(ctx.memory, ctx.fp)?.to_usize();
150-
let num_to_decompose = args[3].read_value(ctx.memory, ctx.fp)?.to_usize();
151-
let chunk_size = args[4].read_value(ctx.memory, ctx.fp)?.to_usize();
148+
let to_decompose_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize();
149+
let num_to_decompose = args[2].read_value(ctx.memory, ctx.fp)?.to_usize();
150+
let chunk_size = args[3].read_value(ctx.memory, ctx.fp)?.to_usize();
152151
assert!(24_usize.is_multiple_of(chunk_size));
153152
let mut memory_index_decomposed = decomposed_ptr;
154-
let mut memory_index_remaining = remaining_ptr;
155153
#[allow(clippy::explicit_counter_loop)]
156154
for i in 0..num_to_decompose {
157155
let value = ctx.memory.get(to_decompose_ptr + i)?.to_usize();
@@ -160,8 +158,6 @@ impl CustomHint {
160158
ctx.memory.set(memory_index_decomposed, value)?;
161159
memory_index_decomposed += 1;
162160
}
163-
ctx.memory.set(memory_index_remaining, F::from_usize(value >> 24))?;
164-
memory_index_remaining += 1;
165161
}
166162
}
167163
Self::DecomposeBitsMerkleWhir => {

crates/rec_aggregation/xmss_aggregate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,23 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks):
3939
poseidon16_compress(b_input, b_input + DIGEST_LEN, encoding_fe)
4040

4141
encoding = Array(NUM_ENCODING_FE * 24 / (2 * W))
42-
remaining = Array(NUM_ENCODING_FE)
4342

44-
hint_decompose_bits_xmss(encoding, remaining, encoding_fe, NUM_ENCODING_FE, 2 * W)
43+
hint_decompose_bits_xmss(encoding, encoding_fe, NUM_ENCODING_FE, 2 * W)
4544

4645
# check that the decomposition is correct
4746
for i in unroll(0, NUM_ENCODING_FE):
4847
for j in unroll(0, 24 / (2 * W)):
4948
assert encoding[i * (24 / (2 * W)) + j] < CHAIN_LENGTH**2
5049

51-
assert remaining[i] < 2**7 - 1 # ensures uniformity + prevent overflow
52-
53-
partial_sum: Mut = remaining[i] * 2**24
54-
partial_sum += encoding[i * (24 / (2 * W))]
50+
partial_sum: Mut = encoding[i * (24 / (2 * W))]
5551
for j in unroll(1, 24 / (2 * W)):
5652
partial_sum += encoding[i * (24 / (2 * W)) + j] * (CHAIN_LENGTH**2) ** j
57-
assert partial_sum == encoding_fe[i]
53+
54+
# p = 2^31 - 2^24 + 1, so inv(2^24) = -127 (mod p).
55+
# Deduce remaining_i from partial_sum + remaining_i * 2^24 == encoding_fe[i]:
56+
# remaining_i = (encoding_fe[i] - partial_sum) * inv(2^24) = (partial_sum - encoding_fe[i]) * 127
57+
remaining_i = (partial_sum - encoding_fe[i]) * 127
58+
assert remaining_i < 2**7 - 1 # ensures uniformity + prevent overflow
5859

5960
# grinding
6061
debug_assert(V_GRINDING % 2 == 0)

0 commit comments

Comments
 (0)