Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions buckaroo/polars_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import polars as pl


# Polars uses "full" instead of "outer"
_HOW_MAP = {"outer": "full", "full": "full", "inner": "inner", "left": "left", "right": "right"}


def col_join_dfs(df1, df2, join_columns, how):
"""Join two Polars DataFrames and compute column-level diff statistics.

Parameters
----------
df1, df2 : pl.DataFrame
The two DataFrames to compare.
join_columns : str or list[str]
Column name(s) to join on.
how : str
Join type ('inner', 'outer', 'left', 'right').

Returns
-------
m_df : pl.DataFrame
Merged DataFrame with membership and equality columns.
column_config_overrides : dict
Buckaroo column config for styling.
eqs : dict
Per-column diff summary.
"""
if isinstance(join_columns, str):
join_columns = [join_columns]

df2_suffix = "|df2"
for col in df1.columns + df2.columns:
if df2_suffix in col:
raise ValueError(
f"|df2 is a sentinel column name used by this tool, "
f"and can't be used in a dataframe passed in, {col} violates that constraint"
)

df1_name, df2_name = "df_1", "df_2"

# Validate join keys are unique to prevent cartesian explosion
if not df1.select(pl.struct(join_columns).is_unique().all()).item():
raise ValueError(
f"Duplicate join keys found in df1 on columns {join_columns}. "
"Join keys must be unique in each dataframe for a valid comparison."
)
if not df2.select(pl.struct(join_columns).is_unique().all()).item():
raise ValueError(
f"Duplicate join keys found in df2 on columns {join_columns}. "
"Join keys must be unique in each dataframe for a valid comparison."
)

pl_how = _HOW_MAP.get(how, how)

# Join with coalesce=False so we can detect membership via null patterns on join keys
m_df = df1.join(df2, on=join_columns, how=pl_how, suffix=df2_suffix, coalesce=False)

# Compute membership from null patterns on the first join key
# left key null => df2 only (2), right key null => df1 only (1), both non-null => both (3)
left_key = join_columns[0]
right_key = f"{left_key}{df2_suffix}"
m_df = m_df.with_columns(
pl.when(pl.col(left_key).is_not_null() & pl.col(right_key).is_not_null())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Track row origin without nullable join keys

membership is derived from whether join_columns[0] and its suffixed counterpart are null, but that logic fails when the first join key itself can be null. In an outer/full join, a df1-only row with a null first key will have both key columns null after the join and gets labeled as 2 (df2-only) instead of 1, which miscolors row provenance and can distort downstream comparisons that depend on membership. Derive origin from explicit per-side marker columns added before the join (or another non-null indicator) rather than key nullness.

Useful? React with 👍 / 👎.

.then(3)
.when(pl.col(left_key).is_not_null())
.then(1)
.otherwise(2)
.cast(pl.Int8)
.alias("membership")
)

# Coalesce join keys and drop suffixed copies
for jc in join_columns:
jc_right = f"{jc}{df2_suffix}"
if jc_right in m_df.columns:
m_df = m_df.with_columns(pl.coalesce(jc, jc_right).alias(jc)).drop(jc_right)

# Build unified column order
col_order = df1.columns.copy()
for col in df2.columns:
if col not in col_order:
col_order.append(col)

# Compute diff stats from key-aligned rows
eqs = {}
both_mask = m_df["membership"] == 3
m_both = m_df.filter(both_mask)
for col in col_order:
if col in join_columns:
eqs[col] = {"diff_count": "join_key"}
elif col in df1.columns and col in df2.columns:
df2_col = f"{col}{df2_suffix}"
if df2_col in m_df.columns:
eqs[col] = {
"diff_count": int(
m_both.select(pl.col(col).ne_missing(pl.col(df2_col)).sum()).item()
)
}
else:
eqs[col] = {"diff_count": 0}
else:
if col in df1.columns:
eqs[col] = {"diff_count": df1_name}
else:
eqs[col] = {"diff_count": df2_name}

column_config_overrides = {}
eq_map = ["pink", "#73ae80", "#90b2b3", "#6c83b5"]

column_config_overrides["membership"] = {"merge_rule": "hidden"}

both_columns = [c for c in m_df.columns if c.endswith(df2_suffix)]
for b_col in both_columns:
a_col = b_col.removesuffix(df2_suffix)
eq_col = f"{a_col}|eq"
m_df = m_df.with_columns(
(pl.col(a_col).eq_missing(pl.col(b_col)).cast(pl.Int8) * 4 + pl.col("membership"))
.alias(eq_col)
)

column_config_overrides[b_col] = {"merge_rule": "hidden"}
column_config_overrides[eq_col] = {"merge_rule": "hidden"}
column_config_overrides[a_col] = {
"tooltip_config": {"tooltip_type": "simple", "val_column": b_col},
"color_map_config": {
"color_rule": "color_categorical",
"map_name": eq_map,
"val_column": eq_col,
},
}

