Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,22 @@ def _compare_columns(
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)

if (
isinstance(dtype_left, pl.Enum)
and isinstance(dtype_right, pl.Enum)
and dtype_left != dtype_right
) or _enum_and_categorical(dtype_left, dtype_right):
if _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner):
Comment thread
MariusMerkleQC marked this conversation as resolved.
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)
return col_left.eq_missing(col_right)
Comment thread
MariusMerkleQC marked this conversation as resolved.

if _different_enums(dtype_left, dtype_right) or _enum_and_categorical(
dtype_left, dtype_right
):
# Enums with different categories as well as enums and categoricals
# can't be compared directly.
# Fall back to comparison of strings.
Expand Down Expand Up @@ -237,6 +237,55 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
return _eq_missing(has_same_length & elements_match, col_left, col_right)


def _is_float_numeric_pair(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
return (dtype_left.is_float() or dtype_right.is_float()) and (
dtype_left.is_numeric() and dtype_right.is_numeric()
)


def _is_temporal_pair(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
return dtype_left.is_temporal() and dtype_right.is_temporal()


def _needs_element_wise_comparison(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
"""Check if two dtypes require element-wise comparison (tolerances or special
handling).

Returns False when eq_missing() on the whole column would produce identical results,
allowing us to skip the expensive element-wise iteration for list/array columns.
"""
if _is_float_numeric_pair(dtype_left, dtype_right):
return True
if _is_temporal_pair(dtype_left, dtype_right):
return True
if _different_enums(dtype_left, dtype_right) or _enum_and_categorical(
dtype_left, dtype_right
):
return True
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
if isinstance(dtype_left, pl.Struct) and isinstance(dtype_right, pl.Struct):
fields_left = {f.name: f.dtype for f in dtype_left.fields}
fields_right = {f.name: f.dtype for f in dtype_right.fields}
return any(
_needs_element_wise_comparison(fields_left[name], fields_right[name])
for name in fields_left
if name in fields_right
)
if isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
return _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner)
return False


def _compare_primitive_columns(
col_left: pl.Expr,
col_right: pl.Expr,
Expand All @@ -246,13 +295,11 @@ def _compare_primitive_columns(
rel_tol: float,
abs_tol_temporal: dt.timedelta,
) -> pl.Expr:
if (dtype_left.is_float() or dtype_right.is_float()) and (
dtype_left.is_numeric() and dtype_right.is_numeric()
):
if _is_float_numeric_pair(dtype_left, dtype_right):
return col_left.is_close(col_right, abs_tol=abs_tol, rel_tol=rel_tol).pipe(
_eq_missing_with_nan, lhs=col_left, rhs=col_right
)
elif dtype_left.is_temporal() and dtype_right.is_temporal():
elif _is_temporal_pair(dtype_left, dtype_right):
diff_less_than_tolerance = (col_left - col_right).abs() <= abs_tol_temporal
return diff_less_than_tolerance.pipe(_eq_missing, lhs=col_left, rhs=col_right)

Expand All @@ -270,6 +317,12 @@ def _eq_missing_with_nan(expr: pl.Expr, lhs: pl.Expr, rhs: pl.Expr) -> pl.Expr:
return _eq_missing(expr, lhs, rhs) | both_nan


def _different_enums(
left: DataType | DataTypeClass, right: DataType | DataTypeClass
) -> bool:
return isinstance(left, pl.Enum) and isinstance(right, pl.Enum) and left != right


def _enum_and_categorical(
left: DataType | DataTypeClass, right: DataType | DataTypeClass
) -> bool:
Expand Down
115 changes: 114 additions & 1 deletion tests/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import polars as pl
import pytest

from diffly._conditions import _can_compare_dtypes, condition_equal_columns
from diffly._conditions import (
_can_compare_dtypes,
_needs_element_wise_comparison,
condition_equal_columns,
)
from diffly.comparison import compare_frames


Expand Down Expand Up @@ -512,6 +516,45 @@ def test_condition_equal_columns_lists_only_inner() -> None:
assert actual.to_list() == [True, False]


def test_condition_equal_columns_list_of_different_enums() -> None:
# Arrange
first_enum = pl.Enum(["one", "two"])
second_enum = pl.Enum(["one", "two", "three"])

lhs = pl.DataFrame(
{"pk": [1, 2], "a": [["one", "two"], ["one", "one"]]},
schema_overrides={"a": pl.List(first_enum)},
)
rhs = pl.DataFrame(
{"pk": [1, 2], "a": [["one", "two"], ["one", "three"]]},
schema_overrides={"a": pl.List(second_enum)},
)
c = compare_frames(lhs, rhs, primary_key="pk")

# Act
lhs = lhs.rename({"a": "a_left"})
rhs = rhs.rename({"a": "a_right"})
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
.select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=c._max_list_lengths_by_column.get("a"),
abs_tol=c.abs_tol_by_column["a"],
rel_tol=c.rel_tol_by_column["a"],
)
)
.to_series()
)

