22# http://www-graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
33
44from bitarray import bitarray
5- from bitarray .util import zeros , ones , ba2int , int2ba
6-
7- from math import comb
5+ from bitarray .util import zeros , ba2int , int2ba
86
97
108def all_perm (n , k , endian = None ):
119 """all_perm(n, k, endian=None) -> iterator
1210
13- Return an iterator over all bitarrays of length `n` with `k` bits set to 1
11+ Return an iterator over all bitarrays of length `n` with `k` bits set to one
1412in lexicographical order.
1513"""
1614 if n < 0 :
@@ -21,14 +19,14 @@ def all_perm(n, k, endian=None):
2119 if k == 0 :
2220 yield zeros (n , endian )
2321 return
24- if k == n :
25- yield ones (n , endian )
26- return
2722 raise ValueError ("k must be in range 0 <= k <= n, got %s" % k )
2823
2924 v = (1 << k ) - 1
30- for _ in range (comb (n , k )):
31- yield int2ba (v , length = n , endian = endian )
25+ while True :
26+ try :
27+ yield int2ba (v , length = n , endian = endian )
28+ except OverflowError :
29+ return
3230 t = (v | (v - 1 )) + 1
3331 v = t | ((((t & - t ) // (v & - v )) >> 1 ) - 1 )
3432
@@ -45,18 +43,21 @@ def next_perm(a):
4543 if v == 0 :
4644 return a
4745 t = (v | (v - 1 )) + 1
48- w = t | ((((t & - t ) // (v & - v )) >> 1 ) - 1 )
46+ v = t | ((((t & - t ) // (v & - v )) >> 1 ) - 1 )
4947 try :
50- return int2ba (w , length = len (a ), endian = a .endian )
48+ return int2ba (v , length = len (a ), endian = a .endian )
5149 except OverflowError :
5250 return a [::- 1 ]
5351
5452# ---------------------------------------------------------------------------
5553
5654import unittest
55+ from math import comb
56+ from random import choice , getrandbits , randrange
57+ from itertools import pairwise
5758
5859from bitarray import frozenbitarray
59- from bitarray .util import urandom
60+ from bitarray .util import random_k
6061
6162
6263class PermTests (unittest .TestCase ):
@@ -69,77 +70,65 @@ def test_explicit_1(self):
6970 self .assertEqual (a .count (), 3 )
7071 self .assertEqual (a , bitarray (s , 'big' ))
7172
72- def test_explicit_2 (self ):
73- for seq in (['0' ], ['1' ], ['00' ], ['11' ], ['01' , '10' ],
74- ['001' , '010' , '100' ], ['011' , '101' , '110' ],
75- ['0011' , '0101' , '0110' , '1001' , '1010' , '1100' ]):
76- a = bitarray (seq [0 ], 'big' )
77- for i in range (20 ):
78- self .assertEqual (a , bitarray (seq [i % len (seq )]))
79- a = next_perm (a )
73+ def test_zeros_ones (self ):
74+ for n in range (1 , 30 ):
75+ endian = choice (["little" , "big" ])
76+ v = getrandbits (1 )
8077
81- def test_all_same (self ):
82- for endian in 'little' , 'big' :
83- for n in range (1 , 30 ):
84- for v in 0 , 1 :
85- a = bitarray (n , endian )
86- a .setall (v )
87- self .assertEqual (next_perm (a ), a )
78+ lst = list (all_perm (n , v * n , endian ))
79+ self .assertEqual (len (lst ), 1 )
80+ a = lst [0 ]
81+ c = a .copy ()
82+ self .assertEqual (a .endian , endian )
83+ self .assertEqual (len (a ), n )
84+ if v :
85+ self .assertTrue (a .all ())
86+ else :
87+ self .assertFalse (a .any ())
88+ self .assertEqual (next_perm (a ), a )
89+ self .assertEqual (a , c )
8890
8991 def test_turnover (self ):
9092 for a in [bitarray ('11111110000' , 'big' ),
9193 bitarray ('0000001111111' , 'little' )]:
9294 self .assertEqual (next_perm (a ), a [::- 1 ])
9395
94- def test_large (self ):
95- a = bitarray ('10010101010100100110010101110100111100101111' , 'big' )
96- b = next_perm (a )
97- c = bitarray ('10010101010100100110010101110100111100110111' )
98- self .assertEqual (b , c )
96+ def test_next_perm_random (self ):
97+ for _ in range (100 ):
98+ n = randrange (2 , 1_000_000 )
99+ k = randrange (1 , n )
100+ a = random_k (n , k , endian = choice (["little" , "big" ]))
101+ b = next_perm (a )
102+ self .assertEqual (len (b ), n )
103+ self .assertEqual (b .count (), k )
104+ self .assertEqual (b .endian , a .endian )
105+ self .assertNotEqual (a , b )
106+ if ba2int (a ) > ba2int (b ):
107+ c = a .copy ()
108+ c .sort (c .endian == 'big' )
109+ self .assertEqual (a , c )
110+ self .assertEqual (b , a [::- 1 ])
99111
100112 def test_errors (self ):
101113 self .assertRaises (ValueError , next_perm , bitarray ())
102114 self .assertRaises (TypeError , next_perm , '1' )
103115
104- def check_all_perm (self , s ):
105- s1 = s .count (1 )
106- n = 0
107- a = bitarray (s )
116+ def check_perm_cycle (self , start ):
117+ n , k = len (start ), start .count ()
118+ a = bitarray (start )
108119 coll = set ()
109- while 1 :
120+ c = 0
121+ while True :
110122 a = next_perm (a )
111123 coll .add (frozenbitarray (a ))
112- self .assertEqual (len (a ), len ( s ) )
113- self .assertEqual (a .count (), s1 )
114- self .assertEqual (a .endian , s .endian )
115- n += 1
116- if a == s :
124+ self .assertEqual (len (a ), n )
125+ self .assertEqual (a .count (), k )
126+ self .assertEqual (a .endian , start .endian )
127+ c += 1
128+ if a == start :
117129 break
118- self .assertEqual (n , comb (len (s ), s1 ))
119- self .assertEqual (len (coll ), n )
120-
121- def check_order (self , a ):
122- i = - 1
123- for _ in range (comb (len (a ), a .count ())):
124- i , j = ba2int (a ), i
125- self .assertTrue (i > j )
126- a = next_perm (a )
127-
128- def test_few (self ):
129- for s in '0' , '1' , '00' , '01' , '111' , '0011' , '01011' , '000000011' :
130- for endian in 'little' , 'big' :
131- a = bitarray (s , endian )
132- self .check_all_perm (a )
133- a .sort (a .endian == 'little' )
134- self .check_order (a )
135-
136- def test_random (self ):
137- for endian in "little" , "big" :
138- for n in range (1 , 10 ):
139- a = urandom (n , endian )
140- self .check_all_perm (a )
141- a .sort (a .endian == 'little' )
142- self .check_order (a )
130+ self .assertEqual (c , comb (n , k ))
131+ self .assertEqual (len (coll ), c )
143132
144133 def test_all_perm_explicit (self ):
145134 for n , k , res in [
@@ -153,24 +142,56 @@ def test_all_perm_explicit(self):
153142 (3 , 1 , ['001' , '010' , '100' ]),
154143 (3 , 2 , ['011' , '101' , '110' ]),
155144 (3 , 3 , ['111' ]),
156- ]:
157- self .assertEqual (list (all_perm (n , k , 'big' )),
158- [bitarray (s ) for s in res ])
145+ (4 , 2 , ['0011' , '0101' , '0110' , '1001' , '1010' , '1100' ]),
146+ ]:
147+ lst = list (all_perm (n , k , 'big' ))
148+ self .assertEqual (len (lst ), comb (n , k ))
149+ self .assertEqual (lst , [bitarray (s ) for s in res ])
150+ if n == 0 :
151+ continue
152+ a = lst [0 ]
153+ for i in range (20 ):
154+ self .assertEqual (a , bitarray (res [i % len (lst )]))
155+ a = next_perm (a )
159156
160- def test_all_perm_1 (self ):
161- n , k = 10 , 5
162- c = 0
163- s = set ()
164- for a in all_perm (n , k , 'little' ):
157+ def test_all_perm (self ):
158+ n , k = 17 , 5
159+ endian = choice (["little" , "big" ])
160+
161+ prev = None
162+ cnt = 0
163+ coll = set ()
164+ for a in all_perm (n , k , endian ):
165165 self .assertEqual (type (a ), bitarray )
166166 self .assertEqual (len (a ), n )
167167 self .assertEqual (a .count (), k )
168- s .add (frozenbitarray (a ))
169- c += 1
170- self .assertEqual (c , comb (n , k ))
171- self .assertEqual (len (s ), comb (n , k ))
168+ self .assertEqual (a .endian , endian )
169+ coll .add (frozenbitarray (a ))
170+ if prev is None :
171+ first = a .copy ()
172+ c = a .copy ()
173+ c .sort (c .endian == "little" )
174+ self .assertEqual (a , c )
175+ else :
176+ self .assertEqual (next_perm (prev ), a )
177+ self .assertTrue (ba2int (prev ) < ba2int (a ))
178+ prev = a
179+ cnt += 1
180+
181+ self .assertEqual (cnt , comb (n , k ))
182+ self .assertEqual (len (coll ), cnt )
183+
184+ # a is now the last permutation
185+ last = a .copy ()
186+ self .assertTrue (ba2int (first ) < ba2int (last ))
187+ self .assertEqual (last , first [::- 1 ])
188+
189+ def test_all_perm_order (self ):
190+ n , k = 10 , 5
191+ for a , b in pairwise (all_perm (n , k , 'little' )):
192+ self .assertTrue (ba2int (b ) > ba2int (a ))
193+ self .assertEqual (next_perm (a ), b )
172194
173- # ---------------------------------------------------------------------------
174195
175196if __name__ == '__main__' :
176197 unittest .main ()
0 commit comments