for jc in join_columns:
column_config_overrides[jc] = {
"color_map_config": {
"color_rule": "color_categorical",
"map_name": eq_map,
"val_column": "membership",
}
}

return m_df, column_config_overrides, eqs
144 changes: 144 additions & 0 deletions tests/unit/polars_compare_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import polars as pl
import pytest

from buckaroo.polars_compare import col_join_dfs


def test_single_join_key():
"""col_join_dfs works with a single join key."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 25, 30]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert "membership" in m_df.columns
assert (m_df["membership"] == 3).all()
assert eqs["val"]["diff_count"] == 1
assert eqs["id"]["diff_count"] == "join_key"


def test_multi_key_join():
"""col_join_dfs works with multiple join columns."""
df1 = pl.DataFrame(
{"account_id": [1, 1, 2], "as_of_date": ["2024-01", "2024-02", "2024-01"], "amount": [100, 200, 300]}
)
df2 = pl.DataFrame(
{"account_id": [1, 1, 2], "as_of_date": ["2024-01", "2024-02", "2024-01"], "amount": [100, 250, 300]}
)

m_df, overrides, eqs = col_join_dfs(
df1, df2, join_columns=["account_id", "as_of_date"], how="outer"
)

assert m_df.height == 3
assert (m_df["membership"] == 3).all()
assert eqs["amount"]["diff_count"] == 1
assert eqs["account_id"]["diff_count"] == "join_key"
assert eqs["as_of_date"]["diff_count"] == "join_key"
assert "account_id" in overrides
assert "as_of_date" in overrides


def test_outer_join_membership():
"""Rows only in one side get correct membership values."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [2, 3, 4], "val": [20, 30, 40]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert m_df.height == 4
rows = {row["id"]: row["membership"] for row in m_df.iter_rows(named=True)}
assert rows[1] == 1 # df1 only
assert rows[2] == 3 # both
assert rows[3] == 3 # both
assert rows[4] == 2 # df2 only


def test_reordered_rows():
"""Diff stats are correct even when row order differs."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [3, 1, 2], "val": [30, 10, 20]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert eqs["val"]["diff_count"] == 0
assert (m_df["membership"] == 3).all()


def test_one_sided_extra_columns():
"""Columns only in one df are reported correctly."""
df1 = pl.DataFrame({"id": [1, 2], "x": [10, 20]})
df2 = pl.DataFrame({"id": [1, 2], "y": [30, 40]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert eqs["x"]["diff_count"] == "df_1"
assert eqs["y"]["diff_count"] == "df_2"


def test_string_join_columns_normalized():
"""A single string join_columns is accepted."""
df1 = pl.DataFrame({"key": [1, 2], "val": [10, 20]})
df2 = pl.DataFrame({"key": [1, 2], "val": [10, 25]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns="key", how="inner")

assert eqs["val"]["diff_count"] == 1


def test_sentinel_column_rejected():
"""DataFrames containing '|df2' in column names are rejected."""
df1 = pl.DataFrame({"id": [1], "bad|df2": [10]})
df2 = pl.DataFrame({"id": [1], "val": [20]})

with pytest.raises(ValueError, match="\\|df2"):
col_join_dfs(df1, df2, join_columns=["id"], how="outer")


def test_inner_join():
"""Inner join only keeps matched rows."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [2, 3, 4], "val": [20, 35, 40]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="inner")

assert m_df.height == 2
assert (m_df["membership"] == 3).all()
assert eqs["val"]["diff_count"] == 1


def test_null_values_in_data():
"""Null-heavy comparisons don't crash and report diffs."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [None, 20, None]})
df2 = pl.DataFrame({"id": [1, 2, 3], "val": [None, None, 30]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert (m_df["membership"] == 3).all()
assert eqs["val"]["diff_count"] >= 2


def test_duplicate_join_keys_rejected():
"""Duplicate join keys raise ValueError."""
df1 = pl.DataFrame({"id": [1, 1, 2], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})

with pytest.raises(ValueError, match="Duplicate join keys"):
col_join_dfs(df1, df2, join_columns=["id"], how="outer")

df1_ok = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2_dup = pl.DataFrame({"id": [1, 1, 2], "val": [10, 20, 30]})

with pytest.raises(ValueError, match="Duplicate join keys"):
col_join_dfs(df1_ok, df2_dup, join_columns=["id"], how="outer")


def test_how_outer_alias():
"""Both 'outer' and 'full' are accepted as how values."""
df1 = pl.DataFrame({"id": [1, 2], "val": [10, 20]})
df2 = pl.DataFrame({"id": [2, 3], "val": [20, 30]})

m1, _, _ = col_join_dfs(df1, df2, join_columns=["id"], how="outer")
m2, _, _ = col_join_dfs(df1, df2, join_columns=["id"], how="full")

assert m1.height == m2.height == 3
Loading