# Assert
assert c._max_list_lengths_by_column == {"a": 2}
assert _needs_element_wise_comparison(first_enum, second_enum)
assert actual.to_list() == [True, False]


@pytest.mark.parametrize(
("dtype_left", "dtype_right", "can_compare_dtypes"),
[
Expand All @@ -534,3 +577,73 @@ def test_can_compare_dtypes(
dtype_left=dtype_left, dtype_right=dtype_right
)
assert can_compare_dtypes_actual == can_compare_dtypes


@pytest.mark.parametrize(
("dtype_left", "dtype_right", "expected"),
[
# Primitives that don't need element-wise comparison
(pl.Int64, pl.Int64, False),
(pl.String, pl.String, False),
(pl.Boolean, pl.Boolean, False),
# Float/numeric pairs
(pl.Float64, pl.Float64, True),
(pl.Int64, pl.Float64, True),
(pl.Float32, pl.Int32, True),
# Temporal pairs
(pl.Datetime, pl.Datetime, True),
(pl.Date, pl.Date, True),
(pl.Datetime, pl.Date, True),
# Enum/categorical
(pl.Enum(["a", "b"]), pl.Enum(["a", "b"]), False),
(pl.Enum(["a", "b"]), pl.Enum(["a", "b", "c"]), True),
(pl.Enum(["a"]), pl.Categorical(), True),
(pl.Categorical(), pl.Enum(["a"]), True),
# Struct with no tolerance-requiring fields
(
pl.Struct({"x": pl.Int64, "y": pl.String}),
pl.Struct({"x": pl.Int64, "y": pl.String}),
False,
),
# Struct with a float field
(
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
True,
),
# Struct with different-category enums
(
pl.Struct({"x": pl.Enum(["a"])}),
pl.Struct({"x": pl.Enum(["b"])}),
True,
),
# List/Array with non-tolerance inner type
(pl.List(pl.Int64), pl.List(pl.Int64), False),
(pl.Array(pl.String, shape=3), pl.Array(pl.String, shape=3), False),
# List/Array with tolerance-requiring inner type
(pl.List(pl.Float64), pl.List(pl.Float64), True),
(pl.Array(pl.Datetime, shape=2), pl.Array(pl.Datetime, shape=2), True),
# Nested: list of structs with a float field
(
pl.List(pl.Struct({"x": pl.Float64})),
pl.List(pl.Struct({"x": pl.Float64})),
True,
),
# Nested: list of structs without tolerance-requiring fields
(
pl.List(pl.Struct({"x": pl.Int64})),
pl.List(pl.Struct({"x": pl.Int64})),
False,
),
# Deeply nested: struct with a list of structs with a float field
(
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
True,
),
],
)
def test_needs_element_wise_comparison(
dtype_left: pl.DataType, dtype_right: pl.DataType, expected: bool
) -> None:
assert _needs_element_wise_comparison(dtype_left, dtype_right) == expected
103 changes: 103 additions & 0 deletions tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import polars as pl

from diffly import compare_frames
from diffly._conditions import condition_equal_columns
from diffly._utils import (
ABS_TOL_DEFAULT,
ABS_TOL_TEMPORAL_DEFAULT,
REL_TOL_DEFAULT,
Side,
)


def test_summary_lazyframe_not_slower_than_dataframe() -> None:
Expand Down Expand Up @@ -74,3 +81,99 @@ def expensive_computation(col: pl.Expr) -> pl.Expr:
f"({mean_time_lf:.3f}s vs {mean_time_df:.3f}s). "
f"This suggests unnecessary re-collection of LazyFrames."
)


def test_eq_missing_not_slower_than_element_wise_for_list_columns() -> None:
"""Ensure that comparing list columns with non-tolerance inner types via
eq_missing() is not slower than the element-wise _compare_sequence_columns()
path."""
n_rows = 500_000
list_len = 20
num_runs_measured = 10
num_runs_warmup = 2

col_left = f"val_{Side.LEFT}"
col_right = f"val_{Side.RIGHT}"
df = pl.DataFrame(
{
col_left: [list(range(list_len)) for _ in range(n_rows)],
col_right: [list(range(list_len)) for _ in range(n_rows)],
}
)

times_eq = []
times_cond = []
for _ in range(num_runs_warmup + num_runs_measured):
start = time.perf_counter()
df.select(pl.col(col_left).eq_missing(pl.col(col_right))).to_series()
times_eq.append(time.perf_counter() - start)

start = time.perf_counter()
df.select(
condition_equal_columns(
column="val",
dtype_left=df.schema[col_left],
dtype_right=df.schema[col_right],
max_list_length=list_len,
abs_tol=ABS_TOL_DEFAULT,
rel_tol=REL_TOL_DEFAULT,
abs_tol_temporal=ABS_TOL_TEMPORAL_DEFAULT,
)
).to_series()
times_cond.append(time.perf_counter() - start)

mean_time_eq = statistics.mean(times_eq[num_runs_warmup:])
mean_time_cond = statistics.mean(times_cond[num_runs_warmup:])

ratio = mean_time_cond / mean_time_eq
assert ratio < 1.25, (
f"condition_equal_columns was {ratio:.1f}x slower than eq_missing "
f"({mean_time_cond:.3f}s vs {mean_time_eq:.3f}s). "
f"Expected comparable performance since list<i64> should use eq_missing directly."
)


def test_eq_missing_not_slower_than_field_wise_for_struct_columns() -> None:
"""Ensure that comparing struct columns with non-tolerance fields via eq_missing()
is not slower than the field-wise decomposition path."""
n_rows = 500_000
n_fields = 20
num_runs_measured = 10
num_runs_warmup = 2

col_left = f"val_{Side.LEFT}"
col_right = f"val_{Side.RIGHT}"
struct_data = [{f"f{i}": row + i for i in range(n_fields)} for row in range(n_rows)]
df = pl.DataFrame({col_left: struct_data, col_right: struct_data})

times_eq = []
times_cond = []
for _ in range(num_runs_warmup + num_runs_measured):
start = time.perf_counter()
df.select(pl.col(col_left).eq_missing(pl.col(col_right))).to_series()
times_eq.append(time.perf_counter() - start)

start = time.perf_counter()
df.select(
condition_equal_columns(
column="val",
dtype_left=df.schema[col_left],
dtype_right=df.schema[col_right],
max_list_length=None,
abs_tol=ABS_TOL_DEFAULT,
rel_tol=REL_TOL_DEFAULT,
abs_tol_temporal=ABS_TOL_TEMPORAL_DEFAULT,
)
).to_series()
times_cond.append(time.perf_counter() - start)

mean_time_eq = statistics.mean(times_eq[num_runs_warmup:])
mean_time_cond = statistics.mean(times_cond[num_runs_warmup:])

ratio = mean_time_cond / mean_time_eq
assert ratio < 1.25, (
f"condition_equal_columns was {ratio:.1f}x slower than eq_missing "
f"({mean_time_cond:.3f}s vs {mean_time_eq:.3f}s). "
f"Expected comparable performance since struct<i64> fields should use "
f"eq_missing directly."
)
Loading