Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions python/benchmarks/bench_eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import io
import os
import json
import socket
import struct
import sys
import tempfile
import threading
from typing import Any, Callable, Iterator

import numpy as np
Expand Down Expand Up @@ -1933,3 +1935,173 @@ class WindowAggPandasUDFTimeBench(_WindowAggPandasBenchMixin, _TimeBenchBase):

class WindowAggPandasUDFPeakmemBench(_WindowAggPandasBenchMixin, _PeakmemBenchBase):
pass


# -- SQL_TRANSFORM_WITH_STATE_PANDAS_UDF ---------------------------------------
# Stateful streaming with Pandas. UDF signature is
# ``(api_client, mode, key, pdfs)`` and returns ``Iterator[pandas.DataFrame]``.
# The input wire stream is a single plain Arrow stream pre-sorted by the
# grouping key column at offset 0; ``TransformWithStateInPandasSerializer``
# chunks rows into one ``(mode, key, pdfs)`` tuple per group, then emits a
# phantom ``PROCESS_TIMER`` and ``COMPLETE`` call with an empty pdf iterator.
# ``StatefulProcessorApiClient.__init__`` opens a real TCP socket to the JVM
# state server; the stub listener below satisfies that connect. The benchmark
# UDFs never invoke any state API method, so no protocol exchange is needed.


class _StubStateServer:
"""Stub TCP listener so ``StatefulProcessorApiClient`` init succeeds.

One instance per benchmark process; the port is reused across all scenarios
and ASV iterations. The accept loop stashes connections to keep them open
until the worker process tears them down (the worker never closes its end
explicitly, but Python GCs the socket on ``main`` return).
"""

_instance: "_StubStateServer | None" = None

@classmethod
def get_port(cls) -> int:
if cls._instance is None:
cls._instance = cls()
return cls._instance.port

def __init__(self) -> None:
self._sock = socket.socket()
self._sock.bind(("127.0.0.1", 0))
self._sock.listen(128)
self.port = self._sock.getsockname()[1]
self._connections: list[socket.socket] = []
self._thread = threading.Thread(target=self._accept_loop, daemon=True)
self._thread.start()

def _accept_loop(self) -> None:
while True:
try:
conn, _ = self._sock.accept()
except OSError:
break
self._connections.append(conn)


class _TransformWithStatePandasBenchMixin:
"""Provides ``_write_scenario`` for SQL_TRANSFORM_WITH_STATE_PANDAS_UDF.

Each scenario emits one plain Arrow stream pre-sorted by the leading int
key column. UDFs receive an iterator of value-only Pandas DataFrames per
group plus phantom ``PROCESS_TIMER``/``COMPLETE`` calls (empty iterator).
"""

# Each scenario: (num_groups, rows_per_group, num_value_cols).
# Row counts are scaled so identity_udf (full pdf passthrough -> ~equal
# input and output volume) stays under ASV's 60s per-sample timeout.
_scenario_configs = {
"few_groups_sm": (50, 5_000, 5),
"few_groups_lg": (50, 50_000, 5),
"many_groups_sm": (2_000, 500, 5),
"many_groups_lg": (500, 2_000, 5),
"wide_cols": (200, 5_000, 20),
}

@staticmethod
def _build_scenario(name):
"""Build a single TWS Pandas scenario.

Returns ``(batches, schema)`` where ``batches`` is a plain list of Arrow
RecordBatches with rows pre-sorted by the leading int32 key column.
"""
np.random.seed(42)
num_groups, rows_per_group, num_value_cols = (
_TransformWithStatePandasBenchMixin._scenario_configs[name]
)
total_rows = num_groups * rows_per_group
key_array = pa.array(
np.repeat(np.arange(num_groups, dtype=np.int32), rows_per_group),
type=pa.int32(),
)
value_pool = MockDataFactory.NUMERIC_TYPES
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure you don't want to add non-numeric types? Maybe some cases with nested types (arrays, structs, etc)?

value_arrays = [
value_pool[i % len(value_pool)][0](total_rows) for i in range(num_value_cols)
]
names = ["col_0"] + [f"col_{i + 1}" for i in range(num_value_cols)]
full_batch = pa.RecordBatch.from_arrays([key_array] + value_arrays, names=names)
batch_size = MockDataFactory.MAX_RECORDS_PER_BATCH
batches = [
full_batch.slice(offset, min(batch_size, total_rows - offset))
for offset in range(0, total_rows, batch_size)
]
schema = StructType(
[StructField("col_0", IntegerType())]
+ [
StructField(f"col_{i + 1}", value_pool[i % len(value_pool)][1])
for i in range(num_value_cols)
]
)
return batches, schema

def _tws_pandas_identity(api_client, mode, key, pdfs):
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)

if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
yield from pdfs

def _tws_pandas_sort(api_client, mode, key, pdfs):
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)

if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
for pdf in pdfs:
yield pdf.sort_values(pdf.columns[0])

def _tws_pandas_count(api_client, mode, key, pdfs):
import pandas as pd
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)

if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
total = sum(len(pdf) for pdf in pdfs)
yield pd.DataFrame({"col_1": [total]})

# ret_type=None means "use all value columns of the input schema".
_udfs = {
"identity_udf": (_tws_pandas_identity, None),
"sort_udf": (_tws_pandas_sort, None),
"count_udf": (_tws_pandas_count, StructType([StructField("col_1", IntegerType())])),
}
params = [list(_scenario_configs), list(_udfs)]
param_names = ["scenario", "udf"]

_NUM_KEY_COLS = 1

def _write_scenario(self, scenario, udf_name, buf):
batches, schema = self._build_scenario(scenario)
udf_func, ret_type = self._udfs[udf_name]
if ret_type is None:
ret_type = StructType(schema.fields[self._NUM_KEY_COLS :])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we typically see the keys included in the output schema for transform with state.

n_value_cols = len(schema.fields) - self._NUM_KEY_COLS
arg_offsets = MockUDFFactory.make_grouped_arg_offsets(self._NUM_KEY_COLS, n_value_cols)
grouping_key_schema = StructType(schema.fields[: self._NUM_KEY_COLS])
MockProtocolWriter.write_worker_input(
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
lambda b: MockProtocolWriter.write_udf_payload(udf_func, ret_type, arg_offsets, b),
lambda b: MockProtocolWriter.write_data_payload(iter(batches), b),
buf,
eval_conf={
"state_server_socket_port": str(_StubStateServer.get_port()),
"grouping_key_schema": grouping_key_schema.json(),
},
)


class TransformWithStatePandasUDFTimeBench(_TransformWithStatePandasBenchMixin, _TimeBenchBase):
pass


class TransformWithStatePandasUDFPeakmemBench(
_TransformWithStatePandasBenchMixin, _PeakmemBenchBase
):
pass