diff --git a/python/benchmarks/bench_eval_type.py b/python/benchmarks/bench_eval_type.py index c75e4490d1ed6..d053623e2373e 100644 --- a/python/benchmarks/bench_eval_type.py +++ b/python/benchmarks/bench_eval_type.py @@ -26,9 +26,11 @@ import io import os import json +import socket import struct import sys import tempfile +import threading from typing import Any, Callable, Iterator import numpy as np @@ -1933,3 +1935,173 @@ class WindowAggPandasUDFTimeBench(_WindowAggPandasBenchMixin, _TimeBenchBase): class WindowAggPandasUDFPeakmemBench(_WindowAggPandasBenchMixin, _PeakmemBenchBase): pass + + +# -- SQL_TRANSFORM_WITH_STATE_PANDAS_UDF --------------------------------------- +# Stateful streaming with Pandas. UDF signature is +# ``(api_client, mode, key, pdfs)`` and returns ``Iterator[pandas.DataFrame]``. +# The input wire stream is a single plain Arrow stream pre-sorted by the +# grouping key column at offset 0; ``TransformWithStateInPandasSerializer`` +# chunks rows into one ``(mode, key, pdfs)`` tuple per group, then emits a +# phantom ``PROCESS_TIMER`` and ``COMPLETE`` call with an empty pdf iterator. +# ``StatefulProcessorApiClient.__init__`` opens a real TCP socket to the JVM +# state server; the stub listener below satisfies that connect. The benchmark +# UDFs never invoke any state API method, so no protocol exchange is needed. + + +class _StubStateServer: + """Stub TCP listener so ``StatefulProcessorApiClient`` init succeeds. + + One instance per benchmark process; the port is reused across all scenarios + and ASV iterations. The accept loop stashes connections to keep them open + until the worker process tears them down (the worker never closes its end + explicitly, but Python GCs the socket on ``main`` return). + """ + + _instance: "_StubStateServer | None" = None + + @classmethod + def get_port(cls) -> int: + if cls._instance is None: + cls._instance = cls() + return cls._instance.port + + def __init__(self) -> None: + self._sock = socket.socket() + self._sock.bind(("127.0.0.1", 0)) + self._sock.listen(128) + self.port = self._sock.getsockname()[1] + self._connections: list[socket.socket] = [] + self._thread = threading.Thread(target=self._accept_loop, daemon=True) + self._thread.start() + + def _accept_loop(self) -> None: + while True: + try: + conn, _ = self._sock.accept() + except OSError: + break + self._connections.append(conn) + + +class _TransformWithStatePandasBenchMixin: + """Provides ``_write_scenario`` for SQL_TRANSFORM_WITH_STATE_PANDAS_UDF. + + Each scenario emits one plain Arrow stream pre-sorted by the leading int + key column. UDFs receive an iterator of value-only Pandas DataFrames per + group plus phantom ``PROCESS_TIMER``/``COMPLETE`` calls (empty iterator). + """ + + # Each scenario: (num_groups, rows_per_group, num_value_cols). + # Row counts are scaled so identity_udf (full pdf passthrough -> ~equal + # input and output volume) stays under ASV's 60s per-sample timeout. + _scenario_configs = { + "few_groups_sm": (50, 5_000, 5), + "few_groups_lg": (50, 50_000, 5), + "many_groups_sm": (2_000, 500, 5), + "many_groups_lg": (500, 2_000, 5), + "wide_cols": (200, 5_000, 20), + } + + @staticmethod + def _build_scenario(name): + """Build a single TWS Pandas scenario. + + Returns ``(batches, schema)`` where ``batches`` is a plain list of Arrow + RecordBatches with rows pre-sorted by the leading int32 key column. + """ + np.random.seed(42) + num_groups, rows_per_group, num_value_cols = ( + _TransformWithStatePandasBenchMixin._scenario_configs[name] + ) + total_rows = num_groups * rows_per_group + key_array = pa.array( + np.repeat(np.arange(num_groups, dtype=np.int32), rows_per_group), + type=pa.int32(), + ) + value_pool = MockDataFactory.NUMERIC_TYPES + value_arrays = [ + value_pool[i % len(value_pool)][0](total_rows) for i in range(num_value_cols) + ] + names = ["col_0"] + [f"col_{i + 1}" for i in range(num_value_cols)] + full_batch = pa.RecordBatch.from_arrays([key_array] + value_arrays, names=names) + batch_size = MockDataFactory.MAX_RECORDS_PER_BATCH + batches = [ + full_batch.slice(offset, min(batch_size, total_rows - offset)) + for offset in range(0, total_rows, batch_size) + ] + schema = StructType( + [StructField("col_0", IntegerType())] + + [ + StructField(f"col_{i + 1}", value_pool[i % len(value_pool)][1]) + for i in range(num_value_cols) + ] + ) + return batches, schema + + def _tws_pandas_identity(api_client, mode, key, pdfs): + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPandasFuncMode, + ) + + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + yield from pdfs + + def _tws_pandas_sort(api_client, mode, key, pdfs): + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPandasFuncMode, + ) + + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + for pdf in pdfs: + yield pdf.sort_values(pdf.columns[0]) + + def _tws_pandas_count(api_client, mode, key, pdfs): + import pandas as pd + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPandasFuncMode, + ) + + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + total = sum(len(pdf) for pdf in pdfs) + yield pd.DataFrame({"col_1": [total]}) + + # ret_type=None means "use all value columns of the input schema". + _udfs = { + "identity_udf": (_tws_pandas_identity, None), + "sort_udf": (_tws_pandas_sort, None), + "count_udf": (_tws_pandas_count, StructType([StructField("col_1", IntegerType())])), + } + params = [list(_scenario_configs), list(_udfs)] + param_names = ["scenario", "udf"] + + _NUM_KEY_COLS = 1 + + def _write_scenario(self, scenario, udf_name, buf): + batches, schema = self._build_scenario(scenario) + udf_func, ret_type = self._udfs[udf_name] + if ret_type is None: + ret_type = StructType(schema.fields[self._NUM_KEY_COLS :]) + n_value_cols = len(schema.fields) - self._NUM_KEY_COLS + arg_offsets = MockUDFFactory.make_grouped_arg_offsets(self._NUM_KEY_COLS, n_value_cols) + grouping_key_schema = StructType(schema.fields[: self._NUM_KEY_COLS]) + MockProtocolWriter.write_worker_input( + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, + lambda b: MockProtocolWriter.write_udf_payload(udf_func, ret_type, arg_offsets, b), + lambda b: MockProtocolWriter.write_data_payload(iter(batches), b), + buf, + eval_conf={ + "state_server_socket_port": str(_StubStateServer.get_port()), + "grouping_key_schema": grouping_key_schema.json(), + }, + ) + + +class TransformWithStatePandasUDFTimeBench(_TransformWithStatePandasBenchMixin, _TimeBenchBase): + pass + + +class TransformWithStatePandasUDFPeakmemBench( + _TransformWithStatePandasBenchMixin, _PeakmemBenchBase +): + pass