Skip to content

Commit f6ba18c

Browse files
paddymulclaude
andcommitted
fix: drop _to_python_native, let pyarrow handle numpy scalars natively
- pyarrow already handles numpy scalars (float64, int64, bool_, nan) - Replace _to_python_native with _is_complex_for_parquet check - Fix pd.Series.to_dict() crash on unhashable values (fall back to to_list) - Update _resolve_all_stats test helpers to handle wide-column format Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4e5cc8e commit f6ba18c

4 files changed

Lines changed: 111 additions & 45 deletions

File tree

buckaroo/serialization_utils.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -268,24 +268,16 @@ def _json_encode_cell(val):
268268
return json.dumps(_make_json_safe(val), default=str)
269269

270270

271-
def _to_python_native(val):
272-
"""Convert numpy scalars to Python builtins for pyarrow."""
271+
def _is_complex_for_parquet(val):
272+
"""Return True if val needs JSON encoding for parquet (not a scalar)."""
273273
import numpy as np
274-
if isinstance(val, np.bool_):
275-
return bool(val)
276-
if isinstance(val, np.integer):
277-
return int(val)
278-
if isinstance(val, np.floating):
279-
if np.isnan(val):
280-
return None
281-
return float(val)
282-
if isinstance(val, float) and np.isnan(val):
283-
return None
284-
if isinstance(val, np.ndarray):
285-
return val.tolist()
286274
if isinstance(val, pd.Series):
287-
return val.to_dict()
288-
return val
275+
return True
276+
if isinstance(val, np.ndarray):
277+
return True
278+
if isinstance(val, (list, dict, tuple)):
279+
return True
280+
return False
289281

290282

291283
def sd_to_parquet_b64(sd: Dict[str, Any]) -> Dict[str, str]:
@@ -310,8 +302,12 @@ def sd_to_parquet_b64(sd: Dict[str, Any]) -> Dict[str, str]:
310302
continue
311303
for stat_name, val in stats.items():
312304
parquet_col = f"{short_col}__{stat_name}"
313-
val = _to_python_native(val)
314-
if isinstance(val, (list, dict, tuple)):
305+
if isinstance(val, pd.Series):
306+
try:
307+
val = val.to_dict()
308+
except TypeError:
309+
val = val.to_list()
310+
if _is_complex_for_parquet(val):
315311
val = json.dumps(_make_json_safe(val), default=str)
316312
wide_data[parquet_col] = [val]
317313

tests/unit/lazy_infinite_polars_widget_test.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import polars as pl
2-
import pandas as pd
32
import base64
43
from io import BytesIO
54
import json
@@ -13,14 +12,49 @@
1312

1413

1514
def _resolve_all_stats(all_stats):
16-
"""Resolve all_stats to a list of row dicts, whether it's JSON or parquet_b64."""
15+
"""Resolve all_stats to a list of row dicts, whether it's JSON or parquet_b64.
16+
17+
Handles both old row-based and new wide-column (col__stat) formats.
18+
"""
1719
if isinstance(all_stats, list):
1820
return all_stats
1921
if isinstance(all_stats, dict) and all_stats.get('format') == 'parquet_b64':
22+
import pyarrow.parquet as pq
2023
raw = base64.b64decode(all_stats['data'])
21-
df = pd.read_parquet(BytesIO(raw), engine='pyarrow')
24+
table = pq.read_table(BytesIO(raw))
25+
col_names = table.column_names
26+
27+
# Detect wide format: column names contain '__'
28+
if any('__' in c for c in col_names):
29+
row_dict = table.to_pydict()
30+
stat_cols = {}
31+
all_cols = set()
32+
for key in col_names:
33+
sep = key.index('__')
34+
col, stat = key[:sep], key[sep+2:]
35+
all_cols.add(col)
36+
if stat not in stat_cols:
37+
stat_cols[stat] = {}
38+
val = row_dict[key][0]
39+
if isinstance(val, str):
40+
try:
41+
parsed = json.loads(val)
42+
if isinstance(parsed, (list, dict)):
43+
val = parsed
44+
except (json.JSONDecodeError, ValueError):
45+
pass
46+
stat_cols[stat][col] = val
47+
rows = []
48+
for stat, cols in stat_cols.items():
49+
row = {'index': stat, 'level_0': stat}
50+
for c in sorted(all_cols):
51+
row[c] = cols.get(c)
52+
rows.append(row)
53+
return rows
54+
55+
# Old row-based format fallback
56+
df = table.to_pandas()
2257
rows = json.loads(df.to_json(orient='records'))
23-
# JSON-parse each cell (they were JSON-encoded on the Python side)
2458
parsed_rows = []
2559
for row in rows:
2660
parsed = {}

