Skip to content

Commit f600d47

Browse files
authored
PAIR: Add Shuffle (#75)
1 parent 2955e3c commit f600d47

2 files changed

Lines changed: 55 additions & 0 deletions

File tree

pkg/pair/pair.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/sha512"
66
"errors"
77
"hash"
8+
mrandv2 "math/rand/v2"
89

910
"github.com/gtank/ristretto255"
1011
)
@@ -119,3 +120,13 @@ func (pk *PrivateKey) Decrypt(ciphertext []byte) ([]byte, error) {
119120

120121
return cipher.MarshalText()
121122
}
123+
124+
// Shuffle shuffles the data in place by using the Fisher-Yates algorithm.
125+
// Note that ideally, it should be called with less than 2^32-1 (4 billion) elements.
126+
func Shuffle(data [][]byte) {
127+
// NOTE: since go 1.20, math.Rand seeds the global random number generator.
128+
// V2 uses ChaCha8 generator as the global one.
129+
mrandv2.Shuffle(len(data), func(i, j int) {
130+
data[i], data[j] = data[j], data[i]
131+
})
132+
}

pkg/pair/pair_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package pair
22

33
import (
4+
"bytes"
45
"crypto/rand"
56
"crypto/sha512"
7+
"slices"
68
"strings"
79
"testing"
810

@@ -59,3 +61,45 @@ func TestPAIR(t *testing.T) {
5961
t.Fatalf("want: %s, got: %s", string(ciphertext), string(decrypted))
6062
}
6163
}
64+
65+
func genData(n int) [][]byte {
66+
data := make([][]byte, n)
67+
for i := 0; i < n; i++ {
68+
// marshaled ristretto255.Scalar is 44 bytes
69+
data[i] = make([]byte, 44)
70+
rand.Read(data[i])
71+
}
72+
return data
73+
}
74+
75+
func TestShuffle(t *testing.T) {
76+
data := genData(1 << 10) // 1k
77+
orig := make([][]byte, len(data))
78+
copy(orig, data)
79+
80+
// shuffle the data in place
81+
Shuffle(data)
82+
83+
once := make([][]byte, len(data))
84+
copy(once, data)
85+
86+
if slices.EqualFunc(data, orig, bytes.Equal) {
87+
t.Fatalf("data not shuffled")
88+
}
89+
90+
// shuffle again
91+
Shuffle(data)
92+
93+
if slices.EqualFunc(data, once, bytes.Equal) {
94+
t.Fatalf("data not shuffled")
95+
}
96+
}
97+
98+
func BenchmarkShuffleOneMillionIDs(b *testing.B) {
99+
data := genData(1 << 20) // 1m
100+
b.ResetTimer()
101+
102+
for i := 0; i < b.N; i++ {
103+
Shuffle(data)
104+
}
105+
}

0 commit comments

Comments
 (0)