@@ -57,35 +57,32 @@ class IngestIncrementalOptions:
5757class BearerAuthMiddleware (flight .ClientMiddleware ):
5858 """Client middleware that adds Bearer token authentication to all requests."""
5959
60- def __init__ (self , token : bytes ):
61- """
62- Initialize the middleware with a Bearer token.
63-
64- Args:
65- token: The Bearer token
66- """
67- self ._token = token
60+ def __init__ (self , factory : "BearerAuthMiddlewareFactory" ):
61+ self ._factory = factory
6862
6963 def sending_headers (self ):
7064 """A callback before headers are sent."""
71- return {b"authorization" : self ._token }
65+ headers = {}
66+
67+ if self ._factory .token :
68+ headers [b"authorization" ] = self ._factory .token
69+
70+ return headers
71+
72+ def received_headers (self , headers ):
73+ if token := headers .get ("authorization" ):
74+ self ._factory .token = token
7275
7376
7477class BearerAuthMiddlewareFactory (flight .ClientMiddlewareFactory ):
7578 """Factory for creating Bearer authentication middleware."""
7679
77- def __init__ (self , token : bytes ):
78- """
79- Initialize the factory with credentials.
80-
81- Args:
82- token: The Bearer token
83- """
84- self ._token = token
80+ def __init__ (self ):
81+ self .token = None
8582
8683 def start_call (self , info ):
8784 """Create middleware instance for a new call."""
88- return BearerAuthMiddleware (self . _token )
85+ return BearerAuthMiddleware (self )
8986
9087
9188class Client :
@@ -139,9 +136,9 @@ def __init__(
139136 self ._auto_commit = auto_commit
140137 self ._transaction = None
141138
142- token = self ._handshake ()
143- auth_middleware = BearerAuthMiddlewareFactory (token )
139+ auth_middleware = BearerAuthMiddlewareFactory ()
144140 self ._client = flight .FlightClient (location , middleware = [auth_middleware ])
141+ self ._client .authenticate_basic_token (self ._username , self ._password )
145142
146143 options = {}
147144 if catalog :
@@ -153,17 +150,6 @@ def __init__(
153150 if options :
154151 self ._set_options (options )
155152
156- def _handshake (self ) -> bytes :
157- """
158- Perform authentication handshake with the server.
159-
160- Returns:
161- bytes: The Bearer token returned by the server
162- """
163- with flight .FlightClient (self ._location ) as client :
164- header = client .authenticate_basic_token (self ._username , self ._password )
165- return header [1 ]
166-
167153 def _set_options (self , options : Mapping [str , sql_pb2 .SessionOptionValue ]):
168154 cmd = sql_pb2 .SetSessionOptionsRequest (session_options = options )
169155 action = flight .Action ("SetSessionOptions" , _pack_command (cmd ))
@@ -688,8 +674,7 @@ def _get_parameter_as_pyarrow(
688674 # Create record batch with positional parameters
689675 if len (parameters ) != len (self ._parameter_schema ):
690676 raise ValueError (
691- f"Expected { len (self ._parameter_schema )} parameters, "
692- f"but got { len (parameters )} "
677+ f"Expected { len (self ._parameter_schema )} parameters, but got { len (parameters )} "
693678 )
694679 param_dict = {
695680 field .name : [value ] for field , value in zip (self ._parameter_schema , parameters )
0 commit comments