Skip to content
This repository was archived by the owner on Nov 12, 2025. It is now read-only.

Commit 4f56029

Browse files
tswasttseaver
andauthored
feat: read_session optional to ReadRowsStream.rows() (#228)
* feat: `read_session` optional to `ReadRowsStream.rows()` The schema from the first `ReadRowsResponse` message can be used to decode messages, instead. Note: `to_arrow()` and `to_dataframe()` do not work on an empty stream unless a `read_session` has been passed in, as the schema is not available. This should not affect `google-cloud-bigquery` and `pandas-gbq`, as those packages use the lower-level message->dataframe/arrow methods. * revert change to comment * use else for empty arrow streams in try-except block Co-authored-by: Tres Seaver <tseaver@palladion.com> * update docstring to reflect that readsession and readrowsresponse can be used interchangeably * update arrow deserializer, too Co-authored-by: Tres Seaver <tseaver@palladion.com>
1 parent a8a8c78 commit 4f56029

4 files changed

Lines changed: 196 additions & 140 deletions

File tree

google/cloud/bigquery_storage_v1/reader.py

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _reconnect(self):
156156
read_stream=self._name, offset=self._offset, **self._read_rows_kwargs
157157
)
158158

159-
def rows(self, read_session):
159+
def rows(self, read_session=None):
160160
"""Iterate over all rows in the stream.
161161
162162
This method requires the fastavro library in order to parse row
@@ -169,19 +169,21 @@ def rows(self, read_session):
169169
170170
Args:
171171
read_session ( \
172-
~google.cloud.bigquery_storage_v1.types.ReadSession \
172+
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
173173
):
174-
The read session associated with this read rows stream. This
175-
contains the schema, which is required to parse the data
176-
messages.
174+
DEPRECATED.
175+
176+
This argument was used to specify the schema of the rows in the
177+
stream, but now the first message in a read stream contains
178+
this information.
177179
178180
Returns:
179181
Iterable[Mapping]:
180182
A sequence of rows, represented as dictionaries.
181183
"""
182-
return ReadRowsIterable(self, read_session)
184+
return ReadRowsIterable(self, read_session=read_session)
183185

184-
def to_arrow(self, read_session):
186+
def to_arrow(self, read_session=None):
185187
"""Create a :class:`pyarrow.Table` of all rows in the stream.
186188
187189
This method requires the pyarrow library and a stream using the Arrow
@@ -191,17 +193,19 @@ def to_arrow(self, read_session):
191193
read_session ( \
192194
~google.cloud.bigquery_storage_v1.types.ReadSession \
193195
):
194-
The read session associated with this read rows stream. This
195-
contains the schema, which is required to parse the data
196-
messages.
196+
DEPRECATED.
197+
198+
This argument was used to specify the schema of the rows in the
199+
stream, but now the first message in a read stream contains
200+
this information.
197201
198202
Returns:
199203
pyarrow.Table:
200204
A table of all rows in the stream.
201205
"""
202-
return self.rows(read_session).to_arrow()
206+
return self.rows(read_session=read_session).to_arrow()
203207

204-
def to_dataframe(self, read_session, dtypes=None):
208+
def to_dataframe(self, read_session=None, dtypes=None):
205209
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
206210
207211
This method requires the pandas libary to create a data frame and the
@@ -215,9 +219,11 @@ def to_dataframe(self, read_session, dtypes=None):
215219
read_session ( \
216220
~google.cloud.bigquery_storage_v1.types.ReadSession \
217221
):
218-
The read session associated with this read rows stream. This
219-
contains the schema, which is required to parse the data
220-
messages.
222+
DEPRECATED.
223+
224+
This argument was used to specify the schema of the rows in the
225+
stream, but now the first message in a read stream contains
226+
this information.
221227
dtypes ( \
222228
Map[str, Union[str, pandas.Series.dtype]] \
223229
):
@@ -233,7 +239,7 @@ def to_dataframe(self, read_session, dtypes=None):
233239
if pandas is None:
234240
raise ImportError(_PANDAS_REQUIRED)
235241

236-
return self.rows(read_session).to_dataframe(dtypes=dtypes)
242+
return self.rows(read_session=read_session).to_dataframe(dtypes=dtypes)
237243

238244

