|
7 | 7 | from pathseq._file_num_seq._decimal_range import DecimalRange |
8 | 8 |
|
9 | 9 |
|
| 10 | +def scale_to_integer_range(r1: DecimalRange, r2: DecimalRange) -> tuple[range, range]: |
| 11 | + max_exponent = max( |
| 12 | + ( |
| 13 | + -x.as_tuple().exponent |
| 14 | + for x in (r1.start, r1.stop, r1.step, r2.start, r2.stop, r2.step) |
| 15 | + ), |
| 16 | + default=0, |
| 17 | + ) |
| 18 | + return ( |
| 19 | + range( |
| 20 | + int(r1.start.scaleb(max_exponent)), |
| 21 | + int(r1.stop.scaleb(max_exponent)), |
| 22 | + int(r1.step.scaleb(max_exponent)), |
| 23 | + ), |
| 24 | + range( |
| 25 | + int(r2.start.scaleb(max_exponent)), |
| 26 | + int(r2.stop.scaleb(max_exponent)), |
| 27 | + int(r2.step.scaleb(max_exponent)), |
| 28 | + ), |
| 29 | + ) |
| 30 | + |
| 31 | + |
10 | 32 | @st.composite |
11 | 33 | def infinite_decimals(draw): |
12 | 34 | values = st.one_of( |
@@ -42,22 +64,43 @@ def invalid_ranges(draw): |
42 | 64 | @st.composite |
43 | 65 | def valid_ranges(draw, max_len=10000): |
44 | 66 | places = draw(st.integers(0, 6)) |
45 | | - valid_start = st.decimals(allow_nan=False, allow_infinity=False, places=places) |
| 67 | + valid_start = st.decimals( |
| 68 | + min_value=-max_len, |
| 69 | + max_value=max_len, |
| 70 | + allow_nan=False, |
| 71 | + allow_infinity=False, |
| 72 | + places=places, |
| 73 | + ) |
46 | 74 | start = draw(valid_start) |
47 | 75 |
|
48 | 76 | # Only allow ranges that we can loop over in a sensible amount of time |
49 | 77 | diffs = st.decimals( |
50 | | - max_value=1500, allow_nan=False, allow_infinity=False, places=places |
| 78 | + min_value=-max_len, |
| 79 | + max_value=max_len, |
| 80 | + allow_nan=False, |
| 81 | + allow_infinity=False, |
| 82 | + places=places, |
51 | 83 | ) |
52 | 84 | diff = draw(diffs) |
53 | | - if draw(st.booleans()): |
54 | | - diff = -diff |
55 | 85 | stop = start + diff |
56 | 86 |
|
| 87 | + # Step must be nonzero and such that the number of elements does not exceed max_len. |
| 88 | + abs_diff = abs(diff) |
| 89 | + if abs_diff == 0: |
| 90 | + min_step = decimal.Decimal("1") |
| 91 | + else: |
| 92 | + min_step = abs_diff / decimal.Decimal(max_len) |
57 | 93 | valid_step = st.decimals( |
58 | | - max_value=diff / max_len, allow_nan=False, allow_infinity=False, places=places |
| 94 | + min_value=min_step, |
| 95 | + max_value=abs_diff if abs_diff != 0 else decimal.Decimal("1"), |
| 96 | + allow_nan=False, |
| 97 | + allow_infinity=False, |
| 98 | + places=places, |
59 | 99 | ).filter(lambda x: not x.is_zero()) |
60 | 100 | step = draw(valid_step) |
| 101 | + # Randomly flip sign of step to allow negative steps |
| 102 | + if draw(st.booleans()): |
| 103 | + step = -step |
61 | 104 |
|
62 | 105 | try: |
63 | 106 | len(DecimalRange(start, stop, step)) |
@@ -98,65 +141,93 @@ def test_bool(values): |
98 | 141 | assert not bool(range_) |
99 | 142 |
|
100 | 143 |
|
101 | | -# TODO: This test is too slow |
102 | | -@pytest.mark.skip |
103 | 144 | @given(valid_ranges()) |
104 | 145 | def test_len(values): |
105 | 146 | range_ = DecimalRange(*values) |
106 | 147 | len_ = sum(1 for _ in range_) |
107 | 148 | assert len(range_) == len_ |
108 | 149 |
|
109 | 150 |
|
110 | | -@st.composite |
111 | | -def ranges_with_index(draw): |
112 | | - values = draw(valid_ranges()) |
113 | | - range_ = DecimalRange(*values) |
114 | | - assume(bool(range_)) |
115 | | - index = draw(st.integers(min_value=0, max_value=len(range_) - 1)) |
116 | | - return (range_, index) |
117 | | - |
118 | | - |
119 | | -@given(ranges_with_index()) |
120 | | -def test_contains_truthy(range_and_index): |
121 | | - range_, index = range_and_index |
122 | | - |
123 | | - value = range_.start + range_.step * index |
124 | | - if value in range_: |
125 | | - assert True |
126 | | - elif range_.start == pytest.approx(value): |
127 | | - # Floating point precision means that we didn't quite hit the mark. |
128 | | - # So test the iter search code path. |
129 | | - assert pytest.approx(value) in range_ |
| 151 | +@given(valid_ranges()) |
| 152 | +def test_contains(values): |
| 153 | + r = DecimalRange(*values) |
| 154 | + last_item = None |
| 155 | + for v in r: |
| 156 | + assert v in r |
| 157 | + last_item = v |
| 158 | + |
| 159 | + if last_item is not None: |
| 160 | + not_in_range = last_item + r.step |
| 161 | + assert not_in_range not in r |
130 | 162 | else: |
131 | | - # Assert the failed condition so that pytest gives debuggable output. |
132 | | - assert value in range_ |
133 | | - |
134 | | - |
135 | | -@given( |
136 | | - valid_ranges(), |
137 | | - st.decimals(allow_nan=False, allow_infinity=False).filter( |
138 | | - lambda x: x.as_integer_ratio()[1] != 1 |
139 | | - ), |
140 | | -) |
141 | | -def test_contains_falsey(values, index): |
142 | | - range_ = DecimalRange(*values) |
143 | | - |
144 | | - value = range_.start + range_.step * index |
145 | | - if value != range_.start: |
146 | | - assert value not in range_ |
147 | | - |
| 163 | + assert r.start not in r |
| 164 | + assert r.stop not in r |
| 165 | + |
| 166 | + |
| 167 | +@given(valid_ranges(), valid_ranges()) |
| 168 | +def test_eq_and_hash(values1, values2): |
| 169 | + r1 = DecimalRange(*values1) |
| 170 | + r2 = DecimalRange(*values2) |
| 171 | + # Reflexivity |
| 172 | + assert r1 == r1 |
| 173 | + assert hash(r1) == hash(r1) |
| 174 | + # Symmetry and hash equality for equal objects. |
| 175 | + # Match the behaviour of built-in range, which also only considers |
| 176 | + # start and step for hashing, and ignores stop. |
| 177 | + ri1, ri2 = scale_to_integer_range(r1, r2) |
| 178 | + if ri1 == ri2: |
| 179 | + assert r1 == r2 |
| 180 | + assert r2 == r1 |
| 181 | + assert hash(r1) == hash(r2) |
| 182 | + else: |
| 183 | + # Hash collisions are possible, so we don't assert hash(r1) != hash(r2) |
| 184 | + assert r1 != r2 |
148 | 185 |
|
149 | | -def test_eq_and_hash(): |
150 | | - pass |
151 | 186 |
|
| 187 | +@given(valid_ranges()) |
| 188 | +def test_iter_and_reversed(values): |
| 189 | + r = DecimalRange(*values) |
| 190 | + items = list(r) |
| 191 | + expected = [] |
| 192 | + current = r.start |
| 193 | + while (r.step > 0 and current < r.stop) or (r.step < 0 and current > r.stop): |
| 194 | + expected.append(current) |
| 195 | + current += r.step |
152 | 196 |
|
153 | | -def test_iter_and_reversed(): |
154 | | - pass |
| 197 | + assert items == expected |
| 198 | + assert list(reversed(r)) == expected[::-1] |
155 | 199 |
|
156 | 200 |
|
157 | | -def test_count(): |
158 | | - pass |
| 201 | +@given(valid_ranges()) |
| 202 | +def test_count(values): |
| 203 | + r = DecimalRange(*values) |
| 204 | + last_item = None |
| 205 | + # For each value in the range, count should be 1 |
| 206 | + for v in r: |
| 207 | + assert r.count(v) == 1 |
| 208 | + last_item = v |
| 209 | + |
| 210 | + # For a value not in the range, count should be 0 |
| 211 | + if last_item is not None: |
| 212 | + not_in_range = last_item + r.step |
| 213 | + assert r.count(not_in_range) == 0 |
| 214 | + else: |
| 215 | + # For empty range, any value should have count 0 |
| 216 | + assert r.count(r.start) == 0 |
| 217 | + assert r.count(r.stop) == 0 |
159 | 218 |
|
160 | 219 |
|
161 | | -def test_index(): |
162 | | - pass |
| 220 | +@given(valid_ranges()) |
| 221 | +def test_index(values): |
| 222 | + r = DecimalRange(*values) |
| 223 | + last_item = None |
| 224 | + # For each value in the range, index should return its position |
| 225 | + for idx, v in enumerate(r): |
| 226 | + assert r.index(v) == idx |
| 227 | + last_item = v |
| 228 | + |
| 229 | + # For a value not in the range, index should raise ValueError |
| 230 | + if last_item is not None: |
| 231 | + not_in_range = last_item + r.step |
| 232 | + with pytest.raises(ValueError): |
| 233 | + r.index(not_in_range) |
0 commit comments