Skip to content

Commit 226e063

Browse files
fix: use a single client for handshake and queries (#6)
Using two different client is messing with the "persistent" nature of the connection we want to achieve. Using a single client for both handshake _and_ subsequent queries solves this issue.
1 parent ae9b1d0 commit 226e063

1 file changed

Lines changed: 18 additions & 33 deletions

File tree

src/altertable_flightsql/client.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,35 +57,32 @@ class IngestIncrementalOptions:
5757
class BearerAuthMiddleware(flight.ClientMiddleware):
5858
"""Client middleware that adds Bearer token authentication to all requests."""
5959

60-
def __init__(self, token: bytes):
61-
"""
62-
Initialize the middleware with a Bearer token.
63-
64-
Args:
65-
token: The Bearer token
66-
"""
67-
self._token = token
60+
def __init__(self, factory: "BearerAuthMiddlewareFactory"):
61+
self._factory = factory
6862

6963
def sending_headers(self):
7064
"""A callback before headers are sent."""
71-
return {b"authorization": self._token}
65+
headers = {}
66+
67+
if self._factory.token:
68+
headers[b"authorization"] = self._factory.token
69+
70+
return headers
71+
72+
def received_headers(self, headers):
73+
if token := headers.get("authorization"):
74+
self._factory.token = token
7275

7376

7477
class BearerAuthMiddlewareFactory(flight.ClientMiddlewareFactory):
7578
"""Factory for creating Bearer authentication middleware."""
7679

77-
def __init__(self, token: bytes):
78-
"""
79-
Initialize the factory with credentials.
80-
81-
Args:
82-
token: The Bearer token
83-
"""
84-
self._token = token
80+
def __init__(self):
81+
self.token = None
8582

8683
def start_call(self, info):
8784
"""Create middleware instance for a new call."""
88-
return BearerAuthMiddleware(self._token)
85+
return BearerAuthMiddleware(self)
8986

9087

9188
class Client:
@@ -139,9 +136,9 @@ def __init__(
139136
self._auto_commit = auto_commit
140137
self._transaction = None
141138

142-
token = self._handshake()
143-
auth_middleware = BearerAuthMiddlewareFactory(token)
139+
auth_middleware = BearerAuthMiddlewareFactory()
144140
self._client = flight.FlightClient(location, middleware=[auth_middleware])
141+
self._client.authenticate_basic_token(self._username, self._password)
145142

146143
options = {}
147144
if catalog:
@@ -153,17 +150,6 @@ def __init__(
153150
if options:
154151
self._set_options(options)
155152

156-
def _handshake(self) -> bytes:
157-
"""
158-
Perform authentication handshake with the server.
159-
160-
Returns:
161-
bytes: The Bearer token returned by the server
162-
"""
163-
with flight.FlightClient(self._location) as client:
164-
header = client.authenticate_basic_token(self._username, self._password)
165-
return header[1]
166-
167153
def _set_options(self, options: Mapping[str, sql_pb2.SessionOptionValue]):
168154
cmd = sql_pb2.SetSessionOptionsRequest(session_options=options)
169155
action = flight.Action("SetSessionOptions", _pack_command(cmd))
@@ -688,8 +674,7 @@ def _get_parameter_as_pyarrow(
688674
# Create record batch with positional parameters
689675
if len(parameters) != len(self._parameter_schema):
690676
raise ValueError(
691-
f"Expected {len(self._parameter_schema)} parameters, "
692-
f"but got {len(parameters)}"
677+
f"Expected {len(self._parameter_schema)} parameters, but got {len(parameters)}"
693678
)
694679
param_dict = {
695680
field.name: [value] for field, value in zip(self._parameter_schema, parameters)

0 commit comments

Comments
 (0)