Skip to content

Commit d7cd66c

Browse files
authored
Add encoder and decoder for DoubleXYData (#358)
* Add message_type to datatypeinfo and ParameterMetadata. Add DoubleXYData type enum. No parsers for it yet. * ni-python-styleguide fix * Modify generate stubs to work when test asset .proto files reference ni types in stubs. * Add serializer and deserializer for DoubleXYData with tests * Fix lint and styleguide errors * Get tests passing * Fix mypy errors * Put underscore before some private methods. * Move helpers to _message.py * Use ValueError * Add unsupported encoder for Array messages * Update to _EncodeVarint and _DecodeVarint copied directly * Remove unnecessary local * Add type hints for encoder and decoder * Get tests passing * Remove unnecessary method * Brad's review feedback * Couple more minor feedback items * Review comments * Review feedback, type hints etc * One more type hint * Add return type hint to _encode_message
1 parent 57b36c9 commit d7cd66c

10 files changed

Lines changed: 291 additions & 72 deletions

File tree

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import struct
2+
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar
3+
4+
from google.protobuf.internal import encoder, wire_format
5+
from google.protobuf.message import Message
6+
7+
from ni_measurementlink_service._internal.parameter._serializer_types import (
8+
Decoder,
9+
Key,
10+
NewDefault,
11+
WriteFunction,
12+
)
13+
14+
15+
def _message_encoder_constructor(
16+
field_index: int, is_repeated: bool, is_packed: bool
17+
) -> Callable[[WriteFunction, Message, bool], int]:
18+
"""Mimics google.protobuf.internal.MessageEncoder.
19+
20+
This function was forked in order to call SerializeToString instead of _InternalSerialize.
21+
22+
_InternalSerialize is only defined for the pure-Python protobuf implementation. Our child
23+
messages (like DoubleXYData) are defined in .proto files, so they use whichever protobuf
24+
implementation that google.protobuf.internal.api_implementation chooses (usually upb).
25+
"""
26+
tag = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED)
27+
encode_varint = _varint_encoder()
28+
29+
def _encode_message(write: WriteFunction, value: Message, deterministic: bool) -> int:
30+
write(tag)
31+
bytes = value.SerializeToString()
32+
encode_varint(write, len(bytes), deterministic)
33+
return write(bytes)
34+
35+
return _encode_message
36+
37+
38+
def _varint_encoder() -> Callable[[WriteFunction, int, Optional[bool]], int]:
39+
"""Return an encoder for a basic varint value (does not include tag).
40+
41+
From google.protobuf.internal.encoder.py _VarintEncoder
42+
"""
43+
local_int2byte = struct.Struct(">B").pack
44+
45+
def encode_varint(
46+
write: WriteFunction, value: int, unused_deterministic: Optional[bool] = None
47+
) -> int:
48+
bits = value & 0x7F
49+
value >>= 7
50+
while value:
51+
write(local_int2byte(0x80 | bits))
52+
bits = value & 0x7F
53+
value >>= 7
54+
return write(local_int2byte(bits))
55+
56+
return encode_varint
57+
58+
59+
def _message_decoder_constructor(
60+
field_index: int, is_repeated: bool, is_packed: bool, key: Key, new_default: NewDefault
61+
) -> Decoder:
62+
"""Mimics google.protobuf.internal.MessageDecoder.
63+
64+
This function was forked in order to call ParseFromString instead of _InternalParse.
65+
66+
_InternalParse is only defined for the pure-Python protobuf implementation. Our child messages
67+
(like DoubleXYData) are defined in .proto files, so they use whichever protobuf implementation
68+
that google.protobuf.internal.api_implementation chooses (usually upb).
69+
"""
70+
71+
def _decode_message(
72+
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
73+
) -> int:
74+
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
75+
value = field_dict.get(key)
76+
if value is None:
77+
value = field_dict.setdefault(key, new_default(message))
78+
# Read length.
79+
(size, pos) = decode_varint(buffer, pos)
80+
new_pos = pos + size
81+
if new_pos > end:
82+
raise ValueError("Error decoding a message. Message is truncated.")
83+
parsed_bytes = value.ParseFromString(buffer[pos:new_pos])
84+
if parsed_bytes != size:
85+
raise ValueError("Parsed incorrect number of bytes.")
86+
return new_pos
87+
88+
return _decode_message
89+
90+
91+
T = TypeVar("T", bound="int")
92+
93+
94+
def _varint_decoder(mask: int, result_type: Type[T]) -> Callable[[memoryview, int], Tuple[T, int]]:
95+
"""Return an encoder for a basic varint value (does not include tag).
96+
97+
Decoded values will be bitwise-anded with the given mask before being
98+
returned, e.g. to limit them to 32 bits. The returned decoder does not
99+
take the usual "end" parameter -- the caller is expected to do bounds checking
100+
after the fact (often the caller can defer such checking until later). The
101+
decoder returns a (value, new_pos) pair.
102+
103+
From google.protobuf.internal.decoder.py _VarintDecoder
104+
"""
105+
106+
def decode_varint(buffer: memoryview, pos: int) -> Tuple[T, int]:
107+
result = 0
108+
shift = 0
109+
while 1:
110+
b = buffer[pos]
111+
result |= (b & 0x7F) << shift
112+
pos += 1
113+
if not (b & 0x80):
114+
result &= mask
115+
result = result_type(result)
116+
return (result, pos)
117+
shift += 7
118+
if shift >= 64:
119+
raise ValueError("Too many bytes when decoding varint: {shift}")
120+
121+
return decode_varint
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
import typing
5+
from typing import Any, Callable, Dict
6+
7+
from google.protobuf.descriptor import FieldDescriptor
8+
from google.protobuf.message import Message
9+
10+
if typing.TYPE_CHECKING:
11+
if sys.version_info >= (3, 10):
12+
from typing import TypeAlias
13+
else:
14+
from typing_extensions import TypeAlias
15+
16+
17+
Key: TypeAlias = FieldDescriptor
18+
WriteFunction: TypeAlias = Callable[[bytes], int]
19+
Encoder: TypeAlias = Callable[[WriteFunction, bytes, bool], int]
20+
PartialEncoderConstructor: TypeAlias = Callable[[int], Encoder]
21+
EncoderConstructor: TypeAlias = Callable[[int, bool, bool], Encoder]
22+
23+
Decoder: TypeAlias = Callable[[memoryview, int, int, Message, Dict[Key, Any]], int]
24+
PartialDecoderConstructor: TypeAlias = Callable[[int, Key], Decoder]
25+
NewDefault: TypeAlias = Callable[[Message], Message]
26+
DecoderConstructor: TypeAlias = Callable[[int, bool, bool, Key, NewDefault], Decoder]

ni_measurementlink_service/_internal/parameter/serialization_strategy.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,23 @@
11
"""Serialization Strategy."""
22
from __future__ import annotations
33

4-
import sys
5-
import typing
6-
from typing import Any, Callable, Dict, Optional, cast
4+
from typing import Any, Optional, cast
75

86
from google.protobuf import type_pb2
9-
from google.protobuf.descriptor import FieldDescriptor
107
from google.protobuf.internal import decoder, encoder
118
from google.protobuf.message import Message
129

13-
if typing.TYPE_CHECKING:
14-
if sys.version_info >= (3, 10):
15-
from typing import TypeAlias
16-
else:
17-
from typing_extensions import TypeAlias
18-
19-
Key: TypeAlias = FieldDescriptor
20-
WriteFunction: TypeAlias = Callable[[bytes], int]
21-
Encoder: TypeAlias = Callable[[WriteFunction, bytes, bool], int]
22-
PartialEncoderConstructor: TypeAlias = Callable[[int], Encoder]
23-
EncoderConstructor: TypeAlias = Callable[[int, bool, bool], Encoder]
24-
25-
Decoder: TypeAlias = Callable[[memoryview, int, int, Message, Dict[Key, Any]], int]
26-
PartialDecoderConstructor: TypeAlias = Callable[[int, Key], Decoder]
27-
NewDefault: TypeAlias = Callable[[Message], Message]
28-
DecoderConstructor: TypeAlias = Callable[[int, bool, bool, Key, NewDefault], Decoder]
10+
from ni_measurementlink_service._internal.parameter import _message
11+
from ni_measurementlink_service._internal.parameter._serializer_types import (
12+
Decoder,
13+
DecoderConstructor,
14+
Encoder,
15+
EncoderConstructor,
16+
Key,
17+
PartialDecoderConstructor,
18+
PartialEncoderConstructor,
19+
)
20+
from ni_measurementlink_service._internal.stubs.ni.protobuf.types import xydata_pb2
2921

3022

3123
def _scalar_encoder(encoder: EncoderConstructor) -> PartialEncoderConstructor:
@@ -63,6 +55,10 @@ def vector_encoder(field_index: int) -> Encoder:
6355
return vector_encoder
6456

6557

58+
def _unsupported_encoder(field_index: int, is_repeated: bool, is_packed: bool) -> Encoder:
59+
raise NotImplementedError(f"Unsupported data type for field {field_index}")
60+
61+
6662
def _scalar_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor:
6763
"""Constructs a scalar decoder constructor.
6864
@@ -106,6 +102,23 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
106102
return vector_decoder
107103

108104

105+
def _double_xy_data_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor:
106+
"""Constructs a DoubleXYData decoder constructor.
107+
108+
Takes a field index and a key and returns a Decoder for DoubleXYData.
109+
"""
110+
111+
def _new_default(unused_message: Optional[Message] = None) -> Any:
112+
return xydata_pb2.DoubleXYData()
113+
114+
def message_decoder(field_index: int, key: Key) -> Decoder:
115+
is_repeated = True
116+
is_packed = True
117+
return decoder(field_index, is_repeated, is_packed, key, _new_default)
118+
119+
return message_decoder
120+
121+
109122
# Cast works around this issue in typeshed
110123
# https://github.com/python/typeshed/issues/10695
111124
FloatEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.FloatEncoder))
@@ -114,13 +127,15 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
114127
UIntEncoder = _scalar_encoder(cast(EncoderConstructor, encoder.UInt32Encoder))
115128
BoolEncoder = _scalar_encoder(encoder.BoolEncoder)
116129
StringEncoder = _scalar_encoder(encoder.StringEncoder)
130+
MessageEncoder = _scalar_encoder(cast(EncoderConstructor, _message._message_encoder_constructor))
117131

118132
FloatArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.FloatEncoder))
119133
DoubleArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.DoubleEncoder))
120134
IntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.Int32Encoder))
121135
UIntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.UInt32Encoder))
122136
BoolArrayEncoder = _vector_encoder(encoder.BoolEncoder)
123137
StringArrayEncoder = _vector_encoder(encoder.StringEncoder, is_packed=False)
138+
UnsupportedMessageArrayEncoder = _vector_encoder(_unsupported_encoder)
124139

125140
# Cast works around this issue in typeshed
126141
# https://github.com/python/typeshed/issues/10697
@@ -132,6 +147,7 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
132147
UInt64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt64Decoder))
133148
BoolDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.BoolDecoder))
134149
StringDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.StringDecoder))
150+
XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor)
135151

136152
FloatArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.FloatDecoder))
137153
DoubleArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.DoubleDecoder))
@@ -155,6 +171,7 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
155171
type_pb2.Field.TYPE_BOOL: (BoolEncoder, BoolArrayEncoder),
156172
type_pb2.Field.TYPE_STRING: (StringEncoder, StringArrayEncoder),
157173
type_pb2.Field.TYPE_ENUM: (IntEncoder, IntArrayEncoder),
174+
type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, UnsupportedMessageArrayEncoder),
158175
}
159176

160177
_FIELD_TYPE_TO_DECODER_MAPPING = {
@@ -181,6 +198,10 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
181198
type_pb2.Field.TYPE_ENUM: int(),
182199
}
183200

201+
_MESSAGE_TYPE_TO_DECODER = {
202+
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataDecoder,
203+
}
204+
184205

185206
def get_encoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialEncoderConstructor:
186207
"""Get the appropriate partial encoder constructor for the specified type.
@@ -195,17 +216,23 @@ def get_encoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialE
195216
return scalar
196217

197218

198-
def get_decoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialDecoderConstructor:
199-
"""Get the appropriate partial decoder constructor for the specified type.
200-
201-
A scalar or vector constructor is returned based on the 'repeated' parameter.
202-
"""
203-
if type not in _FIELD_TYPE_TO_DECODER_MAPPING:
219+
def get_decoder(
220+
type: type_pb2.Field.Kind.ValueType, repeated: bool, message_type: str = ""
221+
) -> PartialDecoderConstructor:
222+
"""Get the appropriate partial decoder constructor for the specified type."""
223+
decoder_mapping = _FIELD_TYPE_TO_DECODER_MAPPING.get(type)
224+
if decoder_mapping is not None:
225+
scalar_decoder, array_decoder = decoder_mapping
226+
return array_decoder if repeated else scalar_decoder
227+
elif type == type_pb2.Field.Kind.TYPE_MESSAGE:
228+
if repeated:
229+
raise ValueError(f"Repeated message types are not supported '{message_type}'")
230+
decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type)
231+
if decoder is None:
232+
raise ValueError(f"Unknown message type '{message_type}'")
233+
return decoder
234+
else:
204235
raise ValueError(f"Error can not decode type '{type}'")
205-
scalar, array = _FIELD_TYPE_TO_DECODER_MAPPING[type]
206-
if repeated:
207-
return array
208-
return scalar
209236

210237

211238
def get_type_default(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> Any:

ni_measurementlink_service/_internal/parameter/serializer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ def _get_overlapping_parameters(
158158
raise Exception(
159159
f"Error occurred while reading the parameter - given protobuf index '{field_index}' is invalid."
160160
)
161-
type = parameter_metadata_dict[field_index].type
162-
is_repeated = parameter_metadata_dict[field_index].repeated
163-
decoder = serialization_strategy.get_decoder(type, is_repeated)
161+
field_metadata = parameter_metadata_dict[field_index]
162+
decoder = serialization_strategy.get_decoder(
163+
field_metadata.type, field_metadata.repeated, field_metadata.message_type
164+
)
164165
inner_decoder = decoder(field_index, field_index)
165166
parameter_bytes_io = BytesIO(parameter_bytes)
166167
parameter_bytes_memory_view = parameter_bytes_io.getbuffer()

0 commit comments

Comments
 (0)