Skip to content

Commit 6c70e0f

Browse files
authored
Support an array of XY data plots (#632)
* Add serialization strategy for array of DoubleXYData * Generate new stubs for test.proto with an array of XYData field * Get serializer tests passing with DoubleXYData array * Fix serialization and deserialization of array with multiple xy data message values * Add test_service tests for DoubleXYDataArray1D * Fix Black errors * Fix lint error * Fix mypy errors * Revert unintentional changes * Fix tests
1 parent c3fa33d commit 6c70e0f

10 files changed

Lines changed: 161 additions & 49 deletions

File tree

ni_measurementlink_service/_datatypeinfo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,9 @@ def get_type_info(data_type: DataType) -> DataTypeInfo:
6262
DataType.PinArray1D: DataTypeInfo(type_pb2.Field.TYPE_STRING, True, TypeSpecialization.Pin),
6363
DataType.PathArray1D: DataTypeInfo(type_pb2.Field.TYPE_STRING, True, TypeSpecialization.Path),
6464
DataType.EnumArray1D: DataTypeInfo(type_pb2.Field.TYPE_ENUM, True, TypeSpecialization.Enum),
65+
DataType.DoubleXYDataArray1D: DataTypeInfo(
66+
type_pb2.Field.TYPE_MESSAGE,
67+
True,
68+
message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
69+
),
6570
}

ni_measurementlink_service/_internal/parameter/_message.py

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
import struct
2-
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar
2+
from typing import (
3+
Any,
4+
Callable,
5+
Dict,
6+
List,
7+
Optional,
8+
Tuple,
9+
Type,
10+
TypeVar,
11+
Union,
12+
cast,
13+
)
314

415
from google.protobuf.internal import encoder, wire_format
516
from google.protobuf.message import Message
@@ -14,7 +25,7 @@
1425

1526
def _message_encoder_constructor(
1627
field_index: int, is_repeated: bool, is_packed: bool
17-
) -> Callable[[WriteFunction, Message, bool], int]:
28+
) -> Callable[[WriteFunction, Union[Message, List[Message]], bool], int]:
1829
"""Mimics google.protobuf.internal.MessageEncoder.
1930
2031
This function was forked in order to call SerializeToString instead of _InternalSerialize.
@@ -26,13 +37,31 @@ def _message_encoder_constructor(
2637
tag = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED)
2738
encode_varint = _varint_encoder()
2839

29-
def _encode_message(write: WriteFunction, value: Message, deterministic: bool) -> int:
30-
write(tag)
31-
bytes = value.SerializeToString()
32-
encode_varint(write, len(bytes), deterministic)
33-
return write(bytes)
40+
if is_repeated:
41+
42+
def _encode_repeated_message(
43+
write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool
44+
) -> int:
45+
bytes_written = 0
46+
for element in cast(List[Message], value):
47+
write(tag)
48+
bytes = element.SerializeToString()
49+
encode_varint(write, len(bytes), deterministic)
50+
bytes_written += write(bytes)
51+
return bytes_written
52+
53+
return _encode_repeated_message
54+
else:
3455

35-
return _encode_message
56+
def _encode_message(
57+
write: WriteFunction, value: Union[Message, List[Message]], deterministic: bool
58+
) -> int:
59+
write(tag)
60+
bytes = cast(Message, value).SerializeToString()
61+
encode_varint(write, len(bytes), deterministic)
62+
return write(bytes)
63+
64+
return _encode_message
3665

3766

3867
def _varint_encoder() -> Callable[[WriteFunction, int, Optional[bool]], int]:
@@ -67,25 +96,55 @@ def _message_decoder_constructor(
6796
(like DoubleXYData) are defined in .proto files, so they use whichever protobuf implementation
6897
that google.protobuf.internal.api_implementation chooses (usually upb).
6998
"""
70-
71-
def _decode_message(
72-
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
73-
) -> int:
74-
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
75-
value = field_dict.get(key)
76-
if value is None:
77-
value = field_dict.setdefault(key, new_default(message))
78-
# Read length.
79-
(size, pos) = decode_varint(buffer, pos)
80-
new_pos = pos + size
81-
if new_pos > end:
82-
raise ValueError("Error decoding a message. Message is truncated.")
83-
parsed_bytes = value.ParseFromString(buffer[pos:new_pos])
84-
if parsed_bytes != size:
85-
raise ValueError("Parsed incorrect number of bytes.")
86-
return new_pos
87-
88-
return _decode_message
99+
if is_repeated:
100+
tag_bytes = encoder.TagBytes(field_index, wire_format.WIRETYPE_LENGTH_DELIMITED)
101+
tag_len = len(tag_bytes)
102+
103+
def _decode_repeated_message(
104+
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
105+
) -> int:
106+
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
107+
value = field_dict.get(key)
108+
if value is None:
109+
value = field_dict.setdefault(key, [])
110+
while 1:
111+
parsed_value = new_default(message)
112+
# Read length.
113+
(size, pos) = decode_varint(buffer, pos)
114+
new_pos = pos + size
115+
if new_pos > end:
116+
raise ValueError("Error decoding a message. Message is truncated.")
117+
parsed_bytes = parsed_value.ParseFromString(buffer[pos:new_pos])
118+
if parsed_bytes != size:
119+
raise ValueError("Parsed incorrect number of bytes.")
120+
value.append(parsed_value)
121+
# Predict that the next tag is another copy of the same repeated field.
122+
pos = new_pos + tag_len
123+
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
124+
# Prediction failed. Return.
125+
return new_pos
126+
127+
return _decode_repeated_message
128+
else:
129+
130+
def _decode_message(
131+
buffer: memoryview, pos: int, end: int, message: Message, field_dict: Dict[Key, Any]
132+
) -> int:
133+
decode_varint = _varint_decoder(mask=(1 << 64) - 1, result_type=int)
134+
value = field_dict.get(key)
135+
if value is None:
136+
value = field_dict.setdefault(key, new_default(message))
137+
# Read length.
138+
(size, pos) = decode_varint(buffer, pos)
139+
new_pos = pos + size
140+
if new_pos > end:
141+
raise ValueError("Error decoding a message. Message is truncated.")
142+
parsed_bytes = value.ParseFromString(buffer[pos:new_pos])
143+
if parsed_bytes != size:
144+
raise ValueError("Parsed incorrect number of bytes.")
145+
return new_pos
146+
147+
return _decode_message
89148

