Skip to content

Commit 2ad8877

Browse files
feat: Tolerances for inner lists and arrays
1 parent 2a3010b commit 2ad8877

3 files changed

Lines changed: 99 additions & 25 deletions

File tree

diffly/_conditions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,7 @@ def _compare_sequence_columns(
206206
n_elements = dtype_right.shape[0]
207207
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
208208
else: # pl.List vs pl.List
209-
if not isinstance(max_list_length, int):
210-
# Fallback for nested list comparisons where no max_list_length is
211-
# available: perform a direct equality comparison without element-wise
212-
# unrolling.
213-
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
209+
assert max_list_length is not None
214210
n_elements = max_list_length
215211
has_same_length = col_left.list.len().eq_missing(col_right.list.len())
216212

@@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
232228
abs_tol=abs_tol,
233229
rel_tol=rel_tol,
234230
abs_tol_temporal=abs_tol_temporal,
235-
max_list_length=None,
231+
max_list_length=max_list_length,
236232
)
237233
for i in range(n_elements)
238234
]

diffly/comparison.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
711711

712712
@cached_property
713713
def _max_list_lengths_by_column(self) -> dict[str, int]:
714-
list_columns = [
715-
col
716-
for col in self._other_common_columns
717-
if isinstance(self.left_schema[col], pl.List)
718-
and isinstance(self.right_schema[col], pl.List)
719-
]
720-
if not list_columns:
714+
"""Max list length across all nesting levels, for columns where either side
715+
contains a List anywhere in its type tree."""
716+
left_exprs: list[pl.Expr] = []
717+
right_exprs: list[pl.Expr] = []
718+
columns: list[str] = []
719+
720+
for col in self._other_common_columns:
721+
col_left = _list_length_exprs(pl.col(col), self.left_schema[col])
722+
col_right = _list_length_exprs(pl.col(col), self.right_schema[col])
723+
if not col_left and not col_right:
724+
continue
725+
columns.append(col)
726+
left_exprs.append(_max_or_zero(col_left).alias(col))
727+
right_exprs.append(_max_or_zero(col_right).alias(col))
728+
729+
if not columns:
721730
return {}
722731

723-
exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
724732
[left_max, right_max] = pl.collect_all(
725-
[self.left.select(exprs), self.right.select(exprs)]
733+
[self.left.select(left_exprs), self.right.select(right_exprs)]
726734
)
727735
return {
728736
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
729-
for col in list_columns
737+
for col in columns
730738
}
731739

732740
def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
@@ -833,3 +841,30 @@ def right_only(self) -> Schema:
833841
"""Columns that are only present in the right data frame, mapped to their data
834842
types."""
835843
return self.right() - self.left()
844+
845+
846+
def _list_length_exprs(
847+
expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass
848+
) -> list[pl.Expr]:
849+
"""Collect max-list-length scalar expressions for every List level in the type
850+
tree."""
851+
if isinstance(dtype, pl.List):
852+
return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)]
853+
if isinstance(dtype, pl.Array):
854+
return _list_length_exprs(expr.explode(), dtype.inner)
855+
if isinstance(dtype, pl.Struct):
856+
return [
857+
e
858+
for field in dtype.fields
859+
for e in _list_length_exprs(expr.struct[field.name], field.dtype)
860+
]
861+
return []
862+
863+
864+
def _max_or_zero(exprs: list[pl.Expr]) -> pl.Expr:
865+
"""Return the horizontal max of scalar expressions, or literal 0 if empty."""
866+
if not exprs:
867+
return pl.lit(0)
868+
if len(exprs) == 1:
869+
return exprs[0]
870+
return pl.max_horizontal(exprs)

tests/test_conditions.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def test_condition_equal_columns_list_array_with_tolerance(
102102
schema={"pk": pl.Int64, "a_right": rhs_type},
103103
)
104104

