Skip to content

Commit a713081

Browse files
committed
update work on lexico example
1 parent a9e5fa4 commit a713081

2 files changed

Lines changed: 102 additions & 81 deletions

File tree

CHANGE_LOG

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
2025-08-XX 3.6.1:
22
-------------------
3-
* fix RecursionError in `util.random_k()`, see #239
43
* add development files for statistical tests in `devel/random/`
54
* optimize `util.sum_indices()`
5+
* fix RecursionError in `util.random_k()`, see #239
66
* add `devel/test_sum_indices.py`
77

88

examples/lexico.py

Lines changed: 101 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
# http://www-graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
33

44
from bitarray import bitarray
5-
from bitarray.util import zeros, ones, ba2int, int2ba
6-
7-
from math import comb
5+
from bitarray.util import zeros, ba2int, int2ba
86

97

108
def all_perm(n, k, endian=None):
119
"""all_perm(n, k, endian=None) -> iterator
1210
13-
Return an iterator over all bitarrays of length `n` with `k` bits set to 1
11+
Return an iterator over all bitarrays of length `n` with `k` bits set to one
1412
in lexicographical order.
1513
"""
1614
if n < 0:
@@ -21,14 +19,14 @@ def all_perm(n, k, endian=None):
2119
if k == 0:
2220
yield zeros(n, endian)
2321
return
24-
if k == n:
25-
yield ones(n, endian)
26-
return
2722
raise ValueError("k must be in range 0 <= k <= n, got %s" % k)
2823

