44This module provides a high-level Python client for Altertable.
55"""
66
7- from collections .abc import Mapping
8- from typing import Any , Optional
7+ from collections .abc import Mapping , Sequence
8+ from typing import Any , Optional , Union
99
1010import pyarrow as pa
1111import pyarrow .flight as flight
@@ -271,7 +271,15 @@ def prepare(
271271 # Parse result
272272 result = sql_pb2 .ActionCreatePreparedStatementResult ()
273273 _unpack_command (results [0 ].body .to_pybytes (), result )
274- return PreparedStatement (self ._client , result .prepared_statement_handle )
274+
275+ # Extract parameter schema if available
276+ parameter_schema = None
277+ if result .parameter_schema :
278+ parameter_schema = pa .ipc .read_schema (pa .py_buffer (result .parameter_schema ))
279+
280+ return PreparedStatement (
281+ self ._client , result .prepared_statement_handle , parameter_schema = parameter_schema
282+ )
275283
276284 def get_catalogs (self ) -> flight .FlightStreamReader :
277285 """
@@ -445,42 +453,65 @@ class PreparedStatement:
445453 Prepared statements can be executed multiple times with different parameters.
446454 """
447455
448- def __init__ (self , client : flight .FlightClient , handle : bytes ):
456+ def __init__ (
457+ self ,
458+ client : flight .FlightClient ,
459+ handle : bytes ,
460+ parameter_schema : Optional [pa .Schema ] = None ,
461+ ):
449462 """
450463 Initialize a prepared statement.
451464
452465 Args:
453466 client: FlightClient instance.
454467 handle: Prepared statement handle from server.
468+ parameter_schema: Optional parameter schema for the prepared statement.
455469 """
456470 self ._client = client
457471 self ._handle = handle
472+ self ._parameter_schema = parameter_schema
458473
459474 def query (
460475 self ,
461476 * ,
462- parameters : Optional [Mapping [str , Any ]] = None ,
477+ parameters : Optional [
478+ Union [pa .Table , pa .RecordBatch , Mapping [str , Any ], Sequence [Any ]]
479+ ] = None ,
463480 ) -> flight .FlightStreamReader :
464481 """
465482 Execute the prepared statement query.
466483
467484 Args:
468- parameters: Optional RecordBatch containing parameter values.
485+ parameters: Optional parameters for the query. Can be:
486+ - pyarrow.Table: A table of parameter values
487+ - pyarrow.RecordBatch: A batch of parameter values
488+ - Mapping[str, Any]: A dictionary mapping parameter names to values
489+ - Sequence[Any]: A list of positional parameter values
469490
470491 Returns:
471492 FlightStreamReader with query results.
493+
494+ Example:
495+ >>> # Using a dictionary
496+ >>> stmt.query(parameters={"id": 42, "name": "Alice"})
497+
498+ >>> # Using a list
499+ >>> stmt.query(parameters=[42, "Alice"])
500+
501+ >>> # Using a RecordBatch
502+ >>> batch = pa.record_batch({"id": [42], "name": ["Alice"]})
503+ >>> stmt.query(parameters=batch)
472504 """
473505 cmd = sql_pb2 .CommandPreparedStatementQuery ()
474506 cmd .prepared_statement_handle = self ._handle
475507
476508 descriptor = flight .FlightDescriptor .for_command (_pack_command (cmd ))
477509 info = self ._client .get_flight_info (descriptor )
478510
479- # If parameters are provided, send them via DoPut
480- if parameters :
481- record_batch = pa .record_batch ({key : [value ] for (key , value ) in parameters .items ()})
482- writer , _ = self ._client .do_put (descriptor , record_batch .schema )
483- writer .write_batch (record_batch )
511+ if parameters is not None :
512+ as_pyarrow = self ._get_parameter_as_pyarrow (parameters )
513+ writer , _ = self ._client .do_put (descriptor , as_pyarrow .schema )
514+ writer .write (as_pyarrow )
484515 writer .close ()
485516
486517 return self ._client .do_get (info .endpoints [0 ].ticket )
@@ -501,3 +532,36 @@ def __enter__(self) -> "PreparedStatement":
501532 def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
502533 """Context manager exit."""
503534 self .close ()
535+
536+ def _get_parameter_as_pyarrow (
537+ self , parameters : Union [pa .Table , pa .RecordBatch , Mapping [str , Any ], Sequence [Any ]]
538+ ) -> Union [pa .Table , pa .RecordBatch ]:
539+ if isinstance (parameters , pa .Table ):
540+ return parameters
541+ elif isinstance (parameters , pa .RecordBatch ):
542+ return parameters
543+ elif isinstance (parameters , Mapping ):
544+ return pa .record_batch ({key : [value ] for (key , value ) in parameters .items ()})
545+ elif isinstance (parameters , Sequence ):
546+ if self ._parameter_schema is None :
547+ raise ValueError (
548+ "Cannot use positional parameters without parameter schema. "
549+ "Use a dictionary (Mapping[str, Any]) instead."
550+ )
551+
552+ # Create record batch with positional parameters
553+ if len (parameters ) != len (self ._parameter_schema ):
554+ raise ValueError (
555+ f"Expected { len (self ._parameter_schema )} parameters, "
556+ f"but got { len (parameters )} "
557+ )
558+ param_dict = {
559+ field .name : [value ] for field , value in zip (self ._parameter_schema , parameters )
560+ }
561+
562+ return pa .record_batch (param_dict )
563+ else :
564+ raise TypeError (
565+ f"Unsupported parameter type: { type (parameters )} . "
566+ "Expected Table, RecordBatch, Mapping, or Sequence."
567+ )
0 commit comments