Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,7 @@ def _compare_sequence_columns(
n_elements = dtype_right.shape[0]
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
else: # pl.List vs pl.List
if not isinstance(max_list_length, int):
# Fallback for nested list comparisons where no max_list_length is
# available: perform a direct equality comparison without element-wise
# unrolling.
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
assert max_list_length is not None
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
n_elements = max_list_length
has_same_length = col_left.list.len().eq_missing(col_right.list.len())

Expand All @@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
max_list_length=None,
max_list_length=max_list_length,
)
for i in range(n_elements)
]
Expand Down
46 changes: 36 additions & 10 deletions diffly/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]

@cached_property
def _max_list_lengths_by_column(self) -> dict[str, int]:
list_columns = [
col
for col in self._other_common_columns
if isinstance(self.left_schema[col], pl.List)
and isinstance(self.right_schema[col], pl.List)
]
if not list_columns:
"""Max list length across all nesting levels, for columns where both sides
contain a List anywhere in their type tree."""
left_exprs: list[pl.Expr] = []
right_exprs: list[pl.Expr] = []
columns: list[str] = []

for col in self._other_common_columns:
col_left = _list_length_exprs(pl.col(col), self.left_schema[col])
col_right = _list_length_exprs(pl.col(col), self.right_schema[col])
if not (col_left and col_right):
continue
columns.append(col)
left_exprs.append(pl.max_horizontal(col_left).alias(col))
right_exprs.append(pl.max_horizontal(col_right).alias(col))

if not columns:
return {}

exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
[left_max, right_max] = pl.collect_all(
[self.left.select(exprs), self.right.select(exprs)]
[self.left.select(left_exprs), self.right.select(right_exprs)]
)
return {
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
for col in list_columns
for col in columns
}

def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
Expand Down Expand Up @@ -833,3 +841,21 @@ def right_only(self) -> Schema:
"""Columns that are only present in the right data frame, mapped to their data
types."""
return self.right() - self.left()


def _list_length_exprs(
expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass
) -> list[pl.Expr]:
"""Collect max-list-length scalar expressions for every List level in the type
tree."""
if isinstance(dtype, pl.List):
return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)]
Comment thread
MariusMerkleQC marked this conversation as resolved.
if isinstance(dtype, pl.Array):
return _list_length_exprs(expr.explode(), dtype.inner)
if isinstance(dtype, pl.Struct):
return [
e
for field in dtype.fields
for e in _list_length_exprs(expr.struct[field.name], field.dtype)
]
return []
61 changes: 52 additions & 9 deletions tests/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def test_condition_equal_columns_list_array_with_tolerance(
schema={"pk": pl.Int64, "a_right": rhs_type},
)

max_list_length: int | None = None
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
max_list_length = 2

# Act
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
Expand All @@ -112,7 +116,7 @@ def test_condition_equal_columns_list_array_with_tolerance(
dtype_right=rhs.schema["a_right"],
abs_tol=0.5,
rel_tol=0,
max_list_length=2,
max_list_length=max_list_length,
)
)
.to_series()
Expand Down Expand Up @@ -156,6 +160,10 @@ def test_condition_equal_columns_nested_list_array_with_tolerance(
schema={"pk": pl.Int64, "a_right": rhs_type},
)

max_list_length: int | None = None
if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
max_list_length = 3

# Act
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
Expand All @@ -166,16 +174,13 @@ def test_condition_equal_columns_nested_list_array_with_tolerance(
dtype_right=rhs.schema["a_right"],
abs_tol=0.5,
rel_tol=0,
max_list_length=2,
max_list_length=max_list_length,
)
)
.to_series()
)

if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
assert actual.to_list() == [True, False, False]
else:
assert actual.to_list() == [True, True, False]
assert actual.to_list() == [True, True, False]
Comment thread
MariusMerkleQC marked this conversation as resolved.


def test_condition_equal_columns_nested_dtype_mismatch() -> None:
Expand All @@ -201,7 +206,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None:
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
max_list_length=2,
)
)
.to_series()
Expand Down Expand Up @@ -341,7 +346,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
max_list_length=2,
abs_tol=0.5,
rel_tol=0,
)
Expand Down Expand Up @@ -406,21 +411,59 @@ def test_condition_equal_columns_empty_list_array(
schema={"pk": pl.Int64, "a_right": rhs_type},
)

max_list_length: int | None = None
if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List):
max_list_length = 0

actual = (
lhs.join(rhs, on="pk", maintain_order="left")
.select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
max_list_length=max_list_length,
)
)
.to_series()
)
assert actual.to_list() == [True, True]


def test_condition_equal_columns_lists_only_inner() -> None:
# Arrange
lhs = pl.DataFrame(
{
"pk": [1, 2],
"a_left": [{"x": 1, "y": [1.0, 2.0, 3.0]}, {"x": 2, "y": [4.0, 5.0, 6.0]}],
},
)
rhs = pl.DataFrame(
{
"pk": [1, 2],
"a_right": [{"x": 1, "y": [1.0, 2.1, 3.0]}, {"x": 2, "y": [4.0, 5.3, 6.0]}],
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
},
)
Comment thread
MariusMerkleQC marked this conversation as resolved.

# Act
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
.select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=3,
abs_tol=0.2,
)
)
.to_series()
)

# Assert
assert actual.to_list() == [True, False]


@pytest.mark.parametrize(
("dtype_left", "dtype_right", "can_compare_dtypes"),
[
Expand Down
Loading