105+
max_list_length: int | None = None
106+
if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List):
107+
max_list_length = 2
108+
105109
# Act
106110
actual = (
107111
lhs.join(rhs, on="pk", maintain_order="left")
@@ -112,7 +116,7 @@ def test_condition_equal_columns_list_array_with_tolerance(
112116
dtype_right=rhs.schema["a_right"],
113117
abs_tol=0.5,
114118
rel_tol=0,
115-
max_list_length=2,
119+
max_list_length=max_list_length,
116120
)
117121
)
118122
.to_series()
@@ -156,6 +160,10 @@ def test_condition_equal_columns_nested_list_array_with_tolerance(
156160
schema={"pk": pl.Int64, "a_right": rhs_type},
157161
)
158162

163+
max_list_length: int | None = None
164+
if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List):
165+
max_list_length = 3
166+
159167
# Act
160168
actual = (
161169
lhs.join(rhs, on="pk", maintain_order="left")
@@ -166,16 +174,13 @@ def test_condition_equal_columns_nested_list_array_with_tolerance(
166174
dtype_right=rhs.schema["a_right"],
167175
abs_tol=0.5,
168176
rel_tol=0,
169-
max_list_length=2,
177+
max_list_length=max_list_length,
170178
)
171179
)
172180
.to_series()
173181
)
174182

175-
if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
176-
assert actual.to_list() == [True, False, False]
177-
else:
178-
assert actual.to_list() == [True, True, False]
183+
assert actual.to_list() == [True, True, False]
179184

180185

181186
def test_condition_equal_columns_nested_dtype_mismatch() -> None:
@@ -201,7 +206,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None:
201206
"a",
202207
dtype_left=lhs.schema["a_left"],
203208
dtype_right=rhs.schema["a_right"],
204-
max_list_length=None,
209+
max_list_length=2,
205210
)
206211
)
207212
.to_series()
@@ -341,7 +346,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
341346
"a",
342347
dtype_left=lhs.schema["a_left"],
343348
dtype_right=rhs.schema["a_right"],
344-
max_list_length=None,
349+
max_list_length=2,
345350
abs_tol=0.5,
346351
rel_tol=0,
347352
)
@@ -406,21 +411,59 @@ def test_condition_equal_columns_empty_list_array(
406411
schema={"pk": pl.Int64, "a_right": rhs_type},
407412
)
408413

414+
max_list_length: int | None = None
415+
if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List):
416+
max_list_length = 0
417+
409418
actual = (
410419
lhs.join(rhs, on="pk", maintain_order="left")
411420
.select(
412421
condition_equal_columns(
413422
"a",
414423
dtype_left=lhs.schema["a_left"],
415424
dtype_right=rhs.schema["a_right"],
416-
max_list_length=None,
425+
max_list_length=max_list_length,
417426
)
418427
)
419428
.to_series()
420429
)
421430
assert actual.to_list() == [True, True]
422431

423432

433+
def test_condition_equal_columns_lists_only_inner() -> None:
434+
# Arrange
435+
lhs = pl.DataFrame(
436+
{
437+
"pk": [1, 2],
438+
"a_left": [{"x": 1, "y": [1.0, 2.0, 3.0]}, {"x": 2, "y": [4.0, 5.0, 6.0]}],
439+
},
440+
)
441+
rhs = pl.DataFrame(
442+
{
443+
"pk": [1, 2],
444+
"a_right": [{"x": 1, "y": [1.0, 2.1, 3.0]}, {"x": 2, "y": [4.0, 5.3, 6.0]}],
445+
},
446+
)
447+
448+
# Act
449+
actual = (
450+
lhs.join(rhs, on="pk", maintain_order="left")
451+
.select(
452+
condition_equal_columns(
453+
"a",
454+
dtype_left=lhs.schema["a_left"],
455+
dtype_right=rhs.schema["a_right"],
456+
max_list_length=3,
457+
abs_tol=0.2,
458+
)
459+
)
460+
.to_series()
461+
)
462+
463+
# Assert
464+
assert actual.to_list() == [True, False]
465+
466+
424467
@pytest.mark.parametrize(
425468
("dtype_left", "dtype_right", "can_compare_dtypes"),
426469
[

0 commit comments

Comments
 (0)