Skip to content

Commit 64e2b02

Browse files
More permissive statement parameters
1 parent 373fce7 commit 64e2b02

3 files changed

Lines changed: 113 additions & 16 deletions

File tree

examples/client_usage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def example_prepared_statement():
9898
# Execute with different parameters
9999
for user_id in [1, 2, 3]:
100100
print(f"Fetching user {user_id}...")
101-
reader = stmt.query(parameters={"id": user_id})
101+
reader = stmt.query(parameters=[user_id])
102102
for batch in reader:
103103
print(batch.data.to_pandas())
104104

src/altertable_flightsql/client.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
This module provides a high-level Python client for Altertable.
55
"""
66

7-
from collections.abc import Mapping
8-
from typing import Any, Optional
7+
from collections.abc import Mapping, Sequence
8+
from typing import Any, Optional, Union
99

1010
import pyarrow as pa
1111
import pyarrow.flight as flight
@@ -271,7 +271,17 @@ def prepare(
271271
# Parse result
272272
result = sql_pb2.ActionCreatePreparedStatementResult()
273273
_unpack_command(results[0].body.to_pybytes(), result)
274-
return PreparedStatement(self._client, result.prepared_statement_handle)
274+
275+
# Extract parameter schema if available
276+
parameter_schema = None
277+
if result.parameter_schema:
278+
parameter_schema = pa.ipc.read_schema(pa.py_buffer(result.parameter_schema))
279+
280+
return PreparedStatement(
281+
self._client,
282+
result.prepared_statement_handle,
283+
parameter_schema=parameter_schema
284+
)
275285

276286
def get_catalogs(self) -> flight.FlightStreamReader:
277287
"""
@@ -445,42 +455,63 @@ class PreparedStatement:
445455
Prepared statements can be executed multiple times with different parameters.
446456
"""
447457

448-
def __init__(self, client: flight.FlightClient, handle: bytes):
458+
def __init__(
459+
self,
460+
client: flight.FlightClient,
461+
handle: bytes,
462+
parameter_schema: Optional[pa.Schema] = None
463+
):
449464
"""
450465
Initialize a prepared statement.
451466
452467
Args:
453468
client: FlightClient instance.
454469
handle: Prepared statement handle from server.
470+
parameter_schema: Optional parameter schema for the prepared statement.
455471
"""
456472
self._client = client
457473
self._handle = handle
474+
self._parameter_schema = parameter_schema
458475

459476
def query(
460477
self,
461478
*,
462-
parameters: Optional[Mapping[str, Any]] = None,
479+
parameters: Optional[Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]] = None,
463480
) -> flight.FlightStreamReader:
464481
"""
465482
Execute the prepared statement query.
466483
467484
Args:
468-
parameters: Optional RecordBatch containing parameter values.
485+
parameters: Optional parameters for the query. Can be:
486+
- pyarrow.Table: A table of parameter values
487+
- pyarrow.RecordBatch: A batch of parameter values
488+
- Mapping[str, Any]: A dictionary mapping parameter names to values
489+
- Sequence[Any]: A list of positional parameter values
469490
470491
Returns:
471492
FlightStreamReader with query results.
493+
494+
Example:
495+
>>> # Using a dictionary
496+
>>> stmt.query(parameters={"id": 42, "name": "Alice"})
497+
498+
>>> # Using a list
499+
>>> stmt.query(parameters=[42, "Alice"])
500+
501+
>>> # Using a RecordBatch
502+
>>> batch = pa.record_batch({"id": [42], "name": ["Alice"]})
503+
>>> stmt.query(parameters=batch)
472504
"""
473505
cmd = sql_pb2.CommandPreparedStatementQuery()
474506
cmd.prepared_statement_handle = self._handle
475507

476508
descriptor = flight.FlightDescriptor.for_command(_pack_command(cmd))
477509
info = self._client.get_flight_info(descriptor)
478510

479-
# If parameters are provided, send them via DoPut
480-
if parameters:
481-
record_batch = pa.record_batch({key: [value] for (key, value) in parameters.items()})
482-
writer, _ = self._client.do_put(descriptor, record_batch.schema)
483-
writer.write_batch(record_batch)
511+
if parameters is not None:
512+
as_pyarrow = self._get_parameter_as_pyarrow(parameters)
513+
writer, _ = self._client.do_put(descriptor, as_pyarrow.schema)
514+
writer.write(as_pyarrow)
484515
writer.close()
485516

486517
return self._client.do_get(info.endpoints[0].ticket)
@@ -501,3 +532,35 @@ def __enter__(self) -> "PreparedStatement":
501532
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
502533
"""Context manager exit."""
503534
self.close()
535+
536+
def _get_parameter_as_pyarrow(self, parameters: Union[pa.Table, pa.RecordBatch, Mapping[str, Any], Sequence[Any]]) -> Union[pa.Table, pa.RecordBatch]:
537+
if isinstance(parameters, pa.Table):
538+
return parameters
539+
elif isinstance(parameters, pa.RecordBatch):
540+
return parameters
541+
elif isinstance(parameters, Mapping):
542+
return pa.record_batch({key: [value] for (key, value) in parameters.items()})
543+
elif isinstance(parameters, Sequence):
544+
if self._parameter_schema is None:
545+
raise ValueError(
546+
"Cannot use positional parameters without parameter schema. "
547+
"Use a dictionary (Mapping[str, Any]) instead."
548+
)
549+
550+
# Create record batch with positional parameters
551+
if len(parameters) != len(self._parameter_schema):
552+
raise ValueError(
553+
f"Expected {len(self._parameter_schema)} parameters, "
554+
f"but got {len(parameters)}"
555+
)
556+
param_dict = {
557+
field.name: [value]
558+
for field, value in zip(self._parameter_schema, parameters)
559+
}
560+
561+
return pa.record_batch(param_dict)
562+
else:
563+
raise TypeError(
564+
f"Unsupported parameter type: {type(parameters)}. "
565+
"Expected Table, RecordBatch, Mapping, or Sequence."
566+
)

tests/test_queries.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Tests basic query execution, updates, and prepared statements.
55
"""
66

