Skip to content

Commit e9703e5

Browse files
committed
Complete tests for DecimalRange
1 parent ce1fe93 commit e9703e5

2 files changed

Lines changed: 136 additions & 63 deletions

File tree

src/pathseq/_file_num_seq/_decimal_range.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,26 @@ def __eq__(self, value: object) -> bool:
104104

105105
def __hash__(self) -> int:
106106
length = len(self)
107-
to_hash: tuple[int, decimal.Decimal | None, None]
108-
if length:
107+
to_hash: tuple[int, decimal.Decimal | None, decimal.Decimal | None]
108+
if not length:
109+
to_hash = (length, None, None)
110+
elif length == 1:
109111
to_hash = (length, self.start, None)
110112
else:
111-
to_hash = (length, None, None)
113+
to_hash = (length, self.start, self.step)
112114
return hash(to_hash)
113115

114116
def __iter__(self) -> Iterator[decimal.Decimal]:
115117
return DecimalRangeIterator(self.start, self.stop, self.step)
116118

119+
def __reversed__(self) -> Iterator[decimal.Decimal]:
120+
n = len(self)
121+
if n == 0:
122+
return iter(())
123+
124+
last = self.start + self.step * (n - 1)
125+
return DecimalRangeIterator(last, self.start - self.step, -self.step)
126+
117127
def __len__(self) -> int:
118128
# x_n = a + d(n-1)
119129
# stop > start + step * (n-1)
@@ -147,14 +157,6 @@ def __repr__(self) -> str:
147157

148158
return f"{self.__class__.__name__}({self._start}, {self._stop}, {self._step})"
149159

150-
def __reversed__(self) -> Iterator[decimal.Decimal]:
151-
if not len(self):
152-
return iter(())
153-
154-
new_stop = self._start - self._step
155-
new_start = new_stop + self._step * len(self)
156-
return DecimalRangeIterator(new_start, new_stop, -self._step)
157-
158160
def count(self, value: decimal.Decimal) -> int:
159161
return 1 if value in self else 0
160162

tests/test_decimal_range.py

Lines changed: 123 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,28 @@
77
from pathseq._file_num_seq._decimal_range import DecimalRange
88

99