tests/unit/polars_basic_widget_test.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import polars as pl
66
from polars import functions as F
77
import numpy as np
8-
import pandas as pd
98
from buckaroo.pluggable_analysis_framework.polars_analysis_management import (
109
PolarsAnalysis, polars_produce_series_df)
1110
from buckaroo.pluggable_analysis_framework.col_analysis import (
@@ -19,12 +18,48 @@
1918

2019

2120
def _resolve_all_stats(all_stats):
22-
"""Resolve all_stats to a list of row dicts, whether it's JSON or parquet_b64."""
21+
"""Resolve all_stats to a list of row dicts, whether it's JSON or parquet_b64.
22+
23+
Handles both old row-based and new wide-column (col__stat) formats.
24+
"""
2325
if isinstance(all_stats, list):
2426
return all_stats
2527
if isinstance(all_stats, dict) and all_stats.get('format') == 'parquet_b64':
28+
import pyarrow.parquet as pq
2629
raw = base64.b64decode(all_stats['data'])
27-
df = pd.read_parquet(BytesIO(raw), engine='pyarrow')
30+
table = pq.read_table(BytesIO(raw))
31+
col_names = table.column_names
32+
33+
# Detect wide format: column names contain '__'
34+
if any('__' in c for c in col_names):
35+
row_dict = table.to_pydict()
36+
stat_cols = {} # stat -> {col -> value}
37+
all_cols = set()
38+
for key in col_names:
39+
sep = key.index('__')
40+
col, stat = key[:sep], key[sep+2:]
41+
all_cols.add(col)
42+
if stat not in stat_cols:
43+
stat_cols[stat] = {}
44+
val = row_dict[key][0]
45+
if isinstance(val, str):
46+
try:
47+
parsed = json.loads(val)
48+
if isinstance(parsed, (list, dict)):
49+
val = parsed
50+
except (json.JSONDecodeError, ValueError):
51+
pass
52+
stat_cols[stat][col] = val
53+
rows = []
54+
for stat, cols in stat_cols.items():
55+
row = {'index': stat, 'level_0': stat}
56+
for c in sorted(all_cols):
57+
row[c] = cols.get(c)
58+
rows.append(row)
59+
return rows
60+
61+
# Old row-based format fallback
62+
df = table.to_pandas()
2863
rows = json.loads(df.to_json(orient='records'))
2964
parsed_rows = []
3065
for row in rows:

tests/unit/test_sd_to_parquet_b64.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pandas as pd
1212
import pyarrow.parquet as pq
1313

14-
from buckaroo.serialization_utils import sd_to_parquet_b64, _to_python_native
14+
from buckaroo.serialization_utils import sd_to_parquet_b64
1515

1616

1717
def _decode_parquet_b64(result):
@@ -144,14 +144,14 @@ def test_sd_to_parquet_b64_multiple_columns():
144144
assert row['b__dtype'] == ['int64']
145145

146146

147-
def test_sd_to_parquet_b64_nan_becomes_null():
148-
"""NaN values should become parquet nulls."""
147+
def test_sd_to_parquet_b64_nan_preserved():
148+
"""NaN values should survive the parquet round-trip."""
149149
sd = {'col': {'mean': np.nan, 'dtype': 'float64'}}
150150
result = sd_to_parquet_b64(sd)
151151
table = _decode_parquet_b64(result)
152152
row = table.to_pydict()
153153

154-
assert row['a__mean'] == [None]
154+
assert np.isnan(row['a__mean'][0])
155155
assert row['a__dtype'] == ['float64']
156156

157157

@@ -173,22 +173,23 @@ def test_sd_to_parquet_b64_value_counts_series():
173173
assert parsed == {'foo': 10, 'bar': 5}
174174

175175

176-
def test_to_python_native_conversions():
177-
assert _to_python_native(np.float64(3.14)) == 3.14
178-
assert isinstance(_to_python_native(np.float64(3.14)), float)
179-
180-
assert _to_python_native(np.int64(42)) == 42
181-
assert isinstance(_to_python_native(np.int64(42)), int)
182-
183-
assert _to_python_native(np.bool_(True)) is True
184-
assert isinstance(_to_python_native(np.bool_(True)), bool)
185-
186-
assert _to_python_native(np.nan) is None
187-
188-
arr = np.array([1, 2, 3])
189-
assert _to_python_native(arr) == [1, 2, 3]
176+
def test_numpy_scalars_handled_natively_by_pyarrow():
177+
"""pyarrow handles numpy scalars without manual conversion."""
178+
sd = {
179+
'col': {
180+
'mean': np.float64(3.14),
181+
'count': np.int64(42),
182+
'is_numeric': np.bool_(True),
183+
'nan_val': np.nan,
184+
},
185+
}
186+
result = sd_to_parquet_b64(sd)
187+
table = _decode_parquet_b64(result)
188+
row = table.to_pydict()
190189

191-
assert _to_python_native("hello") == "hello"
192-
assert _to_python_native(None) is None
190+
assert row['a__mean'] == [3.14]
191+
assert row['a__count'] == [42]
192+
assert row['a__is_numeric'] == [True]
193+
assert np.isnan(row['a__nan_val'][0])
193194

194195

0 commit comments

Comments
 (0)