2924
v = (1 << k) - 1
30-
for _ in range(comb(n, k)):
31-
yield int2ba(v, length=n, endian=endian)
25+
while True:
26+
try:
27+
yield int2ba(v, length=n, endian=endian)
28+
except OverflowError:
29+
return
3230
t = (v | (v - 1)) + 1
3331
v = t | ((((t & -t) // (v & -v)) >> 1) - 1)
3432

@@ -45,18 +43,21 @@ def next_perm(a):
4543
if v == 0:
4644
return a
4745
t = (v | (v - 1)) + 1
48-
w = t | ((((t & -t) // (v & -v)) >> 1) - 1)
46+
v = t | ((((t & -t) // (v & -v)) >> 1) - 1)
4947
try:
50-
return int2ba(w, length=len(a), endian=a.endian)
48+
return int2ba(v, length=len(a), endian=a.endian)
5149
except OverflowError:
5250
return a[::-1]
5351

5452
# ---------------------------------------------------------------------------
5553

5654
import unittest
55+
from math import comb
56+
from random import choice, getrandbits, randrange
57+
from itertools import pairwise
5758

5859
from bitarray import frozenbitarray
59-
from bitarray.util import urandom
60+
from bitarray.util import random_k
6061

6162

6263
class PermTests(unittest.TestCase):
@@ -69,77 +70,65 @@ def test_explicit_1(self):
6970
self.assertEqual(a.count(), 3)
7071
self.assertEqual(a, bitarray(s, 'big'))
7172

72-
def test_explicit_2(self):
73-
for seq in (['0'], ['1'], ['00'], ['11'], ['01', '10'],
74-
['001', '010', '100'], ['011', '101', '110'],
75-
['0011', '0101', '0110', '1001', '1010', '1100']):
76-
a = bitarray(seq[0], 'big')
77-
for i in range(20):
78-
self.assertEqual(a, bitarray(seq[i % len(seq)]))
79-
a = next_perm(a)
73+
def test_zeros_ones(self):
74+
for n in range(1, 30):
75+
endian = choice(["little", "big"])
76+
v = getrandbits(1)
8077

81-
def test_all_same(self):
82-
for endian in 'little', 'big':
83-
for n in range(1, 30):
84-
for v in 0, 1:
85-
a = bitarray(n, endian)
86-
a.setall(v)
87-
self.assertEqual(next_perm(a), a)
78+
lst = list(all_perm(n, v * n, endian))
79+
self.assertEqual(len(lst), 1)
80+
a = lst[0]
81+
c = a.copy()
82+
self.assertEqual(a.endian, endian)
83+
self.assertEqual(len(a), n)
84+
if v:
85+
self.assertTrue(a.all())
86+
else:
87+
self.assertFalse(a.any())
88+
self.assertEqual(next_perm(a), a)
89+
self.assertEqual(a, c)
8890

8991
def test_turnover(self):
9092
for a in [bitarray('11111110000', 'big'),
9193
bitarray('0000001111111', 'little')]:
9294
self.assertEqual(next_perm(a), a[::-1])
9395

94-
def test_large(self):
95-
a = bitarray('10010101010100100110010101110100111100101111', 'big')
96-
b = next_perm(a)
97-
c = bitarray('10010101010100100110010101110100111100110111')
98-
self.assertEqual(b, c)
96+
def test_next_perm_random(self):
97+
for _ in range(100):
98+
n = randrange(2, 1_000_000)
99+
k = randrange(1, n)
100+
a = random_k(n, k, endian=choice(["little", "big"]))
101+
b = next_perm(a)
102+
self.assertEqual(len(b), n)
103+
self.assertEqual(b.count(), k)
104+
self.assertEqual(b.endian, a.endian)
105+
self.assertNotEqual(a, b)
106+
if ba2int(a) > ba2int(b):
107+
c = a.copy()
108+
c.sort(c.endian == 'big')
109+
self.assertEqual(a, c)
110+
self.assertEqual(b, a[::-1])
99111

100112
def test_errors(self):
101113
self.assertRaises(ValueError, next_perm, bitarray())
102114
self.assertRaises(TypeError, next_perm, '1')
103115

104-
def check_all_perm(self, s):
105-
s1 = s.count(1)
106-
n = 0
107-
a = bitarray(s)
116+
def check_perm_cycle(self, start):
117+
n, k = len(start), start.count()
118+
a = bitarray(start)
108119
coll = set()
109-
while 1:
120+
c = 0
121+
while True:
110122
a = next_perm(a)
111123
coll.add(frozenbitarray(a))
112-
self.assertEqual(len(a), len(s))
113-
self.assertEqual(a.count(), s1)
114-
self.assertEqual(a.endian, s.endian)
115-
n += 1
116-
if a == s:
124+
self.assertEqual(len(a), n)
125+
self.assertEqual(a.count(), k)
126+
self.assertEqual(a.endian, start.endian)
127+
c += 1
128+
if a == start:
117129
break
118-
self.assertEqual(n, comb(len(s), s1))
119-
self.assertEqual(len(coll), n)
120-
121-
def check_order(self, a):
122-
i = -1
123-
for _ in range(comb(len(a), a.count())):
124-
i, j = ba2int(a), i
125-
self.assertTrue(i > j)
126-
a = next_perm(a)
127-
128-
def test_few(self):
129-
for s in '0', '1', '00', '01', '111', '0011', '01011', '000000011':
130-
for endian in 'little', 'big':
131-
a = bitarray(s, endian)
132-
self.check_all_perm(a)
133-
a.sort(a.endian == 'little')
134-
self.check_order(a)
135-
136-
def test_random(self):
137-
for endian in "little", "big":
138-
for n in range(1, 10):
139-
a = urandom(n, endian)
140-
self.check_all_perm(a)
141-
a.sort(a.endian == 'little')
142-
self.check_order(a)
130+
self.assertEqual(c, comb(n, k))
131+
self.assertEqual(len(coll), c)
143132

144133
def test_all_perm_explicit(self):
145134
for n, k, res in [
@@ -153,24 +142,56 @@ def test_all_perm_explicit(self):
153142
(3, 1, ['001', '010', '100']),
154143
(3, 2, ['011', '101', '110']),
155144
(3, 3, ['111']),
156-
]:
157-
self.assertEqual(list(all_perm(n, k, 'big')),
158-
[bitarray(s) for s in res])
145+
(4, 2, ['0011', '0101', '0110', '1001', '1010', '1100']),
146+
]:
147+
lst = list(all_perm(n, k, 'big'))
148+
self.assertEqual(len(lst), comb(n, k))
149+
self.assertEqual(lst, [bitarray(s) for s in res])
150+
if n == 0:
151+
continue
152+
a = lst[0]
153+
for i in range(20):
154+
self.assertEqual(a, bitarray(res[i % len(lst)]))
155+
a = next_perm(a)
159156

160-
def test_all_perm_1(self):
161-
n, k = 10, 5
162-
c = 0
163-
s = set()
164-
for a in all_perm(n, k, 'little'):
157+
def test_all_perm(self):
158+
n, k = 17, 5
159+
endian=choice(["little", "big"])
160+
161+
prev = None
162+
cnt = 0
163+
coll = set()
164+
for a in all_perm(n, k, endian):
165165
self.assertEqual(type(a), bitarray)
166166
self.assertEqual(len(a), n)
167167
self.assertEqual(a.count(), k)
168-
s.add(frozenbitarray(a))
169-
c += 1
170-
self.assertEqual(c, comb(n, k))
171-
self.assertEqual(len(s), comb(n, k))
168+
self.assertEqual(a.endian, endian)
169+
coll.add(frozenbitarray(a))
170+
if prev is None:
171+
first = a.copy()
172+
c = a.copy()
173+
c.sort(c.endian == "little")
174+
self.assertEqual(a, c)
175+
else:
176+
self.assertEqual(next_perm(prev), a)
177+
self.assertTrue(ba2int(prev) < ba2int(a))
178+
prev = a
179+
cnt += 1
180+
181+
self.assertEqual(cnt, comb(n, k))
182+
self.assertEqual(len(coll), cnt)
183+
184+
# a is now the last permutation
185+
last = a.copy()
186+
self.assertTrue(ba2int(first) < ba2int(last))
187+
self.assertEqual(last, first[::-1])
188+
189+
def test_all_perm_order(self):
190+
n, k = 10, 5
191+
for a, b in pairwise(all_perm(n, k, 'little')):
192+
self.assertTrue(ba2int(b) > ba2int(a))
193+
self.assertEqual(next_perm(a), b)
172194

173-
# ---------------------------------------------------------------------------
174195

175196
if __name__ == '__main__':
176197
unittest.main()

0 commit comments

Comments
 (0)