10+
def scale_to_integer_range(r1: DecimalRange, r2: DecimalRange) -> tuple[range, range]:
11+
max_exponent = max(
12+
(
13+
-x.as_tuple().exponent
14+
for x in (r1.start, r1.stop, r1.step, r2.start, r2.stop, r2.step)
15+
),
16+
default=0,
17+
)
18+
return (
19+
range(
20+
int(r1.start.scaleb(max_exponent)),
21+
int(r1.stop.scaleb(max_exponent)),
22+
int(r1.step.scaleb(max_exponent)),
23+
),
24+
range(
25+
int(r2.start.scaleb(max_exponent)),
26+
int(r2.stop.scaleb(max_exponent)),
27+
int(r2.step.scaleb(max_exponent)),
28+
),
29+
)
30+
31+
1032
@st.composite
1133
def infinite_decimals(draw):
1234
values = st.one_of(
@@ -42,22 +64,43 @@ def invalid_ranges(draw):
4264
@st.composite
4365
def valid_ranges(draw, max_len=10000):
4466
places = draw(st.integers(0, 6))
45-
valid_start = st.decimals(allow_nan=False, allow_infinity=False, places=places)
67+
valid_start = st.decimals(
68+
min_value=-max_len,
69+
max_value=max_len,
70+
allow_nan=False,
71+
allow_infinity=False,
72+
places=places,
73+
)
4674
start = draw(valid_start)
4775

4876
# Only allow ranges that we can loop over in a sensible amount of time
4977
diffs = st.decimals(
50-
max_value=1500, allow_nan=False, allow_infinity=False, places=places
78+
min_value=-max_len,
79+
max_value=max_len,
80+
allow_nan=False,
81+
allow_infinity=False,
82+
places=places,
5183
)
5284
diff = draw(diffs)
53-
if draw(st.booleans()):
54-
diff = -diff
5585
stop = start + diff
5686

87+
# Step must be nonzero and such that the number of elements does not exceed max_len.
88+
abs_diff = abs(diff)
89+
if abs_diff == 0:
90+
min_step = decimal.Decimal("1")
91+
else:
92+
min_step = abs_diff / decimal.Decimal(max_len)
5793
valid_step = st.decimals(
58-
max_value=diff / max_len, allow_nan=False, allow_infinity=False, places=places
94+
min_value=min_step,
95+
max_value=abs_diff if abs_diff != 0 else decimal.Decimal("1"),
96+
allow_nan=False,
97+
allow_infinity=False,
98+
places=places,
5999
).filter(lambda x: not x.is_zero())
60100
step = draw(valid_step)
101+
# Randomly flip sign of step to allow negative steps
102+
if draw(st.booleans()):
103+
step = -step
61104

62105
try:
63106
len(DecimalRange(start, stop, step))
@@ -98,65 +141,93 @@ def test_bool(values):
98141
assert not bool(range_)
99142

100143

101-
# TODO: This test is too slow
102-
@pytest.mark.skip
103144
@given(valid_ranges())
104145
def test_len(values):
105146
range_ = DecimalRange(*values)
106147
len_ = sum(1 for _ in range_)
107148
assert len(range_) == len_
108149

109150

110-
@st.composite
111-
def ranges_with_index(draw):
112-
values = draw(valid_ranges())
113-
range_ = DecimalRange(*values)
114-
assume(bool(range_))
115-
index = draw(st.integers(min_value=0, max_value=len(range_) - 1))
116-
return (range_, index)
117-
118-
119-
@given(ranges_with_index())
120-
def test_contains_truthy(range_and_index):
121-
range_, index = range_and_index
122-
123-
value = range_.start + range_.step * index
124-
if value in range_:
125-
assert True
126-
elif range_.start == pytest.approx(value):
127-
# Floating point precision means that we didn't quite hit the mark.
128-
# So test the iter search code path.
129-
assert pytest.approx(value) in range_
151+
@given(valid_ranges())
152+
def test_contains(values):
153+
r = DecimalRange(*values)
154+
last_item = None
155+
for v in r:
156+
assert v in r
157+
last_item = v
158+
159+
if last_item is not None:
160+
not_in_range = last_item + r.step
161+
assert not_in_range not in r
130162
else:
131-
# Assert the failed condition so that pytest gives debuggable output.
132-
assert value in range_
133-
134-
135-
@given(
136-
valid_ranges(),
137-
st.decimals(allow_nan=False, allow_infinity=False).filter(
138-
lambda x: x.as_integer_ratio()[1] != 1
139-
),
140-
)
141-
def test_contains_falsey(values, index):
142-
range_ = DecimalRange(*values)
143-
144-
value = range_.start + range_.step * index
145-
if value != range_.start:
146-
assert value not in range_
147-
163+
assert r.start not in r
164+
assert r.stop not in r
165+
166+
167+
@given(valid_ranges(), valid_ranges())
168+
def test_eq_and_hash(values1, values2):
169+
r1 = DecimalRange(*values1)
170+
r2 = DecimalRange(*values2)
171+
# Reflexivity
172+
assert r1 == r1
173+
assert hash(r1) == hash(r1)
174+
# Symmetry and hash equality for equal objects.
175+
# Match the behaviour of built-in range, which also only considers
176+
# start and step for hashing, and ignores stop.
177+
ri1, ri2 = scale_to_integer_range(r1, r2)
178+
if ri1 == ri2:
179+
assert r1 == r2
180+
assert r2 == r1
181+
assert hash(r1) == hash(r2)
182+
else:
183+
# Hash collisions are possible, so we don't assert hash(r1) != hash(r2)
184+
assert r1 != r2
148185

149-
def test_eq_and_hash():
150-
pass
151186

187+
@given(valid_ranges())
188+
def test_iter_and_reversed(values):
189+
r = DecimalRange(*values)
190+
items = list(r)
191+
expected = []
192+
current = r.start
193+
while (r.step > 0 and current < r.stop) or (r.step < 0 and current > r.stop):
194+
expected.append(current)
195+
current += r.step
152196

153-
def test_iter_and_reversed():
154-
pass
197+
assert items == expected
198+
assert list(reversed(r)) == expected[::-1]
155199

156200

157-
def test_count():
158-
pass
201+
@given(valid_ranges())
202+
def test_count(values):
203+
r = DecimalRange(*values)
204+
last_item = None
205+
# For each value in the range, count should be 1
206+
for v in r:
207+
assert r.count(v) == 1
208+
last_item = v
209+
210+
# For a value not in the range, count should be 0
211+
if last_item is not None:
212+
not_in_range = last_item + r.step
213+
assert r.count(not_in_range) == 0
214+
else:
215+
# For empty range, any value should have count 0
216+
assert r.count(r.start) == 0
217+
assert r.count(r.stop) == 0
159218

160219

161-
def test_index():
162-
pass
220+
@given(valid_ranges())
221+
def test_index(values):
222+
r = DecimalRange(*values)
223+
last_item = None
224+
# For each value in the range, index should return its position
225+
for idx, v in enumerate(r):
226+
assert r.index(v) == idx
227+
last_item = v
228+
229+
# For a value not in the range, index should raise ValueError
230+
if last_item is not None:
231+
not_in_range = last_item + r.step
232+
with pytest.raises(ValueError):
233+
r.index(not_in_range)

0 commit comments

Comments
 (0)