90149

91150
T = TypeVar("T", bound="int")

ni_measurementlink_service/_internal/parameter/serialization_strategy.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ def vector_encoder(field_index: int) -> Encoder:
5656
return vector_encoder
5757

5858

59-
def _unsupported_encoder(field_index: int, is_repeated: bool, is_packed: bool) -> Encoder:
60-
raise NotImplementedError(f"Unsupported data type for field {field_index}")
61-
62-
6359
def _scalar_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor:
6460
"""Constructs a scalar decoder constructor.
6561
@@ -103,7 +99,9 @@ def vector_decoder(field_index: int, key: Key) -> Decoder:
10399
return vector_decoder
104100

105101

106-
def _double_xy_data_decoder(decoder: DecoderConstructor) -> PartialDecoderConstructor:
102+
def _double_xy_data_decoder(
103+
decoder: DecoderConstructor, is_repeated: bool
104+
) -> PartialDecoderConstructor:
107105
"""Constructs a DoubleXYData decoder constructor.
108106
109107
Takes a field index and a key and returns a Decoder for DoubleXYData.
@@ -113,7 +111,6 @@ def _new_default(unused_message: Optional[Message] = None) -> Any:
113111
return xydata_pb2.DoubleXYData()
114112

115113
def message_decoder(field_index: int, key: Key) -> Decoder:
116-
is_repeated = True
117114
is_packed = True
118115
return decoder(field_index, is_repeated, is_packed, key, _new_default)
119116

@@ -136,7 +133,9 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
136133
UIntArrayEncoder = _vector_encoder(cast(EncoderConstructor, encoder.UInt32Encoder))
137134
BoolArrayEncoder = _vector_encoder(encoder.BoolEncoder)
138135
StringArrayEncoder = _vector_encoder(encoder.StringEncoder, is_packed=False)
139-
UnsupportedMessageArrayEncoder = _vector_encoder(_unsupported_encoder)
136+
MessageArrayEncoder = _vector_encoder(
137+
cast(EncoderConstructor, _message._message_encoder_constructor)
138+
)
140139