7+
import pyarrow as pa
78
import pytest
89

910
from altertable_flightsql import Client
@@ -56,14 +57,47 @@ def test_empty_result_set(self, altertable_client: Client, test_table: TableInfo
5657
class TestPreparedStatements:
5758
"""Test prepared statement functionality."""
5859

59-
def test_prepare_and_execute(self, altertable_client: Client, test_table: TableInfo):
60-
"""Test creating and executing a prepared statement."""
60+
def test_prepare_with_dict_parameters(self, altertable_client: Client, test_table: TableInfo):
61+
"""Test prepared statement with dict parameters."""
6162
# Prepare statement
6263
with altertable_client.prepare(
63-
f"SELECT * FROM {test_table.full_name} WHERE id = $id"
64+
f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value"
6465
) as stmt:
6566
# Execute prepared statement
66-
reader = stmt.query(parameters={"id": 1})
67+
reader = stmt.query(parameters={"id": 1, "min_value": 100})
68+
table = reader.read_all()
69+
assert table.num_rows > 0
70+
71+
def test_prepare_with_list_parameters(self, altertable_client: Client, test_table: TableInfo):
72+
"""Test prepared statement with list parameters."""
73+
with altertable_client.prepare(
74+
f"SELECT * FROM {test_table.full_name} WHERE id = ? AND value >= ?"
75+
) as stmt:
76+
reader = stmt.query(parameters=[1, 100])
77+
table = reader.read_all()
78+
assert table.num_rows >= 0
79+
80+
def test_prepare_with_record_batch_parameters(
81+
self, altertable_client: Client, test_table: TableInfo
82+
):
83+
"""Test prepared statement with RecordBatch parameters."""
84+
with altertable_client.prepare(
85+
f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value"
86+
) as stmt:
87+
# Create a RecordBatch with parameters
88+
batch = pa.record_batch({"id": [1], "min_value": [100]})
89+
reader = stmt.query(parameters=batch)
90+
table = reader.read_all()
91+
assert table.num_rows > 0
92+
93+
def test_prepare_with_table_parameters(self, altertable_client: Client, test_table: TableInfo):
94+
"""Test prepared statement with Table parameters."""
95+
with altertable_client.prepare(
96+
f"SELECT * FROM {test_table.full_name} WHERE id = $id AND value >= $min_value"
97+
) as stmt:
98+
# Create a Table with parameters
99+
param_table = pa.table({"id": [1], "min_value": [100]})
100+
reader = stmt.query(parameters=param_table)
67101
table = reader.read_all()
68102
assert table.num_rows > 0
69103

0 commit comments

Comments
 (0)