44This module provides a high-level Python client for Altertable.
55"""
66
7- from collections .abc import Mapping
8- from typing import Any , Optional
7+ from collections .abc import Mapping , Sequence
8+ from typing import Any , Optional , Union
99
1010import pyarrow as pa
1111import pyarrow .flight as flight
@@ -271,7 +271,17 @@ def prepare(
271271 # Parse result
272272 result = sql_pb2 .ActionCreatePreparedStatementResult ()
273273 _unpack_command (results [0 ].body .to_pybytes (), result )
274- return PreparedStatement (self ._client , result .prepared_statement_handle )
274+
275+ # Extract parameter schema if available
276+ parameter_schema = None
277+ if result .parameter_schema :
278+ parameter_schema = pa .ipc .read_schema (pa .py_buffer (result .parameter_schema ))
279+
280+ return PreparedStatement (
281+ self ._client ,
282+ result .prepared_statement_handle ,
283+ parameter_schema = parameter_schema
284+ )
275285
276286 def get_catalogs (self ) -> flight .FlightStreamReader :
277287 """
@@ -445,42 +455,63 @@ class PreparedStatement:
445455 Prepared statements can be executed multiple times with different parameters.
446456 """
447457
448- def __init__ (self , client : flight .FlightClient , handle : bytes ):
458+ def __init__ (
459+ self ,
460+ client : flight .FlightClient ,
461+ handle : bytes ,
462+ parameter_schema : Optional [pa .Schema ] = None
463+ ):
449464 """
450465 Initialize a prepared statement.
451466
452467 Args:
453468 client: FlightClient instance.
454469 handle: Prepared statement handle from server.
470+ parameter_schema: Optional parameter schema for the prepared statement.
455471 """
456472 self ._client = client
457473 self ._handle = handle
474+ self ._parameter_schema = parameter_schema
458475
459476 def query (
460477 self ,
461478 * ,
462- parameters : Optional [Mapping [str , Any ]] = None ,
479+ parameters : Optional [Union [ pa . Table , pa . RecordBatch , Mapping [str , Any ], Sequence [ Any ] ]] = None ,
463480 ) -> flight .FlightStreamReader :
464481 """
465482 Execute the prepared statement query.
466483
467484 Args:
468- parameters: Optional RecordBatch containing parameter values.
485+ parameters: Optional parameters for the query. Can be:
486+ - pyarrow.Table: A table of parameter values
487+ - pyarrow.RecordBatch: A batch of parameter values
488+ - Mapping[str, Any]: A dictionary mapping parameter names to values
489+ - Sequence[Any]: A list of positional parameter values
469490
470491 Returns:
471492 FlightStreamReader with query results.
493+
494+ Example:
495+ >>> # Using a dictionary
496+ >>> stmt.query(parameters={"id": 42, "name": "Alice"})
497+
498+ >>> # Using a list
499+ >>> stmt.query(parameters=[42, "Alice"])
500+
501+ >>> # Using a RecordBatch
502+ >>> batch = pa.record_batch({"id": [42], "name": ["Alice"]})
503+ >>> stmt.query(parameters=batch)
472504 """
473505 cmd = sql_pb2 .CommandPreparedStatementQuery ()
474506 cmd .prepared_statement_handle = self ._handle
475507
476508 descriptor = flight .FlightDescriptor .for_command (_pack_command (cmd ))
477509 info = self ._client .get_flight_info (descriptor )
478510
479- # If parameters are provided, send them via DoPut
480- if parameters :
481- record_batch = pa .record_batch ({key : [value ] for (key , value ) in parameters .items ()})
482- writer , _ = self ._client .do_put (descriptor , record_batch .schema )
483- writer .write_batch (record_batch )
511+ if parameters is not None :
512+ as_pyarrow = self ._get_parameter_as_pyarrow (parameters )
513+ writer , _ = self ._client .do_put (descriptor , as_pyarrow .schema )
514+ writer .write (as_pyarrow )
484515 writer .close ()
485516
486517 return self ._client .do_get (info .endpoints [0 ].ticket )
@@ -501,3 +532,35 @@ def __enter__(self) -> "PreparedStatement":
501532 def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
502533 """Context manager exit."""
503534 self .close ()
535+
536+ def _get_parameter_as_pyarrow (self , parameters : Union [pa .Table , pa .RecordBatch , Mapping [str , Any ], Sequence [Any ]]) -> Union [pa .Table , pa .RecordBatch ]:
537+ if isinstance (parameters , pa .Table ):
538+ return parameters
539+ elif isinstance (parameters , pa .RecordBatch ):
540+ return parameters
541+ elif isinstance (parameters , Mapping ):
542+ return pa .record_batch ({key : [value ] for (key , value ) in parameters .items ()})
543+ elif isinstance (parameters , Sequence ):
544+ if self ._parameter_schema is None :
545+ raise ValueError (
546+ "Cannot use positional parameters without parameter schema. "
547+ "Use a dictionary (Mapping[str, Any]) instead."
548+ )
549+
550+ # Create record batch with positional parameters
551+ if len (parameters ) != len (self ._parameter_schema ):
552+ raise ValueError (
553+ f"Expected { len (self ._parameter_schema )} parameters, "
554+ f"but got { len (parameters )} "
555+ )
556+ param_dict = {
557+ field .name : [value ]
558+ for field , value in zip (self ._parameter_schema , parameters )
559+ }
560+
561+ return pa .record_batch (param_dict )
562+ else :
563+ raise TypeError (
564+ f"Unsupported parameter type: { type (parameters )} . "
565+ "Expected Table, RecordBatch, Mapping, or Sequence."
566+ )
0 commit comments