|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +package com.hedera.pbj.compiler.impl.generators.protobuf; |
| 3 | + |
| 4 | +import static com.hedera.pbj.compiler.impl.Common.DEFAULT_INDENT; |
| 5 | + |
| 6 | +import com.hedera.pbj.compiler.impl.Common; |
| 7 | +import com.hedera.pbj.compiler.impl.Field; |
| 8 | +import com.hedera.pbj.compiler.impl.MapField; |
| 9 | +import com.hedera.pbj.compiler.impl.OneOfField; |
| 10 | +import com.hedera.pbj.compiler.impl.SingleField; |
| 11 | +import java.util.Comparator; |
| 12 | +import java.util.List; |
| 13 | +import java.util.function.Function; |
| 14 | +import java.util.stream.Collectors; |
| 15 | +import java.util.stream.Stream; |
| 16 | + |
| 17 | +/** |
| 18 | + * Code to generate the write method for Codec classes. |
| 19 | + */ |
| 20 | +final class CodecWriteByteArrayMethodGenerator { |
| 21 | + |
| 22 | + static String generateWriteMethod( |
| 23 | + final String modelClassName, final String schemaClassName, final List<Field> fields) { |
| 24 | + final String fieldWriteLines = buildFieldWriteLines( |
| 25 | + modelClassName, |
| 26 | + schemaClassName, |
| 27 | + fields, |
| 28 | + field -> " data.%s()".formatted(field.nameCamelFirstLower()), |
| 29 | + true); |
| 30 | + // spotless:off |
| 31 | + return |
| 32 | + """ |
| 33 | + /** |
| 34 | + * Writes an item to the given byte array, this is a performance focused method. In non-performance centric use |
| 35 | + * cases there are simpler methods such as toBytes() or writing to a {@link WritableStreamingData}. |
| 36 | + * |
| 37 | + * @param data The item to write. Must not be null. |
| 38 | + * @param output The byte array to write to, this must be large enough to hold the entire item. |
| 39 | + * @param startOffset The offset in the output array to start writing at. |
| 40 | + * @return The number of bytes written to the output array. |
| 41 | + * @throws IndexOutOfBoundsException If the output array is not large enough to hold the entire item. |
| 42 | + */ |
| 43 | + public int write(@NonNull $modelClass data, @NonNull byte[] output, final int startOffset) { |
| 44 | + int offset = startOffset; |
| 45 | + $fieldWriteLines |
| 46 | + // Write unknown fields if there are any |
| 47 | + for (final UnknownField uf : data.getUnknownFields()) { |
| 48 | + final int tag = (uf.field() << TAG_FIELD_OFFSET) | uf.wireType().ordinal(); |
| 49 | + offset += ProtoArrayWriterTools.writeUnsignedVarInt(output, offset, tag); |
| 50 | + offset += uf.bytes().writeTo(output, offset); |
| 51 | + } |
| 52 | + return offset - startOffset; |
| 53 | + } |
| 54 | + """ |
| 55 | + .replace("$modelClass", modelClassName) |
| 56 | + .replace("$fieldWriteLines", fieldWriteLines) |
| 57 | + .indent(DEFAULT_INDENT); |
| 58 | + // spotless:on |
| 59 | + } |
| 60 | + |
| 61 | + private static String buildFieldWriteLines( |
| 62 | + final String modelClassName, |
| 63 | + final String schemaClassName, |
| 64 | + final List<Field> fields, |
| 65 | + final Function<Field, String> getValueBuilder, |
| 66 | + final boolean skipDefault) { |
| 67 | + return fields.stream() |
| 68 | + .flatMap(field -> field.type() == Field.FieldType.ONE_OF |
| 69 | + ? ((OneOfField) field).fields().stream() |
| 70 | + : Stream.of(field)) |
| 71 | + .sorted(Comparator.comparingInt(Field::fieldNumber)) |
| 72 | + .map(field -> generateFieldWriteLines( |
| 73 | + field, modelClassName, schemaClassName, getValueBuilder.apply(field), skipDefault)) |
| 74 | + .collect(Collectors.joining("\n")) |
| 75 | + .indent(DEFAULT_INDENT); |
| 76 | + } |
| 77 | + |
| 78 | + /** |
| 79 | + * Generate lines of code for writing field |
| 80 | + * |
| 81 | + * @param field The field to generate writing line of code for |
| 82 | + * @param modelClassName The model class name for model class for message type we are generating writer for |
| 83 | + * @param getValueCode java code to get the value of field |
| 84 | + * @param skipDefault skip writing the field if it has default value (for non-oneOf only) |
| 85 | + * @return java code to write field to output |
| 86 | + */ |
| 87 | + private static String generateFieldWriteLines( |
| 88 | + final Field field, |
| 89 | + final String modelClassName, |
| 90 | + final String schemaClassName, |
| 91 | + String getValueCode, |
| 92 | + boolean skipDefault) { |
| 93 | + final String fieldDef = schemaClassName + "." + Common.camelToUpperSnake(field.name()); |
| 94 | + String prefix = "// [%d] - %s%n".formatted(field.fieldNumber(), field.name()); |
| 95 | + |
| 96 | + if (field.parent() != null) { |
| 97 | + final OneOfField oneOfField = field.parent(); |
| 98 | + final String oneOfType = "%s.%sOneOfType".formatted(modelClassName, oneOfField.nameCamelFirstUpper()); |
| 99 | + getValueCode = "(%s)data.%s().as()".formatted(field.javaFieldType(), oneOfField.nameCamelFirstLower()); |
| 100 | + prefix += "if (data.%s().kind() == %s.%s)%n" |
| 101 | + .formatted(oneOfField.nameCamelFirstLower(), oneOfType, Common.camelToUpperSnake(field.name())); |
| 102 | + } |
| 103 | + // spotless:off |
| 104 | + final String writeMethodName = field.methodNameType(); |
| 105 | + if (field.optionalValueType()) { |
| 106 | + return prefix + switch (field.messageType()) { |
| 107 | + case "StringValue" -> "offset += ProtoArrayWriterTools.writeOptionalString(output, offset, %s, %s);" |
| 108 | + .formatted(fieldDef,getValueCode); |
| 109 | + case "BoolValue" -> "offset += ProtoArrayWriterTools.writeOptionalBoolean(output, offset, %s, %s);" |
| 110 | + .formatted(fieldDef, getValueCode); |
| 111 | + case "Int32Value" -> "offset += ProtoArrayWriterTools.writeOptionalInt32Value(output, offset, %s, %s);" |
| 112 | + .formatted(fieldDef, getValueCode); |
| 113 | + case "UInt32Value" -> "offset += ProtoArrayWriterTools.writeOptionalUInt32Value(output, offset, %s, %s);" |
| 114 | + .formatted(fieldDef, getValueCode); |
| 115 | + case "Int64Value","UInt64Value" -> "offset += ProtoArrayWriterTools.writeOptionalInt64Value(output, offset, %s, %s);" |
| 116 | + .formatted(fieldDef, getValueCode); |
| 117 | + case "FloatValue" -> "offset += ProtoArrayWriterTools.writeOptionalFloat(output, offset, %s, %s);" |
| 118 | + .formatted(fieldDef, getValueCode); |
| 119 | + case "DoubleValue" -> "offset += ProtoArrayWriterTools.writeOptionalDouble(output, offset, %s, %s);" |
| 120 | + .formatted(fieldDef, getValueCode); |
| 121 | + case "BytesValue" -> "offset += ProtoArrayWriterTools.writeOptionalBytes(output, offset, %s, %s);" |
| 122 | + .formatted(fieldDef, getValueCode); |
| 123 | + default -> throw new UnsupportedOperationException( |
| 124 | + "Unhandled optional message type:%s".formatted(field.messageType())); |
| 125 | + }; |
| 126 | + } else { |
| 127 | + String codecReference = ""; |
| 128 | + if (Field.FieldType.MESSAGE.equals(field.type())) { |
| 129 | + codecReference = "%s.%s.PROTOBUF".formatted(((SingleField) field).messageTypeModelPackage(), |
| 130 | + ((SingleField) field).completeClassName()); |
| 131 | + } |
| 132 | + if (field.repeated()) { |
| 133 | + return prefix + switch(field.type()) { |
| 134 | + case ENUM -> "offset += ProtoArrayWriterTools.writeEnumList(output, offset, %s, %s);" |
| 135 | + .formatted(fieldDef, getValueCode); |
| 136 | + case MESSAGE -> "offset += ProtoArrayWriterTools.writeMessageList(output, offset, %s, %s, %s);" |
| 137 | + .formatted(fieldDef, getValueCode, codecReference); |
| 138 | + case INT32 -> "offset += ProtoArrayWriterTools.writeInt32List(output, offset, %s, %s);" |
| 139 | + .formatted(fieldDef, getValueCode); |
| 140 | + case UINT32 -> "offset += ProtoArrayWriterTools.writeUInt32List(output, offset, %s, %s);" |
| 141 | + .formatted(fieldDef, getValueCode); |
| 142 | + case SINT32 -> "offset += ProtoArrayWriterTools.writeSInt32List(output, offset, %s, %s);" |
| 143 | + .formatted(fieldDef, getValueCode); |
| 144 | + case FIXED32, SFIXED32 -> "offset += ProtoArrayWriterTools.writeFixed32List(output, offset, %s, %s);" |
| 145 | + .formatted(fieldDef, getValueCode); |
| 146 | + case INT64, UINT64 -> "offset += ProtoArrayWriterTools.writeInt64List(output, offset, %s, %s);" |
| 147 | + .formatted(fieldDef, getValueCode); |
| 148 | + case SINT64 -> "offset += ProtoArrayWriterTools.writeSInt64List(output, offset, %s, %s);" |
| 149 | + .formatted(fieldDef, getValueCode); |
| 150 | + case FIXED64, SFIXED64 -> "offset += ProtoArrayWriterTools.writeFixed64List(output, offset, %s, %s);" |
| 151 | + .formatted(fieldDef, getValueCode); |
| 152 | + |
| 153 | + default -> "offset += ProtoArrayWriterTools.write%sList(output, offset, %s, %s);" |
| 154 | + .formatted(writeMethodName, fieldDef, getValueCode); |
| 155 | + }; |
| 156 | + } else if (field.type() == Field.FieldType.MAP) { |
| 157 | + // https://protobuf.dev/programming-guides/proto3/#maps |
| 158 | + // On the wire, a map is equivalent to: |
| 159 | + // message MapFieldEntry { |
| 160 | + // key_type key = 1; |
| 161 | + // value_type value = 2; |
| 162 | + // } |
| 163 | + // repeated MapFieldEntry map_field = N; |
| 164 | + // NOTE: we serialize the map in the natural order of keys by design, |
| 165 | + // so that the binary representation of the map is deterministic. |
| 166 | + // NOTE: protoc serializes default values (e.g. "") in maps, so we should too. |
| 167 | + final MapField mapField = (MapField) field; |
| 168 | + final List<Field> mapEntryFields = List.of(mapField.keyField(), mapField.valueField()); |
| 169 | + final Function<Field, String> getValueBuilder = mapEntryField -> |
| 170 | + mapEntryField == mapField.keyField() ? "k" : (mapEntryField == mapField.valueField() ? "v" : null); |
| 171 | + final String fieldWriteLines = buildFieldWriteLines( |
| 172 | + field.name(), |
| 173 | + schemaClassName, |
| 174 | + mapEntryFields, |
| 175 | + getValueBuilder, |
| 176 | + false); |
| 177 | + final String fieldSizeOfLines = CodecMeasureRecordMethodGenerator.buildFieldSizeOfLines( |
| 178 | + field.name(), |
| 179 | + mapEntryFields, |
| 180 | + getValueBuilder, |
| 181 | + false); |
| 182 | + return prefix + """ |
| 183 | + if (!$map.isEmpty()) { |
| 184 | + final Pbj$javaFieldType pbjMap = (Pbj$javaFieldType) $map; |
| 185 | + final int mapSize = pbjMap.size(); |
| 186 | + for (int i = 0; i < mapSize; i++) { |
| 187 | + offset += ProtoArrayWriterTools.writeTag(output, offset, $fieldDef, WIRE_TYPE_DELIMITED); |
| 188 | + $K k = pbjMap.getSortedKeys().get(i); |
| 189 | + $V v = pbjMap.get(k); |
| 190 | + int size = 0; |
| 191 | + $fieldSizeOfLines |
| 192 | + offset += ProtoArrayWriterTools.writeUnsignedVarInt(output, offset, size); |
| 193 | + $fieldWriteLines |
| 194 | + } |
| 195 | + } |
| 196 | + """ |
| 197 | + .replace("$fieldDef", fieldDef) |
| 198 | + .replace("$map", getValueCode) |
| 199 | + .replace("$javaFieldType", mapField.javaFieldType()) |
| 200 | + .replace("$K", mapField.keyField().type().boxedType) |
| 201 | + .replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType) |
| 202 | + .replace("$fieldWriteLines", fieldWriteLines.indent(DEFAULT_INDENT)) |
| 203 | + .replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT)); |
| 204 | + } else { |
| 205 | + return prefix + switch(field.type()) { |
| 206 | + case ENUM -> "offset += ProtoArrayWriterTools.writeEnum(output, offset, %s, %s);" |
| 207 | + .formatted(fieldDef, getValueCode); |
| 208 | + case STRING -> "offset += ProtoArrayWriterTools.writeString(output, offset, %s, %s, %s);" |
| 209 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 210 | + case MESSAGE -> "offset += ProtoArrayWriterTools.writeMessage(output, offset, %s, %s, %s);" |
| 211 | + .formatted(fieldDef, getValueCode, codecReference); |
| 212 | + case BOOL -> "offset += ProtoArrayWriterTools.writeBoolean(output, offset, %s, %s, %s);" |
| 213 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 214 | + case INT32 -> "offset += ProtoArrayWriterTools.writeInt32(output, offset, %s, %s, %s);" |
| 215 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 216 | + case UINT32 -> "offset += ProtoArrayWriterTools.writeUInt32(output, offset, %s, %s, %s);" |
| 217 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 218 | + case SINT32 -> "offset += ProtoArrayWriterTools.writSInt32(output, offset, %s, %s, %s);" |
| 219 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 220 | + case FIXED32, SFIXED32 -> "offset += ProtoArrayWriterTools.writeFixed32(output, offset, %s, %s, %s);" |
| 221 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 222 | + case INT64, UINT64 -> "offset += ProtoArrayWriterTools.writeInt64(output, offset, %s, %s, %s);" |
| 223 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 224 | + case SINT64 -> "offset += ProtoArrayWriterTools.writeSInt64(output, offset, %s, %s, %s);" |
| 225 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 226 | + case FIXED64, SFIXED64 -> "offset += ProtoArrayWriterTools.writeFixed64(output, offset, %s, %s, %s);" |
| 227 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 228 | + case BYTES -> "offset += ProtoArrayWriterTools.writeBytes(output, offset, %s, %s, %s);" |
| 229 | + .formatted(fieldDef, getValueCode, skipDefault); |
| 230 | + default -> "offset += ProtoArrayWriterTools.write%s(output, offset, %s, %s);" |
| 231 | + .formatted(writeMethodName, fieldDef, getValueCode); |
| 232 | + }; |
| 233 | + } |
| 234 | + } |
| 235 | + // spotless:on |
| 236 | + } |
| 237 | +} |
0 commit comments