239245
class ReadRowsIterable(object):
@@ -242,18 +248,25 @@ class ReadRowsIterable(object):
242248
Args:
243249
reader (google.cloud.bigquery_storage_v1.reader.ReadRowsStream):
244250
A read rows stream.
245-
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
246-
A read session. This is required because it contains the schema
247-
used in the stream messages.
251+
read_session ( \
252+
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
253+
):
254+
DEPRECATED.
255+
256+
This argument was used to specify the schema of the rows in the
257+
stream, but now the first message in a read stream contains
258+
this information.
248259
"""
249260

250261
# This class is modelled after the google.cloud.bigquery.table.RowIterator
251262
# and aims to be API compatible where possible.
252263

253-
def __init__(self, reader, read_session):
264+
def __init__(self, reader, read_session=None):
254265
self._reader = reader
255-
self._read_session = read_session
256-
self._stream_parser = _StreamParser.from_read_session(self._read_session)
266+
if read_session is not None:
267+
self._stream_parser = _StreamParser.from_read_session(read_session)
268+
else:
269+
self._stream_parser = None
257270

258271
@property
259272
def pages(self):
@@ -266,6 +279,10 @@ def pages(self):
266279
# Each page is an iterator of rows. But also has num_items, remaining,
267280
# and to_dataframe.
268281
for message in self._reader:
282+
# Only the first message contains the schema, which is needed to
283+
# decode the messages.
284+
if not self._stream_parser:
285+
self._stream_parser = _StreamParser.from_read_rows_response(message)
269286
yield ReadRowsPage(self._stream_parser, message)
270287

271288
def __iter__(self):
@@ -328,10 +345,11 @@ def to_dataframe(self, dtypes=None):
328345
# pandas dataframe is about 2x faster. This is because pandas.concat is
329346
# rarely no-copy, whereas pyarrow.Table.from_batches + to_pandas is
330347
# usually no-copy.
331-
schema_type = self._read_session._pb.WhichOneof("schema")
332-
333-
if schema_type == "arrow_schema":
348+
try:
334349
record_batch = self.to_arrow()
350+
except NotImplementedError:
351+
pass
352+
else:
335353
df = record_batch.to_pandas()
336354
for column in dtypes:
337355
df[column] = pandas.Series(df[column], dtype=dtypes[column])
@@ -491,6 +509,12 @@ def to_dataframe(self, message, dtypes=None):
491509
def to_rows(self, message):
492510
raise NotImplementedError("Not implemented.")
493511

512+
def _parse_avro_schema(self):
513+
raise NotImplementedError("Not implemented.")
514+
515+
def _parse_arrow_schema(self):
516+
raise NotImplementedError("Not implemented.")
517+
494518
@staticmethod
495519
def from_read_session(read_session):
496520
schema_type = read_session._pb.WhichOneof("schema")
@@ -503,22 +527,38 @@ def from_read_session(read_session):
503527
"Unsupported schema type in read_session: {0}".format(schema_type)
504528
)
505529

530+
@staticmethod
531+
def from_read_rows_response(message):
532+
schema_type = message._pb.WhichOneof("schema")
533+
if schema_type == "avro_schema":
534+
return _AvroStreamParser(message)
535+
elif schema_type == "arrow_schema":
536+
return _ArrowStreamParser(message)
537+
else:
538+
raise TypeError(
539+
"Unsupported schema type in message: {0}".format(schema_type)
540+
)
541+
506542

507543
class _AvroStreamParser(_StreamParser):
508544
"""Helper to parse Avro messages into useful representations."""
509545

510-
def __init__(self, read_session):
546+
def __init__(self, message):
511547
"""Construct an _AvroStreamParser.
512548
513549
Args:
514-
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
515-
A read session. This is required because it contains the schema
516-
used in the stream messages.
550+
message (Union[
551+
google.cloud.bigquery_storage_v1.types.ReadSession, \
552+
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
553+
]):
554+
Either the first message of data from a read rows stream or a
555+
read session. Both types contain a oneof "schema" field, which
556+
can be used to determine how to deserialize rows.
517557
"""
518558
if fastavro is None:
519559
raise ImportError(_FASTAVRO_REQUIRED)
520560

521-
self._read_session = read_session
561+
self._first_message = message
522562
self._avro_schema_json = None
523563
self._fastavro_schema = None
524564
self._column_names = None
@@ -548,6 +588,10 @@ def to_dataframe(self, message, dtypes=None):
548588
strings in the fastavro library.
549589
550590
Args:
591+
message ( \
592+
~google.cloud.bigquery_storage_v1.types.ReadRowsResponse \
593+
):
594+
A message containing Avro bytes to parse into a pandas DataFrame.
551595
dtypes ( \
552596
Map[str, Union[str, pandas.Series.dtype]] \
553597
):
@@ -578,10 +622,11 @@ def _parse_avro_schema(self):
578622
if self._avro_schema_json:
579623
return
580624

581-
self._avro_schema_json = json.loads(self._read_session.avro_schema.schema)
625+
self._avro_schema_json = json.loads(self._first_message.avro_schema.schema)
582626
self._column_names = tuple(
583627
(field["name"] for field in self._avro_schema_json["fields"])
584628
)
629+
self._first_message = None
585630

586631
def _parse_fastavro(self):
587632
"""Convert parsed Avro schema to fastavro format."""
@@ -615,11 +660,22 @@ def to_rows(self, message):
615660

616661

617662
class _ArrowStreamParser(_StreamParser):
618-
def __init__(self, read_session):
663+
def __init__(self, message):
664+
"""Construct an _ArrowStreamParser.
665+
666+
Args:
667+
message (Union[
668+
google.cloud.bigquery_storage_v1.types.ReadSession, \
669+
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
670+
]):
671+
Either the first message of data from a read rows stream or a
672+
read session. Both types contain a oneof "schema" field, which
673+
can be used to determine how to deserialize rows.
674+
"""
619675
if pyarrow is None:
620676
raise ImportError(_PYARROW_REQUIRED)
621677

622-
self._read_session = read_session
678+
self._first_message = message
623679
self._schema = None
624680

625681
def to_arrow(self, message):
@@ -659,6 +715,7 @@ def _parse_arrow_schema(self):
659715
return
660716

661717
self._schema = pyarrow.ipc.read_schema(
662-
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
718+
pyarrow.py_buffer(self._first_message.arrow_schema.serialized_schema)
663719
)
664720
self._column_names = [field.name for field in self._schema]
721+
self._first_message = None

tests/system/conftest.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,41 @@
1818
import os
1919
import uuid
2020

21+
import google.auth
22+
from google.cloud import bigquery
2123
import pytest
24+
import test_utils.prefixer
2225

2326
from . import helpers
2427

2528

29+
prefixer = test_utils.prefixer.Prefixer("python-bigquery-storage", "tests/system")
30+
31+
2632
_TABLE_FORMAT = "projects/{}/datasets/{}/tables/{}"
2733
_ASSETS_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "assets")
34+
_ALL_TYPES_SCHEMA = [
35+
bigquery.SchemaField("string_field", "STRING"),
36+
bigquery.SchemaField("bytes_field", "BYTES"),
37+
bigquery.SchemaField("int64_field", "INT64"),
38+
bigquery.SchemaField("float64_field", "FLOAT64"),
39+
bigquery.SchemaField("numeric_field", "NUMERIC"),
40+
bigquery.SchemaField("bool_field", "BOOL"),
41+
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
42+
bigquery.SchemaField(
43+
"person_struct_field",
44+
"STRUCT",
45+
fields=(
46+
bigquery.SchemaField("name", "STRING"),
47+
bigquery.SchemaField("age", "INT64"),
48+
),
49+
),
50+
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
51+
bigquery.SchemaField("date_field", "DATE"),
52+
bigquery.SchemaField("time_field", "TIME"),
53+
bigquery.SchemaField("datetime_field", "DATETIME"),
54+
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
55+
]
2856

2957

3058
@pytest.fixture(scope="session")
@@ -38,18 +66,9 @@ def use_mtls():
3866

3967

4068
@pytest.fixture(scope="session")
41-
def credentials(use_mtls):
42-
import google.auth
43-
from google.oauth2 import service_account
44-
45-
if use_mtls:
46-
# mTLS test uses user credentials instead of service account credentials
47-
creds, _ = google.auth.default()
48-
return creds
49-
50-
# NOTE: the test config in noxfile checks that the env variable is indeed set
51-
filename = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
52-
return service_account.Credentials.from_service_account_file(filename)
69+
def credentials():
70+
creds, _ = google.auth.default()
71+
return creds
5372

5473

5574
@pytest.fixture()
@@ -77,8 +96,7 @@ def local_shakespeare_table_reference(project_id, use_mtls):
7796
def dataset(project_id, bq_client):
7897
from google.cloud import bigquery
7998

80-
unique_suffix = str(uuid.uuid4()).replace("-", "_")
81-
dataset_name = "bq_storage_system_tests_" + unique_suffix
99+
dataset_name = prefixer.create_prefix()
82100

83101
dataset_id = "{}.{}".format(project_id, dataset_name)
84102
dataset = bigquery.Dataset(dataset_id)
@@ -120,35 +138,20 @@ def bq_client(credentials, use_mtls):
120138
return bigquery.Client(credentials=credentials)
121139

122140

141+
@pytest.fixture(scope="session", autouse=True)
142+
def cleanup_datasets(bq_client: bigquery.Client):
143+
for dataset in bq_client.list_datasets():
144+
if prefixer.should_cleanup(dataset.dataset_id):
145+
bq_client.delete_dataset(dataset, delete_contents=True, not_found_ok=True)
146+
147+
123148
@pytest.fixture
124149
def all_types_table_ref(project_id, dataset, bq_client):
125150
from google.cloud import bigquery
126151

127-
schema = [
128-
bigquery.SchemaField("string_field", "STRING"),
129-
bigquery.SchemaField("bytes_field", "BYTES"),
130-
bigquery.SchemaField("int64_field", "INT64"),
131-
bigquery.SchemaField("float64_field", "FLOAT64"),
132-
bigquery.SchemaField("numeric_field", "NUMERIC"),
133-
bigquery.SchemaField("bool_field", "BOOL"),
134-
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
135-
bigquery.SchemaField(
136-
"person_struct_field",
137-
"STRUCT",
138-
fields=(
139-
bigquery.SchemaField("name", "STRING"),
140-
bigquery.SchemaField("age", "INT64"),
141-
),
142-
),
143-
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
144-
bigquery.SchemaField("date_field", "DATE"),
145-
bigquery.SchemaField("time_field", "TIME"),
146-
bigquery.SchemaField("datetime_field", "DATETIME"),
147-
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
148-
]
149152
bq_table = bigquery.table.Table(
150153
table_ref="{}.{}.complex_records".format(project_id, dataset.dataset_id),
151-
schema=schema,
154+
schema=_ALL_TYPES_SCHEMA,
152155
)
153156

154157
created_table = bq_client.create_table(bq_table)

0 commit comments

Comments
 (0)