Skip to content

Commit fd93a3b

Browse files
feat: introduce writeTo(byte[]) (#614)
Signed-off-by: Anthony Petrov <anthony@swirldslabs.com>
1 parent f3f3eb3 commit fd93a3b

10 files changed

Lines changed: 1526 additions & 16 deletions

File tree

pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,17 @@ public void generate(
5959
}
6060
final String writeMethod =
6161
CodecWriteMethodGenerator.generateWriteMethod(modelClassName, schemaClassName, fields);
62+
final String writeByteArrayMethod =
63+
CodecWriteByteArrayMethodGenerator.generateWriteMethod(modelClassName, schemaClassName, fields);
6264

6365
final String staticModifier = Generator.isInner(msgDef) ? " static" : "";
6466

6567
writer.addImport("com.hedera.pbj.runtime.*");
6668
writer.addImport("com.hedera.pbj.runtime.io.*");
6769
writer.addImport("com.hedera.pbj.runtime.io.buffer.*");
6870
writer.addImport("com.hedera.pbj.runtime.io.stream.EOFException");
71+
writer.addImport("com.hedera.pbj.runtime.io.stream.WritableStreamingData");
72+
writer.addImport("com.hedera.pbj.runtime.ProtoArrayWriterTools");
6973
writer.addImport("java.io.IOException");
7074
writer.addImport("java.nio.*");
7175
writer.addImport("java.nio.charset.*");
@@ -78,6 +82,7 @@ public void generate(
7882
writer.addImport("static com.hedera.pbj.runtime.ProtoWriterTools.*");
7983
writer.addImport("static com.hedera.pbj.runtime.ProtoParserTools.*");
8084
writer.addImport("static com.hedera.pbj.runtime.ProtoConstants.*");
85+
writer.addImport("static com.hedera.pbj.runtime.Utf8Tools.*");
8186

8287
// spotless:off
8388
writer.append("""
@@ -104,6 +109,7 @@ public void generate(
104109
$unsetOneOfConstants
105110
$parseMethod
106111
$writeMethod
112+
$writeByteArrayMethod
107113
$measureDataMethod
108114
$measureRecordMethod
109115
$fastEqualsMethod
@@ -116,6 +122,7 @@ public void generate(
116122
.replace("$unsetOneOfConstants", CodecParseMethodGenerator.generateUnsetOneOfConstants(fields))
117123
.replace("$parseMethod", CodecParseMethodGenerator.generateParseMethod(modelClassName, schemaClassName, fields))
118124
.replace("$writeMethod", writeMethod)
125+
.replace("$writeByteArrayMethod", writeByteArrayMethod)
119126
.replace("$measureDataMethod", CodecMeasureDataMethodGenerator.generateMeasureMethod(modelClassName, fields))
120127
.replace("$measureRecordMethod", CodecMeasureRecordMethodGenerator.generateMeasureMethod(modelClassName, fields))
121128
.replace("$fastEqualsMethod", CodecFastEqualsMethodGenerator.generateFastEqualsMethod(modelClassName, fields))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
package com.hedera.pbj.compiler.impl.generators.protobuf;
3+
4+
import static com.hedera.pbj.compiler.impl.Common.DEFAULT_INDENT;
5+
6+
import com.hedera.pbj.compiler.impl.Common;
7+
import com.hedera.pbj.compiler.impl.Field;
8+
import com.hedera.pbj.compiler.impl.MapField;
9+
import com.hedera.pbj.compiler.impl.OneOfField;
10+
import com.hedera.pbj.compiler.impl.SingleField;
11+
import java.util.Comparator;
12+
import java.util.List;
13+
import java.util.function.Function;
14+
import java.util.stream.Collectors;
15+
import java.util.stream.Stream;
16+
17+
/**
18+
* Code to generate the write method for Codec classes.
19+
*/
20+
final class CodecWriteByteArrayMethodGenerator {
21+
22+
static String generateWriteMethod(
23+
final String modelClassName, final String schemaClassName, final List<Field> fields) {
24+
final String fieldWriteLines = buildFieldWriteLines(
25+
modelClassName,
26+
schemaClassName,
27+
fields,
28+
field -> " data.%s()".formatted(field.nameCamelFirstLower()),
29+
true);
30+
// spotless:off
31+
return
32+
"""
33+
/**
34+
* Writes an item to the given byte array, this is a performance focused method. In non-performance centric use
35+
* cases there are simpler methods such as toBytes() or writing to a {@link WritableStreamingData}.
36+
*
37+
* @param data The item to write. Must not be null.
38+
* @param output The byte array to write to, this must be large enough to hold the entire item.
39+
* @param startOffset The offset in the output array to start writing at.
40+
* @return The number of bytes written to the output array.
41+
* @throws IndexOutOfBoundsException If the output array is not large enough to hold the entire item.
42+
*/
43+
public int write(@NonNull $modelClass data, @NonNull byte[] output, final int startOffset) {
44+
int offset = startOffset;
45+
$fieldWriteLines
46+
// Write unknown fields if there are any
47+
for (final UnknownField uf : data.getUnknownFields()) {
48+
final int tag = (uf.field() << TAG_FIELD_OFFSET) | uf.wireType().ordinal();
49+
offset += ProtoArrayWriterTools.writeUnsignedVarInt(output, offset, tag);
50+
offset += uf.bytes().writeTo(output, offset);
51+
}
52+
return offset - startOffset;
53+
}
54+
"""
55+
.replace("$modelClass", modelClassName)
56+
.replace("$fieldWriteLines", fieldWriteLines)
57+
.indent(DEFAULT_INDENT);
58+
// spotless:on
59+
}
60+
61+
private static String buildFieldWriteLines(
62+
final String modelClassName,
63+
final String schemaClassName,
64+
final List<Field> fields,
65+
final Function<Field, String> getValueBuilder,
66+
final boolean skipDefault) {
67+
return fields.stream()
68+
.flatMap(field -> field.type() == Field.FieldType.ONE_OF
69+
? ((OneOfField) field).fields().stream()
70+
: Stream.of(field))
71+
.sorted(Comparator.comparingInt(Field::fieldNumber))
72+
.map(field -> generateFieldWriteLines(
73+
field, modelClassName, schemaClassName, getValueBuilder.apply(field), skipDefault))
74+
.collect(Collectors.joining("\n"))
75+
.indent(DEFAULT_INDENT);
76+
}
77+
78+
/**
79+
* Generate lines of code for writing field
80+
*
81+
* @param field The field to generate writing line of code for
82+
* @param modelClassName The model class name for model class for message type we are generating writer for
83+
* @param getValueCode java code to get the value of field
84+
* @param skipDefault skip writing the field if it has default value (for non-oneOf only)
85+
* @return java code to write field to output
86+
*/
87+
private static String generateFieldWriteLines(
88+
final Field field,
89+
final String modelClassName,
90+
final String schemaClassName,
91+
String getValueCode,
92+
boolean skipDefault) {
93+
final String fieldDef = schemaClassName + "." + Common.camelToUpperSnake(field.name());
94+
String prefix = "// [%d] - %s%n".formatted(field.fieldNumber(), field.name());
95+
96+
if (field.parent() != null) {
97+
final OneOfField oneOfField = field.parent();
98+
final String oneOfType = "%s.%sOneOfType".formatted(modelClassName, oneOfField.nameCamelFirstUpper());
99+
getValueCode = "(%s)data.%s().as()".formatted(field.javaFieldType(), oneOfField.nameCamelFirstLower());
100+
prefix += "if (data.%s().kind() == %s.%s)%n"
101+
.formatted(oneOfField.nameCamelFirstLower(), oneOfType, Common.camelToUpperSnake(field.name()));
102+
}
103+
// spotless:off
104+
final String writeMethodName = field.methodNameType();
105+
if (field.optionalValueType()) {
106+
return prefix + switch (field.messageType()) {
107+
case "StringValue" -> "offset += ProtoArrayWriterTools.writeOptionalString(output, offset, %s, %s);"
108+
.formatted(fieldDef,getValueCode);
109+
case "BoolValue" -> "offset += ProtoArrayWriterTools.writeOptionalBoolean(output, offset, %s, %s);"
110+
.formatted(fieldDef, getValueCode);
111+
case "Int32Value" -> "offset += ProtoArrayWriterTools.writeOptionalInt32Value(output, offset, %s, %s);"
112+
.formatted(fieldDef, getValueCode);
113+
case "UInt32Value" -> "offset += ProtoArrayWriterTools.writeOptionalUInt32Value(output, offset, %s, %s);"
114+
.formatted(fieldDef, getValueCode);
115+
case "Int64Value","UInt64Value" -> "offset += ProtoArrayWriterTools.writeOptionalInt64Value(output, offset, %s, %s);"
116+
.formatted(fieldDef, getValueCode);
117+
case "FloatValue" -> "offset += ProtoArrayWriterTools.writeOptionalFloat(output, offset, %s, %s);"
118+
.formatted(fieldDef, getValueCode);
119+
case "DoubleValue" -> "offset += ProtoArrayWriterTools.writeOptionalDouble(output, offset, %s, %s);"
120+
.formatted(fieldDef, getValueCode);
121+
case "BytesValue" -> "offset += ProtoArrayWriterTools.writeOptionalBytes(output, offset, %s, %s);"
122+
.formatted(fieldDef, getValueCode);
123+
default -> throw new UnsupportedOperationException(
124+
"Unhandled optional message type:%s".formatted(field.messageType()));
125+
};
126+
} else {
127+
String codecReference = "";
128+
if (Field.FieldType.MESSAGE.equals(field.type())) {
129+
codecReference = "%s.%s.PROTOBUF".formatted(((SingleField) field).messageTypeModelPackage(),
130+
((SingleField) field).completeClassName());
131+
}
132+
if (field.repeated()) {
133+
return prefix + switch(field.type()) {
134+
case ENUM -> "offset += ProtoArrayWriterTools.writeEnumList(output, offset, %s, %s);"
135+
.formatted(fieldDef, getValueCode);
136+
case MESSAGE -> "offset += ProtoArrayWriterTools.writeMessageList(output, offset, %s, %s, %s);"
137+
.formatted(fieldDef, getValueCode, codecReference);
138+
case INT32 -> "offset += ProtoArrayWriterTools.writeInt32List(output, offset, %s, %s);"
139+
.formatted(fieldDef, getValueCode);
140+
case UINT32 -> "offset += ProtoArrayWriterTools.writeUInt32List(output, offset, %s, %s);"
141+
.formatted(fieldDef, getValueCode);
142+
case SINT32 -> "offset += ProtoArrayWriterTools.writeSInt32List(output, offset, %s, %s);"
143+
.formatted(fieldDef, getValueCode);
144+
case FIXED32, SFIXED32 -> "offset += ProtoArrayWriterTools.writeFixed32List(output, offset, %s, %s);"
145+
.formatted(fieldDef, getValueCode);
146+
case INT64, UINT64 -> "offset += ProtoArrayWriterTools.writeInt64List(output, offset, %s, %s);"
147+
.formatted(fieldDef, getValueCode);
148+
case SINT64 -> "offset += ProtoArrayWriterTools.writeSInt64List(output, offset, %s, %s);"
149+
.formatted(fieldDef, getValueCode);
150+
case FIXED64, SFIXED64 -> "offset += ProtoArrayWriterTools.writeFixed64List(output, offset, %s, %s);"
151+
.formatted(fieldDef, getValueCode);
152+
153+
default -> "offset += ProtoArrayWriterTools.write%sList(output, offset, %s, %s);"
154+
.formatted(writeMethodName, fieldDef, getValueCode);
155+
};
156+
} else if (field.type() == Field.FieldType.MAP) {
157+
// https://protobuf.dev/programming-guides/proto3/#maps
158+
// On the wire, a map is equivalent to:
159+
// message MapFieldEntry {
160+
// key_type key = 1;
161+
// value_type value = 2;
162+
// }
163+
// repeated MapFieldEntry map_field = N;
164+
// NOTE: we serialize the map in the natural order of keys by design,
165+
// so that the binary representation of the map is deterministic.
166+
// NOTE: protoc serializes default values (e.g. "") in maps, so we should too.
167+
final MapField mapField = (MapField) field;
168+
final List<Field> mapEntryFields = List.of(mapField.keyField(), mapField.valueField());
169+
final Function<Field, String> getValueBuilder = mapEntryField ->
170+
mapEntryField == mapField.keyField() ? "k" : (mapEntryField == mapField.valueField() ? "v" : null);
171+
final String fieldWriteLines = buildFieldWriteLines(
172+
field.name(),
173+
schemaClassName,
174+
mapEntryFields,
175+
getValueBuilder,
176+
false);
177+
final String fieldSizeOfLines = CodecMeasureRecordMethodGenerator.buildFieldSizeOfLines(
178+
field.name(),
179+
mapEntryFields,
180+
getValueBuilder,
181+
false);
182+
return prefix + """
183+
if (!$map.isEmpty()) {
184+
final Pbj$javaFieldType pbjMap = (Pbj$javaFieldType) $map;
185+
final int mapSize = pbjMap.size();
186+
for (int i = 0; i < mapSize; i++) {
187+
offset += ProtoArrayWriterTools.writeTag(output, offset, $fieldDef, WIRE_TYPE_DELIMITED);
188+
$K k = pbjMap.getSortedKeys().get(i);
189+
$V v = pbjMap.get(k);
190+
int size = 0;
191+
$fieldSizeOfLines
192+
offset += ProtoArrayWriterTools.writeUnsignedVarInt(output, offset, size);
193+
$fieldWriteLines
194+
}
195+
}
196+
"""
197+
.replace("$fieldDef", fieldDef)
198+
.replace("$map", getValueCode)
199+
.replace("$javaFieldType", mapField.javaFieldType())
200+
.replace("$K", mapField.keyField().type().boxedType)
201+
.replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType)
202+
.replace("$fieldWriteLines", fieldWriteLines.indent(DEFAULT_INDENT))
203+
.replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT));
204+
} else {
205+
return prefix + switch(field.type()) {
206+
case ENUM -> "offset += ProtoArrayWriterTools.writeEnum(output, offset, %s, %s);"
207+
.formatted(fieldDef, getValueCode);
208+
case STRING -> "offset += ProtoArrayWriterTools.writeString(output, offset, %s, %s, %s);"
209+
.formatted(fieldDef, getValueCode, skipDefault);
210+
case MESSAGE -> "offset += ProtoArrayWriterTools.writeMessage(output, offset, %s, %s, %s);"
211+
.formatted(fieldDef, getValueCode, codecReference);
212+
case BOOL -> "offset += ProtoArrayWriterTools.writeBoolean(output, offset, %s, %s, %s);"
213+
.formatted(fieldDef, getValueCode, skipDefault);
214+
case INT32 -> "offset += ProtoArrayWriterTools.writeInt32(output, offset, %s, %s, %s);"
215+
.formatted(fieldDef, getValueCode, skipDefault);
216+
case UINT32 -> "offset += ProtoArrayWriterTools.writeUInt32(output, offset, %s, %s, %s);"
217+
.formatted(fieldDef, getValueCode, skipDefault);
218+
case SINT32 -> "offset += ProtoArrayWriterTools.writSInt32(output, offset, %s, %s, %s);"
219+
.formatted(fieldDef, getValueCode, skipDefault);
220+
case FIXED32, SFIXED32 -> "offset += ProtoArrayWriterTools.writeFixed32(output, offset, %s, %s, %s);"
221+
.formatted(fieldDef, getValueCode, skipDefault);
222+
case INT64, UINT64 -> "offset += ProtoArrayWriterTools.writeInt64(output, offset, %s, %s, %s);"
223+
.formatted(fieldDef, getValueCode, skipDefault);
224+
case SINT64 -> "offset += ProtoArrayWriterTools.writeSInt64(output, offset, %s, %s, %s);"
225+
.formatted(fieldDef, getValueCode, skipDefault);
226+
case FIXED64, SFIXED64 -> "offset += ProtoArrayWriterTools.writeFixed64(output, offset, %s, %s, %s);"
227+
.formatted(fieldDef, getValueCode, skipDefault);
228+
case BYTES -> "offset += ProtoArrayWriterTools.writeBytes(output, offset, %s, %s, %s);"
229+
.formatted(fieldDef, getValueCode, skipDefault);
230+
default -> "offset += ProtoArrayWriterTools.write%s(output, offset, %s, %s);"
231+
.formatted(writeMethodName, fieldDef, getValueCode);
232+
};
233+
}
234+
}
235+
// spotless:on
236+
}
237+
}

pbj-core/pbj-grpc-helidon/src/test/java/com/hedera/pbj/grpc/helidon/PbjTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,8 @@ private GrpcStatus grpcStatus(Http2ClientResponse response) {
696696
try {
697697
return grpcStatus(response.headers());
698698
} catch (NoSuchElementException e) {
699+
// We cannot request trailers before requesting an entity, so:
700+
response.entity();
699701
return grpcStatus(response.trailers());
700702
}
701703
}

pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Codec.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
*
1616
* @param <T> The type of object to serialize and deserialize
1717
*/
18-
public interface Codec<T /*extends Record*/> {
19-
20-
// NOTE: When services has finished migrating to protobuf based objects in state,
21-
// then we should strongly enforce Codec works with Records. This will reduce bugs
22-
// where people try to use a mutable object.
18+
public interface Codec<T> {
2319

2420
/**
2521
* Parses an object from the {@link ReadableSequentialData} and returns it.
@@ -157,6 +153,27 @@ default T parseStrict(@NonNull Bytes bytes) throws ParseException {
157153
*/
158154
void write(@NonNull T item, @NonNull WritableSequentialData output) throws IOException;
159155

156+
/**
157+
* Writes an item to the given byte array, this is a performance focused method. In non-performance centric use
158+
* cases there are simpler methods such as {@link #toBytes(T)} or writing to a {@link WritableStreamingData}.
159+
*
160+
* @param item The item to write. Must not be null.
161+
* @param output The byte array to write to, this must be large enough to hold the entire item.
162+
* @param startOffset The offset in the output array to start writing at.
163+
* @return The number of bytes written to the output array.
164+
* @throws UncheckedIOException If the there is a problem writing to the output array.
165+
* @throws IndexOutOfBoundsException If the output array is not large enough to hold the entire item.
166+
*/
167+
default int write(@NonNull T item, @NonNull byte[] output, final int startOffset) {
168+
final BufferedData bufferedData = BufferedData.wrap(output, startOffset, output.length - startOffset);
169+
try {
170+
write(item, bufferedData);
171+
} catch (IOException e) {
172+
throw new UncheckedIOException(e);
173+
}
174+
return (int) bufferedData.position();
175+
}
176+
160177
/**
161178
* Reads from this data input the length of the data within the input. The implementation may
162179
* read all the data, or just some special serialized data, as needed to find out the length of
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
package com.hedera.pbj.runtime;
3+
4+
/**
5+
* Thrown during the UTF-8 encoding process when it is malformed.
6+
*/
7+
public class MalformedUtf8Exception extends RuntimeException {
8+
9+
/**
10+
* Construct new MalformedUtf8Exception
11+
*
12+
* @param message error message
13+
*/
14+
public MalformedUtf8Exception(final String message) {
15+
super(message);
16+
}
17+
18+
public MalformedUtf8Exception(final String message, final Throwable cause) {
19+
super(message, cause);
20+
}
21+
}

0 commit comments

Comments
 (0)