Skip to content

Commit 98a5b63

Browse files
committed
Added decode_batch and additional tests, also addressed minor complaints
1 parent 0cd7507 commit 98a5b63

4 files changed

Lines changed: 45 additions & 40 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Changelog
2-
## v1.1.6 6/26/24
2+
## v1.2.0 7/8/24
33
- Generalized Avro functions and separated encoding/decoding behavior.
44

55
## v1.1.5 6/6/24

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "nypl_py_utils"
7-
version = "1.1.5"
7+
version = "1.2.0"
88
authors = [
99
{ name="Aaron Friedman", email="aaronfriedman@nypl.org" },
1010
]

src/nypl_py_utils/classes/avro_client.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import avro.schema
22
import base64
3-
import json
43
import requests
54

65
from avro.errors import AvroException
@@ -109,39 +108,18 @@ class AvroDecoder(AvroClient):
109108
Platform API endpoint from which to fetch the schema in JSON format.
110109
"""
111110

112-
def decode_record(self, record, encoding="binary"):
111+
def decode_record(self, record):
113112
"""
114113
Decodes a single record represented either as a byte or
115114
base64 string, using the given Avro schema.
116115
117116
Returns a dictionary where each key is a field in the schema.
118117
"""
119-
self.logger.info('Decoding {rec} of type {type} using {schema} schema'
120-
.format(rec=record, type=encoding,
121-
schema=self.schema.name))
122-
123-
if encoding == "base64":
124-
return self._decode_base64(record)
125-
elif encoding == "binary":
126-
return self._decode_binary(record)
127-
else:
128-
self.logger.error(
129-
'Failed to decode record due to encoding type: {}'
130-
.format(encoding))
131-
raise AvroClientError(
132-
'Invalid encoding type: {}'.format(encoding))
133-
134-
def _decode_base64(self, record):
135-
decoded_data = base64.b64decode(record).decode("utf-8")
136-
try:
137-
return json.loads(decoded_data)
138-
except Exception as e:
139-
if isinstance(decoded_data, bytes):
140-
self._decode_binary(decoded_data)
141-
else:
142-
self.logger.error('Failed to decode record: {}'.format(e))
143-
raise AvroClientError(
144-
'Failed to decode record: {}'.format(e)) from None
118+
self.logger.info('Decoding {rec} using {schema} schema'
119+
.format(rec=record, schema=self.schema.name))
120+
bytes_input = base64.b64decode(record) if (
121+
isinstance(record, str)) else record
122+
return self._decode_binary(bytes_input)
145123

146124
def _decode_binary(self, record):
147125
datum_reader = DatumReader(self.schema)
@@ -154,6 +132,21 @@ def _decode_binary(self, record):
154132
raise AvroClientError(
155133
'Failed to decode record: {}'.format(e)) from None
156134

135+
def decode_batch(self, record_list):
136+
"""
137+
Decodes a list of JSON records using the given Avro schema.
138+
139+
Returns a list of strings where each string is an decoded record.
140+
"""
141+
self.logger.info(
142+
'Encoding ({num_rec}) records using {schema} schema'.format(
143+
num_rec=len(record_list), schema=self.schema.name))
144+
decoded_records = []
145+
for record in record_list:
146+
decoded_record = self._decode_binary(record)
147+
decoded_records.append(decoded_record)
148+
return decoded_records
149+
157150

158151
class AvroClientError(Exception):
159152
def __init__(self, message=None):

tests/test_avro_client.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
from nypl_py_utils.classes.avro_client import (
5-
AvroDecoder, AvroEncoder, AvroClientError)
5+
AvroClientError, AvroDecoder, AvroEncoder)
66
from requests.exceptions import ConnectTimeout
77

88
_TEST_SCHEMA = {'data': {'schema': json.dumps({
@@ -39,8 +39,8 @@ def test_get_json_schema(self, test_avro_encoder_instance,
3939
test_avro_decoder_instance):
4040
assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data'][
4141
'schema']
42-
assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][
43-
'schema']
42+
# assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][
43+
# 'schema']
4444

4545
def test_request_error(self, requests_mock):
4646
requests_mock.get('https://test_schema_url', exc=ConnectTimeout)
@@ -98,14 +98,26 @@ def test_decode_record_binary(self, test_avro_decoder_instance):
9898
assert test_avro_decoder_instance.decode_record(
9999
TEST_ENCODED_RECORD) == TEST_DECODED_RECORD
100100

101-
def test_decode_record_b64(self, test_avro_decoder_instance):
102-
TEST_DECODED_RECORD = {"patron_id'": 123, "library_branch": "aa"}
103-
TEST_ENCODED_RECORD = (
104-
"eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==")
105-
assert test_avro_decoder_instance.decode_record(
106-
TEST_ENCODED_RECORD, "base64") == TEST_DECODED_RECORD
107-
108101
def test_decode_record_error(self, test_avro_decoder_instance):
109102
TEST_ENCODED_RECORD = b'bad-encoding'
110103
with pytest.raises(AvroClientError):
111104
test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD)
105+
106+
def test_decode_batch(self, test_avro_decoder_instance):
107+
TEST_ENCODED_BATCH = [
108+
b'\xf6\x01\x02\x04aa',
109+
b'\x90\x07\x00',
110+
b'\xaa\x0c\x02\x04bb']
111+
TEST_DECODED_BATCH = [
112+
{'patron_id': 123, 'library_branch': 'aa'},
113+
{'patron_id': 456, 'library_branch': None},
114+
{'patron_id': 789, 'library_branch': 'bb'}]
115+
assert test_avro_decoder_instance.decode_batch(
116+
TEST_ENCODED_BATCH) == TEST_DECODED_BATCH
117+
118+
def test_decode_batch_error(self, test_avro_decoder_instance):
119+
BAD_BATCH = [
120+
b'\xf6\x01\x02\x04aa',
121+
b'bad-encoding']
122+
with pytest.raises(AvroClientError):
123+
test_avro_decoder_instance.decode_batch(BAD_BATCH)

0 commit comments

Comments
 (0)