141140
# Cast works around this issue in typeshed
142141
# https://github.com/python/typeshed/issues/10697
@@ -148,7 +147,7 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
148147
UInt64Decoder = _scalar_decoder(cast(DecoderConstructor, decoder.UInt64Decoder))
149148
BoolDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.BoolDecoder))
150149
StringDecoder = _scalar_decoder(cast(DecoderConstructor, decoder.StringDecoder))
151-
XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor)
150+
XYDataDecoder = _double_xy_data_decoder(_message._message_decoder_constructor, is_repeated=False)
152151

153152
FloatArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.FloatDecoder))
154153
DoubleArrayDecoder = _vector_decoder(cast(DecoderConstructor, decoder.DoubleDecoder))
@@ -160,6 +159,9 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
160159
StringArrayDecoder = _vector_decoder(
161160
cast(DecoderConstructor, decoder.StringDecoder), is_packed=False
162161
)
162+
XYDataArrayDecoder = _double_xy_data_decoder(
163+
_message._message_decoder_constructor, is_repeated=True
164+
)
163165

164166

165167
_FIELD_TYPE_TO_ENCODER_MAPPING = {
@@ -172,7 +174,7 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
172174
type_pb2.Field.TYPE_BOOL: (BoolEncoder, BoolArrayEncoder),
173175
type_pb2.Field.TYPE_STRING: (StringEncoder, StringArrayEncoder),
174176
type_pb2.Field.TYPE_ENUM: (IntEncoder, IntArrayEncoder),
175-
type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, UnsupportedMessageArrayEncoder),
177+
type_pb2.Field.TYPE_MESSAGE: (MessageEncoder, MessageArrayEncoder),
176178
}
177179

178180
_FIELD_TYPE_TO_DECODER_MAPPING = {
@@ -203,6 +205,10 @@ def message_decoder(field_index: int, key: Key) -> Decoder:
203205
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataDecoder,
204206
}
205207

208+
_ARRAY_MESSAGE_TYPE_TO_DECODER = {
209+
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name: XYDataArrayDecoder,
210+
}
211+
206212

207213
def get_encoder(type: type_pb2.Field.Kind.ValueType, repeated: bool) -> PartialEncoderConstructor:
208214
"""Get the appropriate partial encoder constructor for the specified type.
@@ -227,8 +233,9 @@ def get_decoder(
227233
return array_decoder if repeated else scalar_decoder
228234
elif type == type_pb2.Field.Kind.TYPE_MESSAGE:
229235
if repeated:
230-
raise ValueError(f"Repeated message types are not supported '{message_type}'")
231-
decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type)
236+
decoder = _ARRAY_MESSAGE_TYPE_TO_DECODER.get(message_type)
237+
else:
238+
decoder = _MESSAGE_TYPE_TO_DECODER.get(message_type)
232239
if decoder is None:
233240
raise ValueError(f"Unknown message type '{message_type}'")
234241
return decoder

ni_measurementlink_service/measurement/info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,4 @@ class DataType(enum.Enum):
9999
PinArray1D = 108
100100
PathArray1D = 109
101101
EnumArray1D = 110
102+
DoubleXYDataArray1D = 111

tests/unit/test_serialization_strategy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
(type_pb2.Field.TYPE_STRING, False, serialization_strategy.StringEncoder),
2121
(type_pb2.Field.TYPE_ENUM, False, serialization_strategy.IntEncoder),
2222
(type_pb2.Field.TYPE_MESSAGE, False, serialization_strategy.MessageEncoder),
23+
(type_pb2.Field.TYPE_MESSAGE, True, serialization_strategy.MessageArrayEncoder),
2324
],
2425
)
2526
def test___serialization_strategy___get_encoder___returns_expected_encoder(
@@ -48,6 +49,12 @@ def test___serialization_strategy___get_encoder___returns_expected_encoder(
4849
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
4950
serialization_strategy.XYDataDecoder,
5051
),
52+
(
53+
type_pb2.Field.TYPE_MESSAGE,
54+
True,
55+
xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
56+
serialization_strategy.XYDataArrayDecoder,
57+
),
5158
],
5259
)
5360
def test___serialization_strategy___get_decoder___returns_expected_decoder(
@@ -71,6 +78,7 @@ def test___serialization_strategy___get_decoder___returns_expected_decoder(
7178
(type_pb2.Field.TYPE_STRING, False, ""),
7279
(type_pb2.Field.TYPE_ENUM, False, 0),
7380
(type_pb2.Field.TYPE_MESSAGE, False, None),
81+
(type_pb2.Field.TYPE_MESSAGE, True, []),
7482
],
7583
)
7684
def test___serialization_strategy___get_default_value___returns_type_defaults(

tests/unit/test_serializer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77
from google.protobuf import any_pb2, type_pb2
88

9-
109
from ni_measurementlink_service._annotations import (
1110
ENUM_VALUES_KEY,
1211
TYPE_SPECIALIZATION_KEY,
@@ -43,6 +42,12 @@ class Countries(IntEnum):
4342
double_xy_data.x_data.append(4)
4443
double_xy_data.y_data.append(6)
4544

45+
double_xy_data2 = xydata_pb2.DoubleXYData()
46+
double_xy_data2.x_data.append(8)
47+
double_xy_data2.y_data.append(10)
48+
49+
double_xy_data_array = [double_xy_data, double_xy_data2]
50+
4651
# This should match the number of fields in bigmessage.proto.
4752
BIG_MESSAGE_SIZE = 100
4853

@@ -72,6 +77,7 @@ class Countries(IntEnum):
7277
Countries.AUSTRALIA,
7378
[Countries.AUSTRALIA, Countries.CANADA],
7479
double_xy_data,
80+
double_xy_data_array,
7581
],
7682
[
7783
-0.9999,
@@ -95,6 +101,7 @@ class Countries(IntEnum):
95101
Countries.AUSTRALIA,
96102
[Countries.AUSTRALIA, Countries.CANADA],
97103
double_xy_data,
104+
double_xy_data_array,
98105
],
99106
],
100107
)
@@ -133,6 +140,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu
133140
Countries.AUSTRALIA,
134141
[Countries.AUSTRALIA, Countries.CANADA],
135142
double_xy_data,
143+
double_xy_data_array,
136144
],
137145
[
138146
-0.9999,
@@ -156,6 +164,7 @@ def test___serializer___serialize_parameter___successful_serialization(test_valu
156164
Countries.AUSTRALIA,
157165
[Countries.AUSTRALIA, Countries.CANADA],
158166
double_xy_data,
167+
double_xy_data_array,
159168
],
160169
],
161170
)
@@ -193,6 +202,7 @@ def test___serializer___serialize_default_parameter___successful_serialization(d
193202
Countries.AUSTRALIA,
194203
[Countries.AUSTRALIA, Countries.CANADA],
195204
double_xy_data,
205+
double_xy_data_array,
196206
]
197207
],
198208
)
@@ -230,6 +240,7 @@ def test___empty_buffer___deserialize_parameters___returns_zero_or_empty():
230240
Countries.AUSTRALIA,
231241
[Countries.AUSTRALIA, Countries.CANADA],
232242
double_xy_data,
243+
double_xy_data_array,
233244
]
234245
parameter = _get_test_parameter_by_id(nonzero_defaults)
235246
parameter_value_by_id = serializer.deserialize_parameters(parameter, bytes())
@@ -449,6 +460,14 @@ def _get_test_parameter_by_id(default_values):
449460
annotations={},
450461
message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
451462
),
463+
22: ParameterMetadata(
464+
display_name="xy_data_array",
465+
type=type_pb2.Field.TYPE_MESSAGE,
466+
repeated=True,
467+
default_value=default_values[21],
468+
annotations={},
469+
message_type=xydata_pb2.DoubleXYData.DESCRIPTOR.full_name,
470+
),
452471
}
453472
return parameter_by_id
454473

@@ -477,6 +496,7 @@ def _get_test_grpc_message(test_values):
477496
parameter.int_enum_array_data.extend(list(map(lambda x: x.value, test_values[19])))
478497
parameter.xy_data.x_data.append(test_values[20].x_data[0])
479498
parameter.xy_data.y_data.append(test_values[20].y_data[0])
499+
parameter.xy_data_array.extend(test_values[21])
480500
return parameter
481501

482502

0 commit comments

Comments
 (0)