From 4b876bc552a5a01e0900130153721680547d5b6a Mon Sep 17 00:00:00 2001 From: Anthony Petrov Date: Tue, 30 Sep 2025 17:16:44 -0700 Subject: [PATCH 1/2] perf: improve strings handling Signed-off-by: Anthony Petrov --- .../com/hedera/pbj/compiler/impl/Common.java | 152 +++++--- .../com/hedera/pbj/compiler/impl/Field.java | 124 ++++++- .../hedera/pbj/compiler/impl/MapField.java | 6 +- .../hedera/pbj/compiler/impl/SingleField.java | 51 ++- .../impl/generators/ModelGenerator.java | 216 ++++++----- .../impl/generators/TestGenerator.java | 7 +- .../generators/json/JsonCodecGenerator.java | 1 + .../json/JsonCodecParseMethodGenerator.java | 36 +- .../json/JsonCodecWriteMethodGenerator.java | 67 ++-- .../protobuf/CodecParseMethodGenerator.java | 49 ++- .../CodecWriteByteArrayMethodGenerator.java | 41 ++- .../protobuf/CodecWriteMethodGenerator.java | 43 ++- .../LazyGetProtobufSizeMethodGenerator.java | 2 +- .../com/hedera/pbj/runtime/PbjConstants.java | 10 + .../pbj/runtime/ProtoArrayWriterTools.java | 112 ++++++ .../hedera/pbj/runtime/ProtoParserTools.java | 25 ++ .../hedera/pbj/runtime/ProtoWriterTools.java | 145 +++++++- .../com/hedera/pbj/runtime/Utf8Tools.java | 32 ++ .../pbj/runtime/ProtoParserToolsTest.java | 8 +- .../pbj/runtime/ProtoWriterToolsTest.java | 51 +-- .../pbj/integration/jmh/utf8/Utf8Bench.java | 320 +++++++++++++++++ .../pbj/integration/jmh/utf8/Utf8ToolsV0.java | 22 ++ .../pbj/integration/jmh/utf8/Utf8ToolsV1.java | 127 +++++++ .../pbj/integration/jmh/utf8/Utf8ToolsV2.java | 151 ++++++++ .../pbj/integration/jmh/utf8/Utf8ToolsV3.java | 278 +++++++++++++++ .../pbj/integration/jmh/utf8/Utf8ToolsV4.java | 337 ++++++++++++++++++ .../pbj/integration/fuzz/SingleFuzzTest.java | 10 + .../main/proto/extendedUtf8StingTest.proto | 19 + .../pbj/integration/test/HashEqualsTest.java | 40 +++ .../pbj/integration/test/SampleFuzzTest.java | 4 +- 30 files changed, 2230 insertions(+), 256 deletions(-) create mode 100644 pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjConstants.java create mode 100644 pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8Bench.java create mode 100644 pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV0.java create mode 100644 pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV1.java create mode 100644 pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java create mode 100644 pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV3.java create mode 100644 pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV4.java diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java index f77f48d1c..b3d4faaa0 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java @@ -292,7 +292,15 @@ public static String getFieldsHashCode(final List fields, String generate .replace("$fieldName", f.nameCamelFirstLower()); } else if (f.type() == Field.FieldType.MAP) { generatedCodeSoFar += getMapHashCodeGeneration(generatedCodeSoFar, f); - } else if (f.type() == Field.FieldType.STRING || f.parent() == null) { // process sub message + } else if (f.isString()) { + generatedCodeSoFar += + (""" + if (!Arrays.equals($fieldName, DEFAULT.$fieldName)) { + result = 31 * result + Arrays.hashCode($fieldName); + } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.parent() == null) { // process sub message generatedCodeSoFar += (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { @@ -310,7 +318,7 @@ public static String getFieldsHashCode(final List fields, String generate } /** - * Get the hashcode codegen for a optional field. + * Get the hashcode codegen for an optional field. * * @param generatedCodeSoFar The string that the codegen is generated into. * @param f The field for which to generate the hash code. @@ -320,55 +328,62 @@ public static String getFieldsHashCode(final List fields, String generate @NonNull private static String getPrimitiveWrapperHashCodeGeneration(String generatedCodeSoFar, Field f) { switch (f.messageType()) { - case "StringValue" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + $fieldName.hashCode(); + case "StringValue" -> + generatedCodeSoFar += + (""" + if ($fieldName != null && !Arrays.equals($fieldName, DEFAULT.$fieldName)) { + result = 31 * result + Arrays.hashCode($fieldName); } """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "BoolValue" -> generatedCodeSoFar += - (""" + .replace("$fieldName", f.nameCamelFirstLower()); + case "BoolValue" -> + generatedCodeSoFar += + (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { result = 31 * result + Boolean.hashCode($fieldName); } """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "Int32Value", "UInt32Value" -> generatedCodeSoFar += - (""" + .replace("$fieldName", f.nameCamelFirstLower()); + case "Int32Value", "UInt32Value" -> + generatedCodeSoFar += + (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { result = 31 * result + Integer.hashCode($fieldName); } """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "Int64Value", "UInt64Value" -> generatedCodeSoFar += - (""" + .replace("$fieldName", f.nameCamelFirstLower()); + case "Int64Value", "UInt64Value" -> + generatedCodeSoFar += + (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { result = 31 * result + Long.hashCode($fieldName); } """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "FloatValue" -> generatedCodeSoFar += - (""" + .replace("$fieldName", f.nameCamelFirstLower()); + case "FloatValue" -> + generatedCodeSoFar += + (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { result = 31 * result + Float.hashCode($fieldName); } """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "DoubleValue" -> generatedCodeSoFar += - (""" + .replace("$fieldName", f.nameCamelFirstLower()); + case "DoubleValue" -> + generatedCodeSoFar += + (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { result = 31 * result + Double.hashCode($fieldName); } """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "BytesValue" -> generatedCodeSoFar += - (""" + .replace("$fieldName", f.nameCamelFirstLower()); + case "BytesValue" -> + generatedCodeSoFar += + (""" if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { result = 31 * result + ($fieldName == null ? 0 : $fieldName.hashCode()); } """) - .replace("$fieldName", f.nameCamelFirstLower()); + .replace("$fieldName", f.nameCamelFirstLower()); default -> throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); } return generatedCodeSoFar; @@ -390,14 +405,16 @@ private static String getRepeatedHashCodeGeneration(String generatedCodeSoFar, F if (list$$fieldName != null) { for (Object o : list$$fieldName) { if (o != null) { - result = 31 * result + o.hashCode(); + result = 31 * result + $singleHashCodeGetter; } else { result = 31 * result; } } } """) - .replace("$fieldName", f.nameCamelFirstLower()); + .replace("$fieldName", f.nameCamelFirstLower()) + .replace( + "$singleHashCodeGetter", f.isString() ? "Arrays.hashCode((byte[]) o)" : "o.hashCode()"); return generatedCodeSoFar; } @@ -500,8 +517,15 @@ public static String getFieldsEqualsStatements(final List fields, String } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.STRING - || f.type() == Field.FieldType.BYTES + } else if (f.type() == Field.FieldType.STRING) { + generatedCodeSoFar += + (""" + if (!Arrays.equals($fieldName, thatObj.$fieldName)) { + return false; + } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.BYTES || f.type() == Field.FieldType.ENUM || f.type() == Field.FieldType.MAP || f.parent() == null /* Process a sub-message */) { @@ -535,16 +559,24 @@ public static String getFieldsEqualsStatements(final List fields, String @NonNull private static String getPrimitiveWrapperEqualsGeneration(String generatedCodeSoFar, Field f) { switch (f.messageType()) { - case "StringValue", - "BoolValue", + case "StringValue" -> + generatedCodeSoFar += + (""" + if (!Arrays.equals($fieldName, thatObj.$fieldName)) { + return false; + } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + case "BoolValue", "Int32Value", "UInt32Value", "Int64Value", "UInt64Value", "FloatValue", "DoubleValue", - "BytesValue" -> generatedCodeSoFar += - (""" + "BytesValue" -> + generatedCodeSoFar += + (""" if (this.$fieldName == null && thatObj.$fieldName != null) { return false; } @@ -552,7 +584,7 @@ private static String getPrimitiveWrapperEqualsGeneration(String generatedCodeSo return false; } """) - .replace("$fieldName", f.nameCamelFirstLower()); + .replace("$fieldName", f.nameCamelFirstLower()); default -> throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); } return generatedCodeSoFar; @@ -574,11 +606,33 @@ private static String getRepeatedEqualsGeneration(String generatedCodeSoFar, Fie return false; } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + if (f.isString()) { + generatedCodeSoFar += + (""" + if (this.$fieldName != null) { + if (thatObj.$fieldName == null || this.$fieldName.size() != thatObj.$fieldName.size()) { + return false; + } + for (int i = this.$fieldName.size() - 1; i >= 0; i--) { + if (!Arrays.equals(this.$fieldName.get(i), thatObj.$fieldName.get(i))) { + return false; + } + } + } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + + } else { + generatedCodeSoFar += + (""" if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { return false; } """) - .replace("$fieldName", f.nameCamelFirstLower()); + .replace("$fieldName", f.nameCamelFirstLower()); + } return generatedCodeSoFar; } @@ -667,9 +721,24 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.STRING - || f.type() == Field.FieldType.BYTES - || f.type() == Field.FieldType.ENUM) { + } else if (f.type() == Field.FieldType.STRING) { + generatedCodeSoFar += + """ + if ($fieldName == null && thatObj.$fieldName != null) { + return -1; + } + if ($fieldName != null && thatObj.$fieldName == null) { + return 1; + } + if ($fieldName != null) { + result = Arrays.compare($fieldName, thatObj.$fieldName); + } + if (result != 0) { + return result; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.BYTES || f.type() == Field.FieldType.ENUM) { generatedCodeSoFar += generateCompareToForObject(f); } else if (f.type() == Field.FieldType.MESSAGE || f.type() == Field.FieldType.ONE_OF) { verifyComparable(f); @@ -752,7 +821,8 @@ private static String getPrimitiveWrapperCompareToGeneration(Field f) { final String compareStatement = switch (f.messageType()) { - case "StringValue", "BytesValue" -> "$fieldName.compareTo(thatObj.$fieldName)"; + case "StringValue" -> "Arrays.compare($fieldName, thatObj.$fieldName)"; + case "BytesValue" -> "$fieldName.compareTo(thatObj.$fieldName)"; case "BoolValue" -> "java.lang.Boolean.compare($fieldName, thatObj.$fieldName)"; case "Int32Value" -> "java.lang.Integer.compare($fieldName, thatObj.$fieldName)"; case "UInt32Value" -> "java.lang.Integer.compareUnsigned($fieldName, thatObj.$fieldName)"; @@ -760,8 +830,8 @@ private static String getPrimitiveWrapperCompareToGeneration(Field f) { case "UInt64Value" -> "java.lang.Long.compareUnsigned($fieldName, thatObj.$fieldName)"; case "FloatValue" -> "java.lang.Float.compare($fieldName, thatObj.$fieldName)"; case "DoubleValue" -> "java.lang.Double.compare($fieldName, thatObj.$fieldName)"; - default -> throw new UnsupportedOperationException( - "Unhandled optional message type:" + f.messageType()); + default -> + throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); }; return template.replace("$compareStatement", compareStatement).replace("$fieldName", f.nameCamelFirstLower()); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java index 65ff9bf80..362008ba2 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java @@ -7,7 +7,10 @@ import static com.hedera.pbj.compiler.impl.Common.TYPE_VARINT; import static com.hedera.pbj.compiler.impl.Common.snakeToCamel; +import com.hedera.pbj.compiler.impl.generators.protobuf.CodecWriteByteArrayMethodGenerator; +import com.hedera.pbj.compiler.impl.generators.protobuf.CodecWriteMethodGenerator; import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; +import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser.MessageDefContext; import edu.umd.cs.findbugs.annotations.NonNull; import java.util.function.Consumer; @@ -17,6 +20,9 @@ @SuppressWarnings("unused") public interface Field { + /** Annotation to add to fields that can't be set to null */ + String NON_NULL_ANNOTATION = "@NonNull"; + /** The default maximum size of a repeated or length-encoded field (Bytes, String, Message, etc.). * The size should not be increased beyond the current limit because of the safety concerns. */ @@ -91,6 +97,17 @@ default String nameCamelFirstLower() { */ String javaFieldType(); + /** + * Get the Java storage field type for this field, this allows for fields to be stored in a different format than + * they are presented to the user, for example String fields are stored as byte array. Normally this is the same as + * {@link #javaFieldType()} except for a few special cases. + * + * @return this fields type in Java format + */ + default String javaFieldStorageType() { + return javaFieldType(); + } + /** * Get the Java field type for this field. * Unlike {@link #javaFieldType()}, this method returns the base type for repeated and oneOf fields. @@ -106,6 +123,111 @@ default String nameCamelFirstLower() { */ String methodNameType(); + /** + * Check if the storage type is different from the exposed type, this is true for String fields which are stored + * as byte arrays. + * + * @return true if storage type is different from exposed type, otherwise false + */ + default boolean hasDifferentStorageType() { + return !javaFieldStorageType().equals(javaFieldType()); + } + + /** + * Get the code for setting storage field, by default a no-op. This can be used for fields like String where we + * want different internal storage types. + * + * @param inputVarName the name of the variable being passed in to the setter + * @param msgDef the message definition + * @param lookupHelper the lookup helper + * @return code for setting storage field + */ + default String storageFieldSetter( + final String inputVarName, final MessageDefContext msgDef, final ContextualLookupHelper lookupHelper) { + if (cannotBeNull()) { + return inputVarName + " != null ? " + inputVarName + " : " + defaultValue(msgDef, lookupHelper); + } else { + return inputVarName; + } + } + + /** + * Check if the field is a String, either native or boxed. + * Useful because we use an array for storing strings in models, and arrays require very special handling + * when it comes to equality checks, comparison, hash codes, etc. + * @return true if this field is a string (and therefore, it's stored as a UTF-8 byte[] internally in models) + */ + default boolean isString() { + return type() == FieldType.STRING || (type() == FieldType.MESSAGE && "StringValue".equals(messageType())); + } + + /** + * Get the code for getting storage field, by default a no-op. This can be used for fields like String where we + * want different internal storage types. + * + * @param fieldName the name of the field being accessed + * @return code for getting storage field + */ + default String storageFieldGetter(String fieldName) { + return fieldName; + } + + default String storageFieldWriter(final String modelClassName, final String schemaClassName) { + return CodecWriteMethodGenerator.generateFieldWriteLines( + this, modelClassName, schemaClassName, nameCamelFirstLower(), true, true); + } + + default String storageFieldByteArrayWriter(final String modelClassName, final String schemaClassName) { + return CodecWriteByteArrayMethodGenerator.generateFieldWriteLines( + this, modelClassName, schemaClassName, nameCamelFirstLower(), true, true); + } + + /** + * Determine if this field cannot be null. For example, string and bytes fields can never be null. + * Repeated fields also cannot be null. They all are initialized with default values (e.g. empty collections) + * if they're missing from the input, so that models always return non-null values to clients. + *

+ * Note that if this method returns `false`, this does NOT mean that the field can be null. For example, + * boolean fields cannot be null because we use the unboxed primitive `boolean` type to store them, + * even though this method will return `false` for them. + *

+ * In other words, the return value of this method is ONLY meaningful for fields represented by Java objects, + * and it's only really meaningful if/when it's equal to `true`. + * + * @return true if this field can be null, otherwise false + */ + default boolean cannotBeNull() { + if (repeated()) return true; + return switch (type()) { + case BYTES, STRING -> true; + default -> false; + }; + } + + /** + * Get a set of annotations for this field. + * + * @return an empty string, or a string with Java annotations ending with a space + */ + default String annotations() { + return cannotBeNull() ? NON_NULL_ANNOTATION + " " : ""; + } + + /** + * Gets the default value for this field + * + * @param msgDef the message definition + * @param lookupHelper the lookup helper + * @return the generated code + */ + default String defaultValue(final MessageDefContext msgDef, final ContextualLookupHelper lookupHelper) { + if (type() == Field.FieldType.ONE_OF) { + return lookupHelper.getFullyQualifiedMessageClassname(FileType.CODEC, msgDef) + "." + javaDefault(); + } else { + return javaDefault(); + } + } + /** * Add all the needed imports for this field to the supplied set. * @@ -276,7 +398,7 @@ enum FieldType { /** Protobuf sfixed64(signed fixed encoding long) field type */ SFIXED64("long", "Long", "0", TYPE_FIXED64), /** Protobuf string field type */ - STRING("String", "String", "\"\"", TYPE_LENGTH_DELIMITED), + STRING("String", "String", "PbjConstants.EMPTY_BYTES", TYPE_LENGTH_DELIMITED), /** Protobuf bool(boolean) field type */ BOOL("boolean", "Boolean", "false", TYPE_VARINT), /** Protobuf bytes field type */ diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java index cafbd838f..561153b1f 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java @@ -58,7 +58,8 @@ public MapField(Protobuf3Parser.MapFieldContext mapContext, final ContextualLook "An internal, private map entry key for %s" .formatted(mapContext.mapName().getText()), false, - null), + null, + true), new SingleField( false, FieldType.of(mapContext.type_(), lookupHelper), @@ -73,7 +74,8 @@ public MapField(Protobuf3Parser.MapFieldContext mapContext, final ContextualLook "An internal, private map entry value for %s" .formatted(mapContext.mapName().getText()), false, - null), + null, + true), false, // maps cannot be repeated Integer.parseInt(mapContext.fieldNumber().getText()), mapContext.mapName().getText(), diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java index 8fc309669..68295bee3 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java @@ -4,6 +4,7 @@ import static com.hedera.pbj.compiler.impl.Common.DEFAULT_INDENT; import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; +import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser.MessageDefContext; import edu.umd.cs.findbugs.annotations.NonNull; import java.util.function.Consumer; @@ -31,7 +32,8 @@ public record SingleField( String messageTypeTestPackage, String comment, boolean deprecated, - OneOfField parent) + OneOfField parent, + boolean isMapField) implements Field { /** @@ -55,7 +57,8 @@ public SingleField(Protobuf3Parser.FieldContext fieldContext, final ContextualLo Common.buildCleanFieldJavaDoc( Integer.parseInt(fieldContext.fieldNumber().getText()), fieldContext.docComment()), getDeprecatedOption(fieldContext.fieldOptions()), - null); + null, + false); } /** @@ -96,7 +99,8 @@ public SingleField( Common.buildCleanFieldJavaDoc( Integer.parseInt(fieldContext.fieldNumber().getText()), fieldContext.docComment()), getDeprecatedOption(fieldContext.fieldOptions()), - parent); + parent, + false); } /** @@ -124,19 +128,54 @@ public String protobufFieldType() { return type == SingleField.FieldType.MESSAGE ? messageType : type.javaType; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public String javaFieldType() { return javaFieldType(true); } + /** {@inheritDoc} */ @Override public String javaFieldTypeBase() { return javaFieldType(false); } + /** {@inheritDoc} */ + @Override + public String javaFieldStorageType() { + if (isString()) { + return repeated ? "List" : "byte[]"; + } + return javaFieldType(); + } + + /** {@inheritDoc} */ + @Override + public String storageFieldSetter( + final String inputVarName, final MessageDefContext msgDef, final ContextualLookupHelper lookupHelper) { + if (isString() && parent == null && !isMapField) { + if (!repeated) { + return inputVarName + " != null ? toUtf8Bytes(" + inputVarName + ") : " + + (cannotBeNull() ? "PbjConstants.EMPTY_BYTES" : "null"); + } else { + // It's List -> List variant: + return "toUtf8Bytes(" + inputVarName + ")"; + } + } else if (cannotBeNull()) { + return inputVarName + " != null ? " + inputVarName + " : " + defaultValue(msgDef, lookupHelper); + } else { + return inputVarName; + } + } + + /** {@inheritDoc} */ + @Override + public String storageFieldGetter(String fieldName) { + return isString() + ? ((cannotBeNull() ? "" : (fieldName + " == null ? null : ")) + "toUtf8String(" + fieldName + ")") + : fieldName; + } + @NonNull private String javaFieldType(boolean considerRepeated) { String fieldType = diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java index 8fff2c71e..838d496d0 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java @@ -37,11 +37,9 @@ * Code generator that parses protobuf files and generates nice Java source for record files for each message type and * enum. */ -@SuppressWarnings({"EscapedSpace", "StringConcatenationInLoop"}) +@SuppressWarnings({"EscapedSpace", "StringConcatenationInLoop", "SwitchStatementWithTooFewBranches"}) public final class ModelGenerator implements Generator { - private static final String NON_NULL_ANNOTATION = "@NonNull"; - private static final String HASH_CODE_MANIPULATION = """ // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 @@ -92,6 +90,11 @@ public void generate( writer.addImport("edu.umd.cs.findbugs.annotations.*"); writer.addImport(lookupHelper.getFullyQualifiedMessageClassname(FileType.SCHEMA, msgDef)); writer.addImport("static " + lookupHelper.getFullyQualifiedMessageClassname(FileType.SCHEMA, msgDef) + ".*"); + writer.addImport("static com.hedera.pbj.runtime.PbjConstants.EMPTY_BYTES"); + writer.addImport("static com.hedera.pbj.runtime.ProtoWriterTools.*"); + writer.addImport("static com.hedera.pbj.runtime.Utf8Tools.*"); + writer.addImport("java.io.IOException"); + writer.addImport("java.util.Arrays"); writer.addImport("java.util.Collections"); writer.addImport("java.util.List"); @@ -136,7 +139,8 @@ public void generate( null, "Computed hash code, manual input ignored.", false, - null)); + null, + false)); fields.add(new SingleField( false, FieldType.FIXED32, @@ -150,7 +154,8 @@ public void generate( null, "Computed protobuf encoded size, manual input ignored.", false, - null)); + null, + false)); // The javadoc comment to use for the model class, which comes **directly** from the protobuf schema, // but is cleaned up and formatted for use in JavaDoc. @@ -183,8 +188,8 @@ public void generate( return fieldComment + "private " + (field.fieldNumber() != -1 ? "final " : "") - + getFieldAnnotations(field) - + field.javaFieldType() + " " + field.nameCamelFirstLower() + + field.annotations() + + field.javaFieldStorageType() + " " + field.nameCamelFirstLower() + (field.fieldNumber() == -1 ? " = -1" : "") + ";"; }) @@ -197,15 +202,21 @@ public void generate( bodyContent += "\n"; // constructors: w/o unknownFields, and with unknownFields - bodyContent += - generateConstructor(javaRecordName, fields, false, fieldsNoPrecomputed, true, msgDef, lookupHelper); + bodyContent += generateConstructor( + false, javaRecordName, fields, false, fieldsNoPrecomputed, true, msgDef, lookupHelper); bodyContent += "\n"; - bodyContent += - generateConstructor(javaRecordName, fields, true, fieldsNoPrecomputed, true, msgDef, lookupHelper); + bodyContent += generateConstructor( + false, javaRecordName, fields, true, fieldsNoPrecomputed, true, msgDef, lookupHelper); bodyContent += "\n"; + if (fields.stream().anyMatch(Field::hasDifferentStorageType)) { + bodyContent += generateConstructor( + true, javaRecordName, fields, true, fieldsNoPrecomputed, true, msgDef, lookupHelper); + bodyContent += "\n"; + } // record style getters - bodyContent += generateRecordStyleGetters(fieldsNoPrecomputed); + bodyContent += + generateRecordStyleGetters(javaRecordName, fieldsNoPrecomputed, msgDef, lookupHelper, schemaClassName); bodyContent += "\n"; bodyContent += @@ -323,6 +334,7 @@ private void generateClass( writer.addImport("java.util.function.Consumer"); writer.addImport("edu.umd.cs.findbugs.annotations.Nullable"); writer.addImport("edu.umd.cs.findbugs.annotations.NonNull"); + writer.addImport("java.lang.SuppressWarnings"); writer.addImport("static java.util.Objects.requireNonNull"); writer.addImport("static com.hedera.pbj.runtime.ProtoWriterTools.*"); writer.addImport("static com.hedera.pbj.runtime.ProtoConstants.*"); @@ -330,6 +342,7 @@ private void generateClass( // spotless:off writer.append(""" $javaDocComment$deprecated + @SuppressWarnings("cast") public final$staticModifier class $javaRecordName $implementsComparable{ $bodyContent @@ -360,45 +373,69 @@ private void generateClass( * @param fields the fields to use for the code generation * @return the generated code */ - private static String generateRecordStyleGetters(final List fields) { + private static String generateRecordStyleGetters( + final String javaRecordName, + final List fields, + final MessageDefContext msgDef, + final ContextualLookupHelper lookupHelper, + final String schemaClassName) { return fields.stream() .map(field -> { String fieldComment = field.comment(); String fieldCommentLowerFirst = fieldComment.substring(0, 1).toLowerCase() + fieldComment.substring(1); - return """ + String prefix = ""; + if (field.hasDifferentStorageType() || field.isString()) { + prefix = + """ + /** + * Write binary protobuf representation of field $fieldCommentLowerFirst to given output + */ + public void $fieldNameWriteTo(@NonNull final WritableSequentialData out) throws IOException { + $fieldWriter + } + + /** + * Write binary protobuf representation of field $fieldCommentLowerFirst to given output. + * @return the number of bytes written + */ + public int $fieldNameWriteTo(@NonNull final byte[] output, final int startOffset) { + int offset = startOffset; + $byteArrayFieldWriter + return offset - startOffset; + } + + """ + .replace("$fieldCommentLowerFirst", fieldCommentLowerFirst) + .replace("$fieldName", field.nameCamelFirstLower()) + .replace( + "$fieldWriter", + field.storageFieldWriter(javaRecordName, schemaClassName)) + .replace( + "$byteArrayFieldWriter", + field.storageFieldByteArrayWriter(javaRecordName, schemaClassName)) + .indent(DEFAULT_INDENT); + } + return prefix + + """ /** * Get field $fieldCommentLowerFirst * * @return the value of the $fieldName field */ public $fieldType $fieldName() { - return $fieldName; + return $fieldGetCode; } """ - .replace("$fieldCommentLowerFirst", fieldCommentLowerFirst) - .replace("$fieldName", field.nameCamelFirstLower()) - .replace("$fieldType", field.javaFieldType()) - .indent(DEFAULT_INDENT); + .replace("$fieldCommentLowerFirst", fieldCommentLowerFirst) + .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$fieldGetCode", field.storageFieldGetter(field.nameCamelFirstLower())) + .replace("$fieldType", field.javaFieldType()) + .indent(DEFAULT_INDENT); }) .collect(Collectors.joining("\n")); } - /** - * Returns a set of annotations for a given field. - * @param field a field - * @return an empty string, or a string with Java annotations ending with a space - */ - private static String getFieldAnnotations(final Field field) { - if (field.repeated()) return NON_NULL_ANNOTATION + " "; - - return switch (field.type()) { - case MESSAGE -> "@Nullable "; - case BYTES, STRING -> NON_NULL_ANNOTATION + " "; - default -> ""; - }; - } - /** * Filter the fields to only include those that are comparable * @param msgDef The message definition @@ -619,8 +656,10 @@ public String toString() { // spotless:on for (int i = 0; i < fields.size(); i++) { Field f = fields.get(i); - bodyContent += - FIELD_INDENT + FIELD_INDENT + "+ \"" + f.nameCamelFirstLower() + "=\" + " + f.nameCamelFirstLower(); + // Getters can include ?: ternary, so we enclose them in (...): + bodyContent += FIELD_INDENT + FIELD_INDENT + "+ \"" + f.nameCamelFirstLower() + "=\" + (" + + f.storageFieldGetter(f.nameCamelFirstLower()) + + ")"; if (i < fields.size() - 1) { bodyContent += " + \", \""; } @@ -651,6 +690,7 @@ public String toString() { * @return the generated code */ private static String generateConstructor( + final boolean generatePrivateStorageConstructor, final String constructorName, final List fields, final boolean initUnknownFields, @@ -667,7 +707,7 @@ private static String generateConstructor( * Create a pre-populated $constructorName. * $constructorParamDocs */ - public $constructorName($constructorParams$unknownFieldsParam) { + public $constructorName($fakeParams$constructorParams$unknownFieldsParam) { $unknownFieldsCode $constructorCode } """ @@ -676,8 +716,16 @@ private static String generateConstructor( field.comment().replaceAll("\n", "\n * "+" ".repeat(field.nameCamelFirstLower().length())) ).collect(Collectors.joining(" "))) .replace("$constructorName", constructorName) + // List and List are the same after erasure, and if there's no other storage type fields, + // the normal and storage constructors will have the same signature at JVM level, which javac doesn't allow. + // So we add a fake argument to the private constructor, to be able to instantiate objects using storage types. + // Note that unfortunately, these storage constructors have to be public, so that the Codec has access to them. + // This, however, allows malicious clients to create mutable instances of the model, which isn't ideal. + .replace("$fakeParams", generatePrivateStorageConstructor ? ("int ___unusedArgumentToBypassGenericErasure" + (initUnknownFields || !fieldsNoPrecomputed.isEmpty() ? ", " : "")) : "") .replace("$constructorParams",fieldsNoPrecomputed.stream().map(field -> - field.javaFieldType() + " " + field.nameCamelFirstLower() + generatePrivateStorageConstructor ? + (field.javaFieldStorageType() + " " + field.nameCamelFirstLower()) : + (field.javaFieldType() + " " + field.nameCamelFirstLower()) ).collect(Collectors.joining(", "))) .replace("$unknownFieldsParam", initUnknownFields ? ((fieldsNoPrecomputed.isEmpty() ? "" : ", ") + "final List $unknownFields") @@ -692,9 +740,12 @@ private static String generateConstructor( } switch (field.type()) { case BYTES, STRING: { - sb.append("this.$name = $name != null ? $name : $default;" + sb.append("this.$name = $name != null ? $fieldSetter : $default;" .replace("$name", field.nameCamelFirstLower()) - .replace("$default", getDefaultValue(field, msgDef, lookupHelper)) + .replace("$fieldSetter", + generatePrivateStorageConstructor ? field.nameCamelFirstLower() : + field.storageFieldSetter(field.nameCamelFirstLower(), msgDef, lookupHelper)) + .replace("$default", field.defaultValue(msgDef, lookupHelper)) ); break; } @@ -705,12 +756,11 @@ private static String generateConstructor( break; } default: - if (field.repeated()) { - sb.append("this.$name = $name == null ? Collections.emptyList() : $name;".replace( - "$name", field.nameCamelFirstLower())); - } else { - sb.append("this.$name = $name;".replace("$name", field.nameCamelFirstLower())); - } + sb.append("this.$fieldName = $fieldSetter;" + .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$fieldSetter", + generatePrivateStorageConstructor ? field.nameCamelFirstLower() : + field.storageFieldSetter(field.nameCamelFirstLower(), msgDef, lookupHelper))); break; } return sb.toString(); @@ -815,7 +865,7 @@ private static void generateCodeForField( * @return the value for $fieldName if it has a value, or else returns the default value */ public $javaFieldType $fieldNameOrElse(@NonNull final $javaFieldType defaultValue) { - return has$fieldNameUpperFirst() ? $fieldName : defaultValue; + return has$fieldNameUpperFirst() ? $fieldGetCode : defaultValue; } /** @@ -826,7 +876,7 @@ private static void generateCodeForField( * @throws NullPointerException if $fieldName is null */ public @NonNull $javaFieldType $fieldNameOrThrow() { - return requireNonNull($fieldName, "Field $fieldName is null"); + return requireNonNull($fieldGetCode, "Field $fieldName is null"); } /** @@ -836,13 +886,14 @@ private static void generateCodeForField( */ public void if$fieldNameUpperFirst(@NonNull final Consumer<$javaFieldType> ifPresent) { if (has$fieldNameUpperFirst()) { - ifPresent.accept($fieldName); + ifPresent.accept($fieldGetCode); } } """ .replace("$fieldNameUpperFirst", field.nameCamelFirstUpper()) .replace("$javaFieldType", field.javaFieldType()) .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$fieldGetCode", field.storageFieldGetter(field.nameCamelFirstLower())) .indent(DEFAULT_INDENT) ); } @@ -966,7 +1017,7 @@ private static String generateBuilderFactoryMethods(String bodyContent, final Li * @return a pre-populated builder */ public Builder copyBuilder() { - return new Builder(%s$unknownFieldsArg); + return new Builder($fakeParams%s$unknownFieldsArg); } /** @@ -979,6 +1030,7 @@ public static Builder newBuilder() { } """ .formatted(fields.stream().map(Field::nameCamelFirstLower).collect(Collectors.joining(", "))) + .replace("$fakeParams", fields.stream().anyMatch(Field::hasDifferentStorageType) ? "0, " : "") .replace("$unknownFieldsArg", (fields.isEmpty() ? "" : ", ") + "$unknownFields") .indent(DEFAULT_INDENT); // spotless:on @@ -998,8 +1050,7 @@ private static void generateBuilderMethods( final MessageDefContext msgDef, final Field field, final ContextualLookupHelper lookupHelper) { - final String prefix, postfix, fieldToSet; - final String fieldAnnotations = getFieldAnnotations(field); + final String prefix, postfix, fieldToSet, fieldSetter; final OneOfField parentOneOfField = field.parent(); final String fieldName = field.nameCamelFirstLower(); if (parentOneOfField != null) { @@ -1007,14 +1058,12 @@ private static void generateBuilderMethods( prefix = " new %s<>(".formatted(parentOneOfField.className()) + oneOfEnumValue + ","; postfix = ")"; fieldToSet = parentOneOfField.nameCamelFirstLower(); - } else if (fieldAnnotations.contains(NON_NULL_ANNOTATION)) { - prefix = ""; - postfix = " != null ? " + fieldName + " : " + getDefaultValue(field, msgDef, lookupHelper); - fieldToSet = fieldName; + fieldSetter = fieldName; } else { prefix = ""; postfix = ""; fieldToSet = fieldName; + fieldSetter = field.storageFieldSetter(fieldName, msgDef, lookupHelper); } // spotless:off builderMethods.add(""" @@ -1025,7 +1074,7 @@ private static void generateBuilderMethods( * @return builder to continue building with */ public Builder $fieldName($fieldAnnotations$fieldType $fieldName) { - this.$fieldToSet = $prefix$fieldName$postfix; + this.$fieldToSet = $prefix$fieldSetter$postfix; return this; }""" .replace("$fieldDoc", field.comment() @@ -1033,8 +1082,9 @@ private static void generateBuilderMethods( .replace("$fieldName", fieldName) .replace("$fieldToSet", fieldToSet) .replace("$prefix", prefix) + .replace("$fieldSetter", fieldSetter) .replace("$postfix", postfix) - .replace("$fieldAnnotations", fieldAnnotations) + .replace("$fieldAnnotations", field.annotations()) .replace("$fieldType", field.javaFieldType()) .indent(DEFAULT_INDENT) ); @@ -1072,16 +1122,17 @@ private static void generateBuilderMethods( final String repeatedPrefix; final String repeatedPostfix; // spotless:off - if (parentOneOfField != null) { - repeatedPrefix = prefix + " values == null ? " + getDefaultValue(field, msgDef, lookupHelper) + " : "; + if (field.cannotBeNull()) { + repeatedPrefix = prefix + " values == null ? " + field.defaultValue(msgDef, lookupHelper) + " : "; repeatedPostfix = postfix; - } else if (fieldAnnotations.contains(NON_NULL_ANNOTATION)) { - repeatedPrefix = "values == null ? " + getDefaultValue(field, msgDef, lookupHelper) + " : "; - repeatedPostfix = ""; } else { repeatedPrefix = prefix; repeatedPostfix = postfix; } + String baseType = field.javaFieldType().substring("List<".length(),field.javaFieldType().length()-1); + if (field.type() == FieldType.STRING) { + baseType = "String"; + } builderMethods.add(""" /** * $fieldDoc @@ -1090,10 +1141,11 @@ private static void generateBuilderMethods( * @return builder to continue building with */ public Builder $fieldName($baseType ... values) { - this.$fieldToSet = $repeatedPrefix List.of(values) $repeatedPostfix; + this.$fieldToSet = $repeatedPrefix $convertMethod(values) $repeatedPostfix; return this; }""" - .replace("$baseType",field.javaFieldType().substring("List<".length(),field.javaFieldType().length()-1)) + .replace("$convertMethod",field.type() == FieldType.STRING ? "toUtf8Bytes" : "List.of") + .replace("$baseType",baseType) .replace("$fieldDoc",field.comment() .replaceAll("\n", "\n * ")) .replace("$fieldName", fieldName) @@ -1143,6 +1195,7 @@ public static final class Builder { */ public Builder() { $unknownFields = null; } + $prePopulatedPrivateBuilder $prePopulatedBuilder $prePopulatedWithUnknownFieldsBuilder /** @@ -1151,38 +1204,27 @@ public static final class Builder { * @return new model record with data set */ public $javaRecordName build() { - return new $javaRecordName($recordParams); + return new $javaRecordName($fakeParams$recordParams$unknownFieldParam); } $builderMethods}""" .replace("$fields", fields.stream().map(field -> - getFieldAnnotations(field) - + "private " + field.javaFieldType() + field.annotations() + + "private " + field.javaFieldStorageType() + " " + field.nameCamelFirstLower() - + " = " + getDefaultValue(field, msgDef, lookupHelper) + + " = " + field.defaultValue(msgDef, lookupHelper) ).collect(Collectors.joining(";\n "))) - .replace("$prePopulatedBuilder", generateConstructor("Builder", fields, false, fields, false, msgDef, lookupHelper)) - .replace("$prePopulatedWithUnknownFieldsBuilder", generateConstructor("Builder", fields, true, fields, false, msgDef, lookupHelper)) + .replace("$prePopulatedPrivateBuilder", fields.stream().noneMatch(Field::hasDifferentStorageType) + ? "" : + generateConstructor(true, "Builder", fields, true, fields, false, msgDef, lookupHelper)) + .replace("$prePopulatedBuilder", generateConstructor(false, "Builder", fields, false, fields, false, msgDef, lookupHelper)) + .replace("$prePopulatedWithUnknownFieldsBuilder", generateConstructor(false,"Builder", fields, true, fields, false, msgDef, lookupHelper)) .replace("$javaRecordName",javaRecordName) + .replace("$fakeParams", fields.stream().anyMatch(Field::hasDifferentStorageType) ? "0, " : "") .replace("$recordParams",fields.stream().map(Field::nameCamelFirstLower).collect(Collectors.joining(", "))) + .replace("$unknownFieldParam",fields.isEmpty() ? "$unknownFields" : ", $unknownFields") .replace("$builderMethods", String.join("\n", builderMethods)) .indent(DEFAULT_INDENT); // spotless:on } - - /** - * Gets the default value for the field - * @param field the field to use for the code generation - * @param msgDef the message definition - * @param lookupHelper the lookup helper - * @return the generated code - */ - private static String getDefaultValue( - final Field field, final MessageDefContext msgDef, final ContextualLookupHelper lookupHelper) { - if (field.type() == Field.FieldType.ONE_OF) { - return lookupHelper.getFullyQualifiedMessageClassname(FileType.CODEC, msgDef) + "." + field.javaDefault(); - } else { - return field.javaDefault(); - } - } } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java index 8f1b82aae..ce9f6c43d 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java @@ -362,7 +362,12 @@ private static String generateTestMethod(final String modelClassName, final Stri } // read proto bytes with ProtoC to make sure it is readable and no parse exceptions are thrown - final $protocModelClass protoCModelObj = $protocModelClass.parseFrom(byteBuffer); + final $protocModelClass protoCModelObj; + try { + protoCModelObj = $protocModelClass.parseFrom(byteBuffer); + } catch (final Exception e) { + throw new RuntimeException("For model:\\n" + modelObj + "\\nCAUGHT EXCEPTION: ", e); + } // read proto bytes with PBJ parser dataBuffer.resetPosition(); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java index 2e6d5d7de..635c58fde 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java @@ -70,6 +70,7 @@ public void generate( writer.addImport("com.hedera.pbj.runtime.jsonparser.*"); writer.addImport("static " + lookupHelper.getFullyQualifiedMessageClassname(FileType.SCHEMA, msgDef) + ".*"); writer.addImport("static com.hedera.pbj.runtime.JsonTools.*"); + writer.addImport("static com.hedera.pbj.runtime.Utf8Tools.*"); // spotless:off writer.append(""" diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java index a545a1db5..19e2760ae 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java @@ -8,6 +8,7 @@ import com.hedera.pbj.compiler.impl.Field; import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; +import com.hedera.pbj.compiler.impl.SingleField; import java.util.List; import java.util.stream.Collectors; @@ -78,18 +79,24 @@ static String generateParseObjectMethod(final String modelClassName, final List< } } - return new $modelClassName($fieldsList); + return new $modelClassName($fakeParams$fieldsList$unknownFields); } catch (Exception ex) { throw new ParseException(ex); } } """ .replace("$modelClassName", modelClassName) + .replace( + "$fakeParams", + fields.stream().anyMatch(Field::hasDifferentStorageType) + ? ("0" + (fields.isEmpty() ? "" : ", ")) + : "") + .replace("$unknownFields", fields.isEmpty() ? "Collections.emptyList()" : ", Collections.emptyList()") .replace( "$fieldDefs", fields.stream() .map(field -> " %s temp_%s = %s;" - .formatted(field.javaFieldType(), field.name(), field.javaDefault())) + .formatted(field.javaFieldStorageType(), field.name(), field.javaDefault())) .collect(Collectors.joining("\n"))) .replace( "$fieldsList", @@ -138,6 +145,7 @@ private static String generateCaseStatements(final List fields) { private static void generateFieldCaseStatement( final StringBuilder origSB, final Field field, final String valueGetter) { final StringBuilder sb = new StringBuilder(); + final boolean isMapField = field instanceof SingleField && ((SingleField) field).isMapField(); if (field.repeated()) { if (field.type() == Field.FieldType.MESSAGE) { sb.append("parseObjArray($valueGetter.arr(), " + field.messageType() + ".JSON, maxDepth - 1)"); @@ -149,7 +157,11 @@ private static void generateFieldCaseStatement( case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> sb.append("parseLong(v)"); case FLOAT -> sb.append("parseFloat(v)"); case DOUBLE -> sb.append("parseDouble(v)"); - case STRING -> sb.append("unescape(v.STRING().getText())"); + case STRING -> + sb.append( + isMapField || field.parent() != null + ? "unescape(v.STRING().getText())" + : "toUtf8Bytes(unescape(v.STRING().getText()))"); case BOOL -> sb.append("parseBoolean(v)"); case BYTES -> sb.append("Bytes.fromBase64(v.STRING().getText())"); default -> throw new RuntimeException("Unknown field type [" + field.type() + "]"); @@ -162,7 +174,11 @@ private static void generateFieldCaseStatement( case "Int64Value", "UInt64Value" -> sb.append("parseLong($valueGetter)"); case "FloatValue" -> sb.append("parseFloat($valueGetter)"); case "DoubleValue" -> sb.append("parseDouble($valueGetter)"); - case "StringValue" -> sb.append("unescape($valueGetter.STRING().getText())"); + case "StringValue" -> + sb.append( + isMapField || field.parent() != null + ? "unescape($valueGetter.STRING().getText())" + : "toUtf8Bytes(unescape($valueGetter.STRING().getText()))"); case "BoolValue" -> sb.append("parseBoolean($valueGetter)"); case "BytesValue" -> sb.append("Bytes.fromBase64($valueGetter.STRING().getText())"); default -> throw new RuntimeException("Unknown message type [" + field.messageType() + "]"); @@ -187,14 +203,20 @@ private static void generateFieldCaseStatement( .replace("$mapEntryValue", valueSB.toString())); } else { switch (field.type()) { - case MESSAGE -> sb.append(field.javaFieldType() - + ".JSON.parse($valueGetter.getChild(JSONParser.ObjContext.class, 0), false, maxDepth - 1)"); + case MESSAGE -> + sb.append( + field.javaFieldType() + + ".JSON.parse($valueGetter.getChild(JSONParser.ObjContext.class, 0), false, maxDepth - 1)"); case ENUM -> sb.append(field.javaFieldType() + ".fromString($valueGetter.STRING().getText())"); case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> sb.append("parseInteger($valueGetter)"); case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> sb.append("parseLong($valueGetter)"); case FLOAT -> sb.append("parseFloat($valueGetter)"); case DOUBLE -> sb.append("parseDouble($valueGetter)"); - case STRING -> sb.append("unescape($valueGetter.STRING().getText())"); + case STRING -> + sb.append( + isMapField || field.parent() != null + ? "unescape($valueGetter.STRING().getText())" + : "toUtf8Bytes(unescape($valueGetter.STRING().getText()))"); case BOOL -> sb.append("parseBoolean($valueGetter)"); case BYTES -> sb.append("Bytes.fromBase64($valueGetter.STRING().getText())"); default -> throw new RuntimeException("Unknown field type [" + field.type() + "]"); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java index a90a126af..96f277d0f 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java @@ -101,9 +101,14 @@ private static String generateFieldWriteLines(final Field field, final String mo return prefix + "if (data." + field.nameCamelFirstLower() + "() != " + field.javaDefault() + " && !data." + field.nameCamelFirstLower() + "().isEmpty()) fieldLines.add(" + basicFieldCode + ");"; + } else if (field.isString()) { + return prefix + "if (" + getValueCode + " != null" + + (field.optionalValueType() ? "" : (" && " + getValueCode + ".length() > 0")) + + ") fieldLines.add(" + + basicFieldCode + ");"; } else { - return prefix + "if (data." + field.nameCamelFirstLower() + "() != " + field.javaDefault() - + ") fieldLines.add(" + basicFieldCode + ");"; + return prefix + "if (" + getValueCode + " != " + field.javaDefault() + ") fieldLines.add(" + + basicFieldCode + ");"; } } } @@ -121,23 +126,25 @@ private static String generateBasicFieldLines( "DoubleValue", "BytesValue" -> "field(%s, %s)".formatted(fieldName, getValueCode); case "Int64Value", "UInt64Value" -> "field(%s, %s, true)".formatted(fieldName, getValueCode); - default -> throw new UnsupportedOperationException( - "Unhandled optional message type:" + field.messageType()); + default -> + throw new UnsupportedOperationException("Unhandled optional message type:" + field.messageType()); }; } else if (field.repeated()) { return switch (field.type()) { - case MESSAGE -> "arrayField(childIndent, $fieldName, $codec, $valueCode)" - .replace("$fieldName", fieldName) - .replace("$fieldDef", fieldDef) - .replace("$valueCode", getValueCode) - .replace( - "$codec", - ((SingleField) field).messageTypeModelPackage() + "." - + ((SingleField) field).completeClassName() + ".JSON"); - default -> "arrayField($fieldName, $fieldDef, $valueCode)" - .replace("$fieldName", fieldName) - .replace("$fieldDef", fieldDef) - .replace("$valueCode", getValueCode); + case MESSAGE -> + "arrayField(childIndent, $fieldName, $codec, $valueCode)" + .replace("$fieldName", fieldName) + .replace("$fieldDef", fieldDef) + .replace("$valueCode", getValueCode) + .replace( + "$codec", + ((SingleField) field).messageTypeModelPackage() + "." + + ((SingleField) field).completeClassName() + ".JSON"); + default -> + "arrayField($fieldName, $fieldDef, $valueCode)" + .replace("$fieldName", fieldName) + .replace("$fieldDef", fieldDef) + .replace("$valueCode", getValueCode); }; } else if (field.type() == Field.FieldType.MAP) { final MapField mapField = (MapField) field; @@ -156,19 +163,21 @@ private static String generateBasicFieldLines( .replace("$vComposer", "(n, v) -> " + vComposerMethod); } else { return switch (field.type()) { - case ENUM -> "field($fieldName, $valueCode.protoName())" - .replace("$fieldName", fieldName) - .replace("$fieldDef", fieldDef) - .replace("$valueCode", getValueCode); - case MESSAGE -> "field($childIndent, $fieldName, $codec, $valueCode)" - .replace("$childIndent", childIndent) - .replace("$fieldName", fieldName) - .replace("$fieldDef", fieldDef) - .replace("$valueCode", getValueCode) - .replace( - "$codec", - ((SingleField) field).messageTypeModelPackage() + "." - + ((SingleField) field).completeClassName() + ".JSON"); + case ENUM -> + "field($fieldName, $valueCode.protoName())" + .replace("$fieldName", fieldName) + .replace("$fieldDef", fieldDef) + .replace("$valueCode", getValueCode); + case MESSAGE -> + "field($childIndent, $fieldName, $codec, $valueCode)" + .replace("$childIndent", childIndent) + .replace("$fieldName", fieldName) + .replace("$fieldDef", fieldDef) + .replace("$valueCode", getValueCode) + .replace( + "$codec", + ((SingleField) field).messageTypeModelPackage() + "." + + ((SingleField) field).completeClassName() + ".JSON"); default -> "field(%s, %s)".formatted(fieldName, getValueCode); }; } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java index 9ad4e075d..fc4f5246d 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java @@ -2,6 +2,8 @@ package com.hedera.pbj.compiler.impl.generators.protobuf; import static com.hedera.pbj.compiler.impl.Common.DEFAULT_INDENT; +import static com.hedera.pbj.compiler.impl.Field.FieldType.MAP; +import static com.hedera.pbj.compiler.impl.Field.FieldType.STRING; import com.hedera.pbj.compiler.impl.Common; import com.hedera.pbj.compiler.impl.Field; @@ -81,7 +83,7 @@ static String generateParseMethod( Collections.sort($unknownFields); $initialSizeOfUnknownFieldsArray = Math.max($initialSizeOfUnknownFieldsArray, $unknownFields.size()); } - return new $modelClassName($fieldsList); + return new $modelClassName($fakeParams$fieldsList); } catch (final Exception anyException) { if (anyException instanceof ParseException parseException) { throw parseException; @@ -91,13 +93,15 @@ static String generateParseMethod( } """ .replace("$modelClassName",modelClassName) - .replace("$fieldDefs",fields.stream().map(field -> " %s temp_%s = %s;".formatted(field.javaFieldType(), + .replace("$fakeParams", fields.stream().anyMatch(Field::hasDifferentStorageType) ? ("0" + (fields.isEmpty() ? "" : ", ")) : "") + .replace("$fieldDefs",fields.stream().map(field -> " %s temp_%s = %s;" + .formatted(field.javaFieldStorageType(), field.name(), field.javaDefault())).collect(Collectors.joining("\n"))) .replace("$fieldsList", fields.stream().map(field -> "temp_"+field.name()).collect(Collectors.joining(", ")) + (fields.isEmpty() ? "" : ", ") + "$unknownFields" ) - .replace("$parseLoop", generateParseLoop(generateCaseStatements(fields, schemaClassName), "", schemaClassName)) + .replace("$parseLoop", generateParseLoop(generateCaseStatements(fields, schemaClassName, false), "", schemaClassName)) .replace("$skipMaxSize", String.valueOf(Field.DEFAULT_MAX_SIZE)) .indent(DEFAULT_INDENT); // spotless:on @@ -192,20 +196,21 @@ static String generateParseLoop( * @param fields list of all fields in record * @return string of case statement code */ - private static String generateCaseStatements(final List fields, final String schemaClassName) { + private static String generateCaseStatements( + final List fields, final String schemaClassName, final boolean isMapField) { StringBuilder sb = new StringBuilder(); for (Field field : fields) { if (field instanceof final OneOfField oneOfField) { for (final Field subField : oneOfField.fields()) { - generateFieldCaseStatement(sb, subField, schemaClassName); + generateFieldCaseStatement(sb, subField, schemaClassName, isMapField); } } else if (field.repeated() && field.type().wireType() != Common.TYPE_LENGTH_DELIMITED) { // for repeated fields that are not length encoded there are 2 forms they can be stored in file. // "packed" and repeated primitive fields - generateFieldCaseStatement(sb, field, schemaClassName); - generateFieldCaseStatementPacked(sb, field); + generateFieldCaseStatement(sb, field, schemaClassName, isMapField); + generateFieldCaseStatementPacked(sb, field, isMapField); } else { - generateFieldCaseStatement(sb, field, schemaClassName); + generateFieldCaseStatement(sb, field, schemaClassName, isMapField); } } return sb.toString().indent(DEFAULT_INDENT * 4); @@ -218,7 +223,8 @@ private static String generateCaseStatements(final List fields, final Str * @param sb StringBuilder to append code to */ @SuppressWarnings("StringConcatenationInsideStringBufferAppend") - private static void generateFieldCaseStatementPacked(final StringBuilder sb, final Field field) { + private static void generateFieldCaseStatementPacked( + final StringBuilder sb, final Field field, final boolean isMapField) { final int wireType = Common.TYPE_LENGTH_DELIMITED; final int fieldNum = field.fieldNumber(); final int tag = Common.getTag(wireType, fieldNum); @@ -261,7 +267,7 @@ private static void generateFieldCaseStatementPacked(final StringBuilder sb, fin * @param sb StringBuilder to append code to */ private static void generateFieldCaseStatement( - final StringBuilder sb, final Field field, final String schemaClassName) { + final StringBuilder sb, final Field field, final String schemaClassName, final boolean isMapField) { final int wireType = field.optionalValueType() ? Common.TYPE_LENGTH_DELIMITED : field.type().wireType(); @@ -290,7 +296,7 @@ private static void generateFieldCaseStatement( // means optional is default value value = $defaultValue; }""" - .replace("$fieldType", field.javaFieldType()) + .replace("$fieldType", field.javaFieldStorageType()) .replace("$readMethod", readMethod(field)) .replace("$defaultValue", switch (field.messageType()) { @@ -300,7 +306,7 @@ private static void generateFieldCaseStatement( case "DoubleValue" -> "0d"; case "BoolValue" -> "false"; case "BytesValue" -> "Bytes.EMPTY"; - case "StringValue" -> "\"\""; + case "StringValue" -> "PbjConstants.EMPTY_BYTES"; default -> throw new PbjCompilerException("Unexpected and unknown field type " + field.type() + " cannot be parsed"); }) .replace("$valueTypeWireType", Integer.toString( @@ -392,8 +398,10 @@ private static void generateFieldCaseStatement( .replace("$fieldName", field.name()) .replace("$fieldDefs",mapEntryFields.stream().map(mapEntryField -> "%s temp_%s = %s;".formatted(mapEntryField.javaFieldType(), - mapEntryField.name(), mapEntryField.javaDefault())).collect(Collectors.joining("\n"))) - .replace("$mapParseLoop", generateParseLoop(generateCaseStatements(mapEntryFields, schemaClassName), "map_entry_", schemaClassName) + mapEntryField.name(), + mapEntryField.type() == STRING ? "\"\"" : mapEntryField.javaDefault() + )).collect(Collectors.joining("\n"))) + .replace("$mapParseLoop", generateParseLoop(generateCaseStatements(mapEntryFields, schemaClassName, true), "map_entry_", schemaClassName) .indent(-DEFAULT_INDENT)) .replace("$maxSize", String.valueOf(field.maxSize())) ); @@ -408,9 +416,10 @@ private static void generateFieldCaseStatement( throw new PbjCompilerException("Fields can not be oneof and repeated ["+field+"]"); } else if (field.parent() != null) { final var oneOfField = field.parent(); - sb.append("temp_%s = new %s<>(%s.%s, value);%n" + sb.append("temp_%s = new %s<>(%s.%s, %s);%n" .formatted(oneOfField.name(), oneOfField.className(), oneOfField.getEnumClassRef(), - Common.camelToUpperSnake(field.name()))); + Common.camelToUpperSnake(field.name()), + field.isString() ? "toUtf8String(value)" : "value")); } else if (field.repeated()) { sb.append( """ @@ -431,6 +440,8 @@ private static void generateFieldCaseStatement( } """.formatted(field.name(), field.maxSize(), mapField.keyField().name(), mapField.valueField().name())); + } else if(field.isString() && isMapField){ + sb.append("temp_%s = toUtf8String(value);\n".formatted(field.name())); } else { sb.append("temp_%s = value;\n".formatted(field.name())); } @@ -441,7 +452,8 @@ private static void generateFieldCaseStatement( static String readMethod(Field field) { if (field.optionalValueType()) { return switch (field.messageType()) { - case "StringValue" -> "readString(input, %d)".formatted(field.maxSize()); + case "StringValue" -> + "readString%s(input, %d)".formatted(field.hasDifferentStorageType() ? "Raw" : "", field.maxSize()); case "Int32Value" -> "readInt32(input)"; case "UInt32Value" -> "readUint32(input)"; case "Int64Value" -> "readInt64(input)"; @@ -470,7 +482,8 @@ static String readMethod(Field field) { case DOUBLE -> "readDouble(input)"; case FIXED64 -> "readFixed64(input)"; case SFIXED64 -> "readSignedFixed64(input)"; - case STRING -> "readString(input, %d)".formatted(field.maxSize()); + case STRING -> + "readString%s(input, %d)".formatted(field.hasDifferentStorageType() ? "Raw" : "", field.maxSize()); case BOOL -> "readBool(input)"; case BYTES -> "readBytes(input, %d)".formatted(field.maxSize()); case MESSAGE -> field.parseCode(); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteByteArrayMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteByteArrayMethodGenerator.java index b11bd9c2f..d6c8afab5 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteByteArrayMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteByteArrayMethodGenerator.java @@ -17,7 +17,7 @@ /** * Code to generate the write method for Codec classes. */ -final class CodecWriteByteArrayMethodGenerator { +public final class CodecWriteByteArrayMethodGenerator { static String generateWriteMethod( final String modelClassName, final String schemaClassName, final List fields) { @@ -25,8 +25,10 @@ static String generateWriteMethod( modelClassName, schemaClassName, fields, - field -> " data.%s()".formatted(field.nameCamelFirstLower()), - true); + field -> " data.%s%s" + .formatted(field.nameCamelFirstLower(), field.hasDifferentStorageType() ? "" : "()"), + true, + false); // spotless:off return """ @@ -63,14 +65,20 @@ private static String buildFieldWriteLines( final String schemaClassName, final List fields, final Function getValueBuilder, - final boolean skipDefault) { + final boolean skipDefault, + final boolean accessStorageField) { return fields.stream() .flatMap(field -> field.type() == Field.FieldType.ONE_OF ? ((OneOfField) field).fields().stream() : Stream.of(field)) .sorted(Comparator.comparingInt(Field::fieldNumber)) .map(field -> generateFieldWriteLines( - field, modelClassName, schemaClassName, getValueBuilder.apply(field), skipDefault)) + field, + modelClassName, + schemaClassName, + getValueBuilder.apply(field), + skipDefault, + accessStorageField)) .collect(Collectors.joining("\n")) .indent(DEFAULT_INDENT); } @@ -82,14 +90,17 @@ private static String buildFieldWriteLines( * @param modelClassName The model class name for model class for message type we are generating writer for * @param getValueCode java code to get the value of field * @param skipDefault skip writing the field if it has default value (for non-oneOf only) + * @param accessStorageField true if the generated code has access to storage fields (e.g. in models), + * false otherwise (e.g. in codecs) * @return java code to write field to output */ - private static String generateFieldWriteLines( + public static String generateFieldWriteLines( final Field field, final String modelClassName, final String schemaClassName, String getValueCode, - boolean skipDefault) { + final boolean skipDefault, + final boolean accessStorageField) { final String fieldDef = schemaClassName + "." + Common.camelToUpperSnake(field.name()); String prefix = "// [%d] - %s%n".formatted(field.fieldNumber(), field.name()); @@ -104,8 +115,9 @@ private static String generateFieldWriteLines( final String writeMethodName = field.methodNameType(); if (field.optionalValueType()) { return prefix + switch (field.messageType()) { - case "StringValue" -> "offset += ProtoArrayWriterTools.writeOptionalString(output, offset, %s, %s);" - .formatted(fieldDef,getValueCode); + case "StringValue" -> accessStorageField || field.parent() != null + ? "offset += ProtoArrayWriterTools.writeOptionalString(output, offset, %s, %s);".formatted(fieldDef,getValueCode) + : "offset += %sWriteTo(output, offset);".formatted(getValueCode); case "BoolValue" -> "offset += ProtoArrayWriterTools.writeOptionalBoolean(output, offset, %s, %s);" .formatted(fieldDef, getValueCode); case "Int32Value" -> "offset += ProtoArrayWriterTools.writeOptionalInt32Value(output, offset, %s, %s);" @@ -149,6 +161,9 @@ private static String generateFieldWriteLines( .formatted(fieldDef, getValueCode); case FIXED64, SFIXED64 -> "offset += ProtoArrayWriterTools.writeFixed64List(output, offset, %s, %s);" .formatted(fieldDef, getValueCode); + case STRING -> accessStorageField + ? "offset += ProtoArrayWriterTools.writeByteArrayStringList(output, offset, %s, %s);".formatted(fieldDef, getValueCode) + : "offset += %sWriteTo(output, offset);".formatted(getValueCode); default -> "offset += ProtoArrayWriterTools.write%sList(output, offset, %s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); @@ -173,7 +188,8 @@ private static String generateFieldWriteLines( schemaClassName, mapEntryFields, getValueBuilder, - false); + false, + true); final String fieldSizeOfLines = CodecMeasureRecordMethodGenerator.buildFieldSizeOfLines( field.name(), mapEntryFields, @@ -205,8 +221,9 @@ private static String generateFieldWriteLines( return prefix + switch(field.type()) { case ENUM -> "offset += ProtoArrayWriterTools.writeEnum(output, offset, %s, %s);" .formatted(fieldDef, getValueCode); - case STRING -> "offset += ProtoArrayWriterTools.writeString(output, offset, %s, %s, %s);" - .formatted(fieldDef, getValueCode, skipDefault); + case STRING -> accessStorageField || field.parent() != null + ? "offset += ProtoArrayWriterTools.writeString(output, offset, %s, %s, %s);".formatted(fieldDef, getValueCode, skipDefault) + : "offset += %sWriteTo(output, offset);".formatted(getValueCode); case MESSAGE -> "offset += ProtoArrayWriterTools.writeMessage(output, offset, %s, %s, %s);" .formatted(fieldDef, getValueCode, codecReference); case BOOL -> "offset += ProtoArrayWriterTools.writeBoolean(output, offset, %s, %s, %s);" diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java index 3839b6a73..44c455c0a 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java @@ -17,7 +17,7 @@ /** * Code to generate the write method for Codec classes. */ -final class CodecWriteMethodGenerator { +public final class CodecWriteMethodGenerator { static String generateWriteMethod( final String modelClassName, final String schemaClassName, final List fields) { @@ -25,8 +25,10 @@ static String generateWriteMethod( modelClassName, schemaClassName, fields, - field -> "data.%s()".formatted(field.nameCamelFirstLower()), - true); + field -> " data.%s%s" + .formatted(field.nameCamelFirstLower(), field.hasDifferentStorageType() ? "" : "()"), + true, + false); // spotless:off return """ @@ -60,14 +62,20 @@ private static String buildFieldWriteLines( final String schemaClassName, final List fields, final Function getValueBuilder, - final boolean skipDefault) { + final boolean skipDefault, + final boolean accessStorageField) { return fields.stream() .flatMap(field -> field.type() == Field.FieldType.ONE_OF ? ((OneOfField) field).fields().stream() : Stream.of(field)) .sorted(Comparator.comparingInt(Field::fieldNumber)) .map(field -> generateFieldWriteLines( - field, modelClassName, schemaClassName, getValueBuilder.apply(field), skipDefault)) + field, + modelClassName, + schemaClassName, + getValueBuilder.apply(field), + skipDefault, + accessStorageField)) .collect(Collectors.joining("\n")) .indent(DEFAULT_INDENT); } @@ -79,21 +87,24 @@ private static String buildFieldWriteLines( * @param modelClassName The model class name for model class for message type we are generating writer for * @param getValueCode java code to get the value of field * @param skipDefault skip writing the field if it has default value (for non-oneOf only) + * @param accessStorageField true if the generated code has access to storage fields (e.g. in models), + * false otherwise (e.g. in codecs) * @return java code to write field to output */ - private static String generateFieldWriteLines( + public static String generateFieldWriteLines( final Field field, final String modelClassName, final String schemaClassName, String getValueCode, - boolean skipDefault) { + final boolean skipDefault, + final boolean accessStorageField) { final String fieldDef = schemaClassName + "." + Common.camelToUpperSnake(field.name()); String prefix = "// [%d] - %s%n".formatted(field.fieldNumber(), field.name()); if (field.parent() != null) { final OneOfField oneOfField = field.parent(); final String oneOfType = "%s.%sOneOfType".formatted(modelClassName, oneOfField.nameCamelFirstUpper()); - getValueCode = "data.%s().as()".formatted(oneOfField.nameCamelFirstLower()); + getValueCode = "(%s) data.%s().as()".formatted(field.javaFieldType(), oneOfField.nameCamelFirstLower()); prefix += "if (data.%s().kind() == %s.%s)%n" .formatted(oneOfField.nameCamelFirstLower(), oneOfType, Common.camelToUpperSnake(field.name())); } @@ -101,8 +112,9 @@ private static String generateFieldWriteLines( final String writeMethodName = field.methodNameType(); if (field.optionalValueType()) { return prefix + switch (field.messageType()) { - case "StringValue" -> "writeOptionalString(out, %s, %s);" - .formatted(fieldDef,getValueCode); + case "StringValue" -> accessStorageField || field.parent() != null + ? "writeOptionalString(out, %s, %s);".formatted(fieldDef,getValueCode) + : "%sWriteTo(out);".formatted(getValueCode); case "BoolValue" -> "writeOptionalBoolean(out, %s, %s);" .formatted(fieldDef, getValueCode); case "Int32Value","UInt32Value" -> "writeOptionalInteger(out, %s, %s);" @@ -130,6 +142,9 @@ private static String generateFieldWriteLines( .formatted(fieldDef, getValueCode); case MESSAGE -> "writeMessageList(out, %s, %s, %s);" .formatted(fieldDef, getValueCode, codecReference); + case STRING -> accessStorageField + ? "write%sList(out, %s, %s);".formatted(writeMethodName, fieldDef, getValueCode) + : "%sWriteTo(out);".formatted(getValueCode); default -> "write%sList(out, %s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); }; @@ -153,7 +168,8 @@ private static String generateFieldWriteLines( schemaClassName, mapEntryFields, getValueBuilder, - false); + false, + true); final String fieldSizeOfLines = CodecMeasureRecordMethodGenerator.buildFieldSizeOfLines( field.name(), mapEntryFields, @@ -185,8 +201,9 @@ private static String generateFieldWriteLines( return prefix + switch(field.type()) { case ENUM -> "writeEnum(out, %s, %s);" .formatted(fieldDef, getValueCode); - case STRING -> "writeString(out, %s, %s, %s);" - .formatted(fieldDef, getValueCode, skipDefault); + case STRING -> accessStorageField || field.parent() != null + ? "writeString(out, %s, %s, %s);".formatted(fieldDef, getValueCode, skipDefault) + : "%sWriteTo(out);".formatted(getValueCode); case MESSAGE -> "writeMessage(out, %s, %s, %s);" .formatted(fieldDef, getValueCode, codecReference); case BOOL -> "writeBoolean(out, %s, %s, %s);" diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/LazyGetProtobufSizeMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/LazyGetProtobufSizeMethodGenerator.java index 608c7585d..e7a36c845 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/LazyGetProtobufSizeMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/LazyGetProtobufSizeMethodGenerator.java @@ -106,7 +106,7 @@ private static String generateFieldSizeOfLines( final String oneOfType = modelClassName == null ? oneOfField.nameCamelFirstUpper() + "OneOfType" : modelClassName + "." + oneOfField.nameCamelFirstUpper() + "OneOfType"; - getValueCode = oneOfField.nameCamelFirstLower() + ".as()"; + getValueCode = "(" + field.javaFieldType() + ")" + oneOfField.nameCamelFirstLower() + ".as()"; prefix += "if (" + oneOfField.nameCamelFirstLower() + ".kind() == " + oneOfType + "." + Common.camelToUpperSnake(field.name()) + ")"; prefix += "\n"; diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjConstants.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjConstants.java new file mode 100644 index 000000000..1eea393be --- /dev/null +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjConstants.java @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.runtime; + +/** + * Constants used in the PBJ runtime. + */ +public final class PbjConstants { + /** An empty byte array constant to avoid creating multiple empty arrays */ + public static final byte[] EMPTY_BYTES = new byte[0]; +} diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoArrayWriterTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoArrayWriterTools.java index f82f1bbc9..c98c3dbb5 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoArrayWriterTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoArrayWriterTools.java @@ -190,6 +190,56 @@ private static int writeStringNoChecks( return bytesWritten; } + /** + * Write a string to data output, assuming the field is non-repeated. + * + * @param output the byte array to write to + * @param offset the offset to start writing at + * @param field the descriptor for the field we are writing, the field must be non-repeated + * @param value the string value to write + * @param skipDefault default value results in no-op for non-oneOf + * @return the number of bytes written + */ + public static int writeString( + @NonNull byte[] output, + final int offset, + @NonNull final FieldDefinition field, + final byte[] value, + final boolean skipDefault) { + assert field.type() == FieldType.STRING : "Not a string type " + field; + assert !field.repeated() : "Use writeStringList with repeated types"; + return writeStringNoChecks(output, offset, field, value, skipDefault); + } + + /** + * Write a string to data output - no validation checks. + * + * @param output the byte array to write to + * @param offset the offset to start writing at + * @param field the descriptor for the field we are writing + * @param value the string value to write + * @param skipDefault default value results in no-op for non-oneOf + * @return the number of bytes written + */ + private static int writeStringNoChecks( + @NonNull byte[] output, + final int offset, + @NonNull final FieldDefinition field, + final byte[] value, + final boolean skipDefault) { + int bytesWritten = 0; + // When not a oneOf don't write default value + if (skipDefault && !field.oneOf() && (value == null || value.length == 0)) { + return 0; + } + bytesWritten += writeTag(output, offset, field, WIRE_TYPE_DELIMITED); + final int size = sizeOfStringNoTag(value); + bytesWritten += writeUnsignedVarInt(output, offset + bytesWritten, size); + System.arraycopy(value, 0, output, offset + bytesWritten, size); + bytesWritten += size; + return bytesWritten; + } + /** * Write an optional string to data output * @@ -220,6 +270,38 @@ public static int writeOptionalString( return bytesWritten; } + /** + * Write an optional string to data output + * + * @param output the byte array to write to + * @param offset the offset to start writing at + * @param field the field definition for the string field + * @param value the optional string value to write + * @return the number of bytes written + */ + public static int writeOptionalString( + @NonNull byte[] output, + final int offset, + @NonNull final FieldDefinition field, + @Nullable final byte[] value) { + int bytesWritten = 0; + if (value != null) { + bytesWritten += writeTag(output, offset, field, WIRE_TYPE_DELIMITED); + final var newField = field.type().optionalFieldDefinition; + bytesWritten += writeUnsignedVarInt(output, offset + bytesWritten, sizeOfString(newField, value)); + + // Don't write default value + if (value.length != 0) { + bytesWritten += writeTag(output, offset + bytesWritten, newField, WIRE_TYPE_DELIMITED); + final int size = sizeOfStringNoTag(value); + bytesWritten += writeUnsignedVarInt(output, offset + bytesWritten, size); + System.arraycopy(value, 0, output, offset + bytesWritten, size); + bytesWritten += size; + } + } + return bytesWritten; + } + /** * Write a boolean to data output * @@ -1122,6 +1204,36 @@ public static int writeStringList( return curOffset - offset; } + /** + * Write a list of strings to data output + * + * @param output the byte array to write to + * @param offset the offset to start writing at + * @param field the descriptor for the field we are writing + * @param list the list of strings value to write + * @return the number of bytes written + */ + public static int writeByteArrayStringList( + @NonNull final byte[] output, final int offset, FieldDefinition field, List list) { + assert field.type() == FieldType.STRING : "Not a string type " + field; + assert field.repeated() : "Use writeString with non-repeated types"; + // When not a oneOf don't write default value + if (!field.oneOf() && list.isEmpty()) { + return 0; + } + int curOffset = offset; + final int listSize = list.size(); + for (int i = 0; i < listSize; i++) { + final byte[] value = list.get(i); + curOffset += writeTag(output, curOffset, field, WIRE_TYPE_DELIMITED); + final int size = sizeOfStringNoTag(value); + curOffset += writeUnsignedVarInt(output, curOffset, size); + System.arraycopy(value, 0, output, curOffset, size); + curOffset += size; + } + return curOffset - offset; + } + /** * Write a list of messages to data output * diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java index 246e1287e..bc50478a1 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java @@ -268,6 +268,31 @@ public static String readString(final ReadableSequentialData input, final long m } } + /** + * Read a String field from data input + * + * @param input the input to read from + * @param maxSize the maximum allowed size + * @return Read string + * @throws ParseException if the length is greater than maxSize + */ + public static byte[] readStringRaw(final ReadableSequentialData input, final long maxSize) + throws IOException, ParseException { + final int length = input.readVarInt(false); + if (length > maxSize) { + throw new ParseException("size " + length + " is greater than max " + maxSize); + } + if (input.remaining() < length) { + throw new BufferUnderflowException(); + } + byte[] result = new byte[length]; + final long bytesRead = input.readBytes(result); + if (bytesRead != length) { + throw new BufferUnderflowException(); + } + return result; + } + /** * Read a Bytes field from data input * diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java index c17d10839..61fa9a059 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java @@ -72,8 +72,8 @@ public static void writeTag( } /** Create an unsupported field type exception */ - private static RuntimeException unsupported() { - return new RuntimeException("Unsupported field type. Bug in ProtoOutputStream, shouldn't happen."); + static RuntimeException unsupported() { + return new RuntimeException("Unsupported field type, shouldn't happen."); } // ================================================================================================================ @@ -277,6 +277,19 @@ public static void writeString(final WritableSequentialData out, final FieldDefi writeString(out, field, value, true); } + /** + * Write a string represented by a byte[] to data output, assuming the field is non-repeated. + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing, the field must be non-repeated + * @param value the string value to write + * @throws IOException If a I/O error occurs + */ + public static void writeString(final WritableSequentialData out, final FieldDefinition field, final byte[] value) + throws IOException { + writeString(out, field, value, true); + } + /** * Write a string to data output, assuming the field is non-repeated. * @@ -294,6 +307,23 @@ public static void writeString( writeStringNoChecks(out, field, value, skipDefault); } + /** + * Write a string to data output, assuming the field is non-repeated. + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing, the field must be non-repeated + * @param value the string value to write + * @param skipDefault default value results in no-op for non-oneOf + * @throws IOException If a I/O error occurs + */ + public static void writeString( + final WritableSequentialData out, final FieldDefinition field, final byte[] value, boolean skipDefault) + throws IOException { + assert field.type() == FieldType.STRING : "Not a string type " + field; + assert !field.repeated() : "Use writeStringList with repeated types"; + writeStringNoChecks(out, field, value, skipDefault); + } + /** * Write a string to data output, assuming the field is repeated. Usually this method is called multiple * times, one for every repeated value. If all values are available immediately, {@link #writeStringList( @@ -305,7 +335,7 @@ public static void writeString( * @throws IOException If a I/O error occurs */ public static void writeOneRepeatedString( - final WritableSequentialData out, final FieldDefinition field, final String value) throws IOException { + final WritableSequentialData out, final FieldDefinition field, final byte[] value) throws IOException { assert field.type() == FieldType.STRING : "Not a string type " + field; assert field.repeated() : "writeOneRepeatedString can only be used with repeated fields"; writeStringNoChecks(out, field, value); @@ -320,7 +350,7 @@ public static void writeOneRepeatedString( * @throws IOException If a I/O error occurs */ private static void writeStringNoChecks( - final WritableSequentialData out, final FieldDefinition field, final String value) throws IOException { + final WritableSequentialData out, final FieldDefinition field, final byte[] value) throws IOException { writeStringNoChecks(out, field, value, true); } @@ -345,6 +375,27 @@ private static void writeStringNoChecks( Utf8Tools.encodeUtf8(value, out); } + /** + * Write a integer to data output - no validation checks. + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing + * @param value the string value to write + * @param skipDefault default value results in no-op for non-oneOf + * @throws IOException If a I/O error occurs + */ + private static void writeStringNoChecks( + final WritableSequentialData out, final FieldDefinition field, final byte[] value, boolean skipDefault) + throws IOException { + // When not a oneOf don't write default value + if (skipDefault && !field.oneOf() && (value == null || value.length == 0)) { + return; + } + writeTag(out, field, WIRE_TYPE_DELIMITED); + out.writeVarInt(value.length, false); + out.writeBytes(value); + } + /** * Write a bytes to data output, assuming the corresponding field is non-repeated, and field type * is any delimited: bytes, string, or message. @@ -630,6 +681,27 @@ public static void writeOptionalString(WritableSequentialData out, FieldDefiniti } } + /** + * Write an optional string represented by a byte[] to data output + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing + * @param value the optional string value to write + * @throws IOException If a I/O error occurs + */ + public static void writeOptionalString(WritableSequentialData out, FieldDefinition field, @Nullable byte[] value) + throws IOException { + if (value != null) { + writeTag(out, field, WIRE_TYPE_DELIMITED); + final var newField = field.type().optionalFieldDefinition; + final int size = sizeOfString(newField, value, true); + out.writeVarInt(size, false); + if (size > 0) { + writeString(out, newField, value); + } + } + } + /** * Write an optional bytes to data output * @@ -896,7 +968,7 @@ public static void writeEnumList( * @param list the list of strings value to write * @throws IOException If a I/O error occurs */ - public static void writeStringList(WritableSequentialData out, FieldDefinition field, List list) + public static void writeStringList(WritableSequentialData out, FieldDefinition field, List list) throws IOException { assert field.type() == FieldType.STRING : "Not a string type " + field; assert field.repeated() : "Use writeString with non-repeated types"; @@ -906,10 +978,14 @@ public static void writeStringList(WritableSequentialData out, FieldDefinition f } final int listSize = list.size(); for (int i = 0; i < listSize; i++) { - final String value = list.get(i); + final byte[] value = list.get(i); writeTag(out, field, WIRE_TYPE_DELIMITED); - out.writeVarInt(sizeOfStringNoTag(value), false); - Utf8Tools.encodeUtf8(value, out); + if (value == null) { + out.writeVarInt(0, false); + } else { + out.writeVarInt(value.length, false); + out.writeBytes(value); + } } } @@ -1150,6 +1226,21 @@ public static int sizeOfOptionalString(FieldDefinition field, @Nullable String v return 0; } + /** + * Get number of bytes that would be needed to encode an optional string field represented by a byte[] + * + * @param field descriptor of field + * @param value optional string value to get encoded size for + * @return the number of bytes for encoded value + */ + public static int sizeOfOptionalString(FieldDefinition field, @Nullable byte[] value) { + if (value != null) { + final int size = sizeOfString(field.type().optionalFieldDefinition, value, true); + return sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfUnsignedVarInt32(size) + size; + } + return 0; + } + /** * Get number of bytes that would be needed to encode an optional bytes field * @@ -1299,6 +1390,17 @@ public static int sizeOfString(FieldDefinition field, String value) { return sizeOfString(field, value, true); } + /** + * Get number of bytes that would be needed to encode a string field + * + * @param field descriptor of field + * @param value string value to get encoded size for + * @return the number of bytes for encoded value + */ + public static int sizeOfString(FieldDefinition field, byte[] value) { + return value == null || value.length == 0 ? 0 : sizeOfDelimited(field, value.length); + } + /** * Get number of bytes that would be needed to encode a string field * @@ -1315,6 +1417,21 @@ public static int sizeOfString(FieldDefinition field, String value, boolean skip return sizeOfDelimited(field, sizeOfStringNoTag(value)); } + /** + * Get number of bytes that would be needed to encode a string field + * + * @param field descriptor of field + * @param value string value to get encoded size for + * @param skipDefault default value results in zero size + * @return the number of bytes for encoded value + */ + public static int sizeOfString(FieldDefinition field, byte[] value, boolean skipDefault) { + if (skipDefault && !field.oneOf() && (value == null || value.length == 0)) { + return 0; + } + return sizeOfDelimited(field, value.length); + } + /** * Get number of bytes that would be needed to encode a string, without field tag * @@ -1333,6 +1450,16 @@ static int sizeOfStringNoTag(String value) { } } + /** + * Get number of bytes that would be needed to encode a string, without field tag + * + * @param value string value to get encoded size for + * @return the number of bytes for encoded value + */ + static int sizeOfStringNoTag(byte[] value) { + return value == null ? 0 : value.length; + } + /** * Get number of bytes that would be needed to encode a bytes field * @@ -1525,7 +1652,7 @@ public static int sizeOfEnumList(FieldDefinition field, List list) { + public static int sizeOfStringList(FieldDefinition field, List list) { int size = 0; final int listSize = list.size(); for (int i = 0; i < listSize; i++) { diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Utf8Tools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Utf8Tools.java index c757c8922..6ea6abcb8 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Utf8Tools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Utf8Tools.java @@ -6,12 +6,44 @@ import com.hedera.pbj.runtime.io.WritableSequentialData; import edu.umd.cs.findbugs.annotations.NonNull; import java.io.IOException; +import java.util.Arrays; +import java.util.List; /** * UTF8 tools based on protobuf standard library, so we are byte for byte identical */ public final class Utf8Tools { + public static byte[] toUtf8Bytes(final String string) { + return string.getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + + public static List toUtf8Bytes(final List strings) { + return strings.stream() + .map(s -> { + return s.getBytes(java.nio.charset.StandardCharsets.UTF_8); + }) + .toList(); + } + + public static List toUtf8Bytes(final String... strings) { + return Arrays.stream(strings) + .map(s -> { + return s.getBytes(java.nio.charset.StandardCharsets.UTF_8); + }) + .toList(); + } + + public static String toUtf8String(final byte[] bytes) { + return new String(bytes, java.nio.charset.StandardCharsets.UTF_8); + } + + public static List toUtf8String(final List bytesList) { + return bytesList.stream() + .map(bytes -> new String(bytes, java.nio.charset.StandardCharsets.UTF_8)) + .toList(); + } + /** * Returns the number of bytes in the UTF-8-encoded form of {@code sequence}. For a string, this * method is equivalent to {@code string.getBytes(UTF_8).length}, but is more efficient in both diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java index f0cc7df58..a167f114b 100644 --- a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java @@ -318,8 +318,8 @@ void testSkipField() throws IOException { writeInteger(data, createFieldDefinition(FIXED32), rng.nextInt()); int value = rng.nextInt(0, Integer.MAX_VALUE); writeInteger(data, createFieldDefinition(INT32), value); - writeString(data, createFieldDefinition(STRING), randomVarSizeString()); - writeString(data, createFieldDefinition(STRING), valToRead); + writeString(data, createFieldDefinition(STRING), randomVarSizeString().getBytes(StandardCharsets.UTF_8)); + writeString(data, createFieldDefinition(STRING), valToRead.getBytes(StandardCharsets.UTF_8)); data.flip(); @@ -410,7 +410,7 @@ private static Bytes prepareExtractBytesTestInput() throws IOException { final WritableStreamingData out = new WritableStreamingData(bout)) { ProtoWriterTools.writeInteger(out, INT32_F, INT32_V); ProtoWriterTools.writeInteger(out, FIXED_F, FIXED32_V); - ProtoWriterTools.writeString(out, STRING_F, STRING_V); + ProtoWriterTools.writeString(out, STRING_F, STRING_V.getBytes(StandardCharsets.UTF_8)); ProtoWriterTools.writeBytes(out, BYTES_F, BYTES_V); ProtoWriterTools.writeMessage(out, MESSAGE_F, MESSAGE_V, TestMessageCodec.INSTANCE); ProtoWriterTools.writeDouble(out, DOUBLE_F, DOUBLE32_V); @@ -573,7 +573,7 @@ public void write(@NonNull final TestMessage item, @NonNull final WritableSequen throws IOException { final String value = item.getValue(); if (value != null) { - ProtoWriterTools.writeString(out, VALUE_FIELD, value); + ProtoWriterTools.writeString(out, VALUE_FIELD, value.getBytes(StandardCharsets.UTF_8)); } } diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java index 0a372fe9d..8dbb79a4b 100644 --- a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java @@ -76,6 +76,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.function.BiConsumer; import java.util.random.RandomGenerator; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -407,7 +408,7 @@ void testWriteEnumZeroOrdinal() { @Test void testWriteString_empty() throws IOException { FieldDefinition definition = createFieldDefinition(STRING); - String valToWrite = ""; + byte[] valToWrite = new byte[0]; final long positionBefore = bufferedData.position(); writeString(bufferedData, definition, valToWrite); final long positionAfter = bufferedData.position(); @@ -417,20 +418,20 @@ void testWriteString_empty() throws IOException { @Test void testWriteString() throws IOException { FieldDefinition definition = createFieldDefinition(STRING); - String valToWrite = RANDOM_STRING.nextString(); + byte[] valToWrite = RANDOM_STRING.nextString().getBytes(StandardCharsets.UTF_8); writeString(bufferedData, definition, valToWrite); bufferedData.flip(); assertEquals( (definition.number() << TAG_TYPE_BITS) | WIRE_TYPE_DELIMITED.ordinal(), bufferedData.readVarInt(false)); int length = bufferedData.readVarInt(false); - assertEquals(valToWrite, new String(bufferedData.readBytes(length).toByteArray())); + assertArrayEquals(valToWrite, bufferedData.readBytes(length).toByteArray()); } @Test void testWriteOneRepeatedString() throws IOException { final FieldDefinition definition = createRepeatedFieldDefinition(STRING); - final String value1 = RANDOM_STRING.nextString(); - final String value2 = RANDOM_STRING.nextString(); + final byte[] value1 = RANDOM_STRING.nextString().getBytes(StandardCharsets.UTF_8); + final byte[] value2 = RANDOM_STRING.nextString().getBytes(StandardCharsets.UTF_8); final BufferedData buf1 = BufferedData.allocate(256); ProtoWriterTools.writeStringList(buf1, definition, List.of(value1, value2)); final Bytes writtenBytes1 = buf1.getBytes(0, buf1.position()); @@ -672,7 +673,7 @@ void testWriteOptionalBoolean_null() { void testWriteOptionalString() throws IOException { FieldDefinition definition = createOptionalFieldDefinition(STRING); String valToWrite = RANDOM_STRING.nextString(); - writeOptionalString(bufferedData, definition, valToWrite); + writeOptionalString(bufferedData, definition, valToWrite.getBytes(StandardCharsets.UTF_8)); bufferedData.flip(); assertTypeDelimitedTag(definition); assertEquals(valToWrite.length() + TAG_SIZE + MIN_LENGTH_VAR_SIZE, bufferedData.readVarInt(false)); @@ -686,7 +687,7 @@ void testWriteOptionalString() throws IOException { @Test void testWriteOptionalString_null() throws IOException { FieldDefinition definition = createOptionalFieldDefinition(STRING); - writeOptionalString(bufferedData, definition, null); + writeOptionalString(bufferedData, definition, (String) null); bufferedData.flip(); assertEquals(0, bufferedData.length()); } @@ -1090,11 +1091,11 @@ void testSizeOfOneOfEnumList_empty() { @Test void testSizeOfStringList() { FieldDefinition definition = createFieldDefinition(STRING); - String str1 = randomVarSizeString(); - String str2 = randomVarSizeString(); + byte[] str1 = randomVarSizeString().getBytes(StandardCharsets.UTF_8); + byte[] str2 = randomVarSizeString().getBytes(StandardCharsets.UTF_8); assertEquals( - MIN_LENGTH_VAR_SIZE * 2 + TAG_SIZE * 2 + str1.length() + str2.length(), + MIN_LENGTH_VAR_SIZE * 2 + TAG_SIZE * 2 + str1.length + str2.length, sizeOfStringList(definition, asList(str1, str2))); } @@ -1102,7 +1103,7 @@ void testSizeOfStringList() { void testSizeOfStringList_nullAndEmpty() { FieldDefinition definition = createFieldDefinition(STRING); - assertEquals(MIN_LENGTH_VAR_SIZE * 2 + TAG_SIZE * 2, sizeOfStringList(definition, asList(null, ""))); + assertEquals(MIN_LENGTH_VAR_SIZE * 2 + TAG_SIZE * 2, sizeOfStringList(definition, asList(null, new byte[0]))); } @Test @@ -1186,7 +1187,7 @@ void testSizeOfString() { @Test void testSizeOfString_null() { final FieldDefinition definition = createFieldDefinition(STRING); - assertEquals(0, sizeOfString(definition, null)); + assertEquals(0, sizeOfString(definition, (byte[]) null)); } @Test @@ -1206,7 +1207,7 @@ void testSizeOfString_oneOf() { @Test void testSizeOfString_oneOf_null() { final FieldDefinition definition = createOneOfFieldDefinition(STRING); - assertEquals(MIN_LENGTH_VAR_SIZE + TAG_SIZE, sizeOfString(definition, null)); + assertEquals(MIN_LENGTH_VAR_SIZE + TAG_SIZE, sizeOfString(definition, (String) null)); } @Test @@ -1435,22 +1436,24 @@ private static Stream provideWriteUnpackedListArguments() { return Stream.of( Arguments.of( STRING, - (WriterMethod) (out, field, list) -> { + (WriterMethod) (out, field, list) -> { try { ProtoWriterTools.writeStringList(out, field, list); } catch (IOException e) { Sneaky.sneakyThrow(e); } }, - List.of("string 1", "testing here", "testing there"), - (ReaderMethod>) (BufferedData bd, long pos) -> { + List.of( + "string 1".getBytes(StandardCharsets.UTF_8), + "testing here".getBytes(StandardCharsets.UTF_8), + "testing there".getBytes(StandardCharsets.UTF_8)), + (ReaderMethod>) (BufferedData bd, long pos) -> { int size = bd.getVarInt(pos, false); int sizeOfSize = ProtoWriterTools.sizeOfVarInt32(size); return new UnpackedField<>( - new String( - bd.getBytes(pos + sizeOfSize, size).toByteArray(), StandardCharsets.UTF_8), - sizeOfSize + size); - }), + bd.getBytes(pos + sizeOfSize, size).toByteArray(), sizeOfSize + size); + }, + (BiConsumer) (a, b) -> assertArrayEquals(a, b)), Arguments.of( BYTES, (WriterMethod) (out, field, list) -> { @@ -1468,7 +1471,8 @@ private static Stream provideWriteUnpackedListArguments() { int size = bd.getVarInt(pos, false); int sizeOfSize = ProtoWriterTools.sizeOfVarInt32(size); return new UnpackedField<>(bd.getBytes(pos + sizeOfSize, size), sizeOfSize + size); - })); + }, + (BiConsumer) (a, b) -> assertEquals(a, b))); } @ParameterizedTest @@ -1477,7 +1481,8 @@ void testWriteUnpackedList( final FieldType type, final WriterMethod writerMethod, final List list, - final ReaderMethod> readerMethod) { + final ReaderMethod> readerMethod, + final BiConsumer assertMethod) { final FieldDefinition definition = createRepeatedFieldDefinition(type); final long start = bufferedData.position(); @@ -1491,7 +1496,7 @@ void testWriteUnpackedList( int sizeOfTag = ProtoWriterTools.sizeOfVarInt32(tag); UnpackedField value = readerMethod.read(bufferedData, start + offset + sizeOfTag); - assertEquals(list.get(i), value.value()); + assertMethod.accept(list.get(i), value.value()); offset += sizeOfTag + value.size(); } diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8Bench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8Bench.java new file mode 100644 index 000000000..a3ff7b49d --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8Bench.java @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh.utf8; + +import java.nio.charset.StandardCharsets; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * Benchmarks three implementations: Utf8ToolsV1, Utf8ToolsV2, Utf8ToolsV3. + * + *

Each implementation must provide:

+ *
+ *   public final class Utf8ToolsV* {
+ *       // private static int encodedLength(String) throws IOException {...} // (not used here)
+ *       public static String decodeUtf8(byte[] in, int offset, int length) throws java.io.IOException;
+ *       public static void encodeUtf8(String in, byte[] out, int offset) throws java.io.IOException;
+ *   }
+ * 
+ *

We precompute input strings and their UTF-8 byte[] using the JDK encoder in @Setup, + * and we preallocate output buffers sized to the exact UTF-8 length so the measured + * methods do not allocate (other than what the implementation itself does).

+ * + *

Make sure you run with JVM arg --add-opens java.base/java.lang=ALL-UNNAMED

+ */ +@SuppressWarnings("SameParameterValue") +@BenchmarkMode({Mode.AverageTime}) // ops/sec; switch to SampleTime if you want latency histograms +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 3) +@Fork(value = 1) +@State(Scope.Thread) +public class Utf8Bench { + // -------------------- Parameters -------------------- + + /** Which implementation to run (maps to class switch below). */ + @Param({"V0", "V1", "V2", "V3", "V4"}) + // @Param({"V4"}) + public String impl; + + /** Dataset shape: ASCII-heavy, Latin-1, Mixed BMP, Emoji (surrogates). */ + @Param({"ascii", "latin1", "mixed", "emoji"}) + // @Param({"ascii"}) + public String dataset; + + /** Mean string length to generate. Actual strings vary around this length. */ + @Param({"8", "32", "100"}) + public int meanLen; + + /** Number of distinct strings in the corpus (cycled during the run). */ + @Param({"1024"}) + public int corpusSize; + + // -------------------- Corpus -------------------- + + private String[] strings; // inputs for encode & round-trip + private byte[][] utf8Bytes; // pre-encoded with JDK for decode benchmark + private byte[][] encodeBuffers; // sized exactly to utf8 length + private int[] utf8Lens; // lengths of utf8Bytes[i] + private int[] utf8LensJdk; // for correctness check of encodedLength + private int idxMask; // for cheap modulo (power-of-two corpus) + + private final Random rnd = new Random(4141684161512124L); + + public static void main(String[] args) throws Exception { + Options opt = + new OptionsBuilder().include(Utf8Bench.class.getSimpleName()).build(); + + new Runner(opt).run(); + } + + // -------------------- Lifecycle -------------------- + + @Setup(Level.Trial) + public void setUp() throws Exception { + // Ensure corpusSize is a power of two for cheap cycling + int pow2 = 1; + while (pow2 < corpusSize) pow2 <<= 1; + if (pow2 != corpusSize) { + corpusSize = pow2; // silently round up + } + idxMask = corpusSize - 1; + + strings = new String[corpusSize]; + utf8Bytes = new byte[corpusSize][]; + utf8Lens = new int[corpusSize]; + utf8LensJdk = new int[corpusSize]; + encodeBuffers = new byte[corpusSize][]; + + // Generate corpus + for (int i = 0; i < corpusSize; i++) { + String s = + switch (dataset) { + case "ascii" -> genAsciiString(meanLen, 0.50); + case "latin1" -> genLatin1String(meanLen, 0.50); // includes bytes 0x80..0xFF (→ 2-byte UTF-8) + case "mixed" -> genMixedBmpString(meanLen, 0.30, 0.10); // some non-ASCII BMP + case "emoji" -> genEmojiString(meanLen, 0.15); // surrogate pairs sprinkled in + default -> throw new IllegalArgumentException("Unknown dataset: " + dataset); + }; + strings[i] = s; + + // Pre-encode with the JDK for decode() input and for sizing encode buffers. + byte[] u = s.getBytes(StandardCharsets.UTF_8); + utf8Bytes[i] = u; + utf8Lens[i] = u.length; + utf8LensJdk[i] = u.length; + encodeBuffers[i] = new byte[u.length]; // exact size; offset=0 in benchmarks + } + + // Quick sanity: round-trip each impl once to catch broken code before timing + for (String version : new String[] {"V1", "V2", "V3"}) { + for (int i = 0; i < Math.min(corpusSize, 128); i++) { + String s = strings[i]; + byte[] buf = new byte[utf8Lens[i]]; + encode(version, s, buf, 0); + String back = decode(version, buf, 0, buf.length); + if (!s.equals(back)) { + throw new IllegalStateException(version + " failed round-trip on sample " + i); + } + } + } + + // Sanity: encodedLength must match JDK UTF-8 byte length + for (String version : new String[] {"V1", "V2", "V3"}) { + for (int i = 0; i < Math.min(corpusSize, 1024); i++) { + int len = strings[i].getBytes(java.nio.charset.StandardCharsets.UTF_8).length; + if (len != utf8LensJdk[i]) { + throw new IllegalStateException(version + " encodedLength mismatch at " + i + + " expected=" + utf8LensJdk[i] + " got=" + len + + " str=\"" + preview(strings[i]) + "\""); + } + } + } + } + + // -------------------- Benchmarks -------------------- + + private int cursor = 0; + + /** Encode String -> UTF-8 bytes into a pre-sized buffer. */ + @Benchmark + public void encode(Blackhole bh) throws Exception { + final int i = (cursor++) & idxMask; + final String s = strings[i]; + final byte[] out = encodeBuffers[i]; + encode(impl, s, out, 0); + // Consume a couple of bytes to keep JIT honest (avoid DCE): + bh.consume(out[0]); + bh.consume(out[out.length - 1]); + } + + /** Decode UTF-8 bytes -> String (input was pre-encoded by the JDK). */ + @Benchmark + public void decode(Blackhole bh) throws Exception { + final int i = (cursor++) & idxMask; + final byte[] in = utf8Bytes[i]; + final String s = decode(impl, in, 0, in.length); + // Consume length & first char (if present) to avoid DCE: + bh.consume(s.length()); + if (!s.isEmpty()) bh.consume(s.charAt(0)); + } + + /** Round-trip using the impl for both encode and decode (avoids JDK encoder in timed region). */ + @Benchmark + public void roundTrip(Blackhole bh) throws Exception { + final int i = (cursor++) & idxMask; + final String s = strings[i]; + final byte[] buf = encodeBuffers[i]; + encode(impl, s, buf, 0); + final String back = decode(impl, buf, 0, buf.length); + bh.consume(back.length()); + } + + /** Measure just the UTF-8 length computation (no encoding). */ + @Benchmark + public void encodedLength(Blackhole bh) { + final int i = (cursor++) & idxMask; + final String s = strings[i]; + final int len = encodedLength(impl, s); + // consume value and a char to discourage CSE / constant folding + bh.consume(len); + if (!s.isEmpty()) bh.consume(s.charAt(0)); + } + + // -------------------- Dispatch to implementations -------------------- + + // Replace these with your real classes (same static method signatures). + // Example: Utf8ToolsV1.encodeUtf8(in, out, off); + private static void encode(String version, String in, byte[] out, int off) throws Exception { + switch (version) { + case "V0" -> Utf8ToolsV0.encodeUtf8(in, out, off); + case "V1" -> Utf8ToolsV1.encodeUtf8(in, out, off); + case "V2" -> Utf8ToolsV2.encodeUtf8(in, out, off); + case "V3" -> Utf8ToolsV3.encodeUtf8(in, out, off); + case "V4" -> Utf8ToolsV4.encodeUtf8(in, out, off); + default -> throw new IllegalArgumentException(version); + } + } + + private static String decode(String version, byte[] in, int off, int len) throws Exception { + return switch (version) { + case "V0" -> Utf8ToolsV0.decodeUtf8(in, off, len); + case "V1" -> Utf8ToolsV1.decodeUtf8(in, off, len); + case "V2" -> Utf8ToolsV2.decodeUtf8(in, off, len); + case "V3" -> Utf8ToolsV3.decodeUtf8(in, off, len); + case "V4" -> Utf8ToolsV4.decodeUtf8(in, off, len); + default -> throw new IllegalArgumentException(version); + }; + } + + private static int encodedLength(String version, String s) { + try { + return switch (version) { + case "V0" -> Utf8ToolsV0.encodedLength(s); + case "V1" -> Utf8ToolsV1.encodedLength(s); + case "V2" -> Utf8ToolsV2.encodedLength(s); + case "V3" -> Utf8ToolsV3.encodedLength(s); + case "V4" -> Utf8ToolsV4.encodedLength(s); + default -> throw new IllegalArgumentException(version); + }; + } catch (java.io.IOException e) { + // Treat malformed handling as a failure in correctness check + throw new RuntimeException(e); + } + } + + // -------------------- Generators (fast & simple; deterministic-ish) -------------------- + + private String genAsciiString(int mean, double punctRatio) { + int len = jitteredLen(mean); + StringBuilder sb = new StringBuilder(len); + for (int i = 0; i < len; i++) { + if (rnd.nextDouble() < punctRatio) { + sb.append(" .,-_/+[]()".charAt(rnd.nextInt(11))); + } else { + char c = (char) ('a' + rnd.nextInt(26)); + if (rnd.nextBoolean()) c = Character.toUpperCase(c); + sb.append(c); + } + } + return sb.toString(); + } + + private String genLatin1String(int mean, double highRatio) { + int len = jitteredLen(mean); + StringBuilder sb = new StringBuilder(len); + for (int i = 0; i < len; i++) { + if (rnd.nextDouble() < highRatio) { + // 0x80..0xFF (valid Latin-1; forces 2-byte UTF-8) + sb.append((char) (0x80 + rnd.nextInt(0x80))); + } else { + sb.append((char) (' ' + rnd.nextInt(95))); // ASCII printable + } + } + return sb.toString(); + } + + private String genMixedBmpString(int mean, double nonAsciiRatio, double threeByteRatio) { + int len = jitteredLen(mean); + StringBuilder sb = new StringBuilder(len); + for (int i = 0; i < len; i++) { + double r = rnd.nextDouble(); + if (r < nonAsciiRatio) { + if (r < threeByteRatio) { + // 3-byte UTF-8 BMP range excluding surrogates (e.g., Greek/Cyrillic) + char c = (char) (0x0800 + rnd.nextInt(0xD7FF - 0x0800)); + sb.append(c); + } else { + // Latin-1 high bytes (2-byte UTF-8) + sb.append((char) (0x80 + rnd.nextInt(0x80))); + } + } else { + sb.append((char) (' ' + rnd.nextInt(95))); + } + } + return sb.toString(); + } + + private String genEmojiString(int mean, double emojiRatio) { + int len = jitteredLen(mean); + StringBuilder sb = new StringBuilder(len); + for (int i = 0; i < len; i++) { + if (rnd.nextDouble() < emojiRatio) { + // A few common emoji code points (U+1F3xx / U+1F60x / U+1F9xx) + int[] cps = {0x1F600, 0x1F602, 0x1F603, 0x1F60D, 0x1F680, 0x1F64C, 0x1F4AF, 0x1F3C3, 0x1F9E9}; + int cp = cps[rnd.nextInt(cps.length)]; + sb.appendCodePoint(cp); + } else { + sb.append((char) (' ' + rnd.nextInt(95))); + } + } + return sb.toString(); + } + + private int jitteredLen(int mean) { + // ±25% jitter around mean, min 1 + int span = Math.max(1, mean / 4); + return Math.max(1, mean - span + rnd.nextInt(2 * span + 1)); + } + + private static String preview(String s) { + if (s.length() <= 24) return s; + return s.substring(0, 24) + "…(" + s.length() + ")"; + } +} diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV0.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV0.java new file mode 100644 index 000000000..8ee46f0a8 --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV0.java @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh.utf8; + +/** + * UTF8 Tools based on java standard library + */ +@SuppressWarnings("DuplicatedCode") +public final class Utf8ToolsV0 { + public static int encodedLength(final String in) { + return in.getBytes(java.nio.charset.StandardCharsets.UTF_8).length; + } + + public static String decodeUtf8(byte[] in, int offset, int length) { + return new String(in, offset, length, java.nio.charset.StandardCharsets.UTF_8); + } + + public static int encodeUtf8(String in, byte[] out, int offset) { + byte[] b = in.getBytes(java.nio.charset.StandardCharsets.UTF_8); + System.arraycopy(b, 0, out, offset, b.length); + return b.length; + } +} diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV1.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV1.java new file mode 100644 index 000000000..ab5e83eec --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV1.java @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh.utf8; + +import static java.lang.Character.MAX_SURROGATE; +import static java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT; +import static java.lang.Character.MIN_SURROGATE; +import static java.lang.Character.isSurrogatePair; +import static java.lang.Character.toCodePoint; + +import com.hedera.pbj.runtime.MalformedProtobufException; +import com.hedera.pbj.runtime.io.WritableSequentialData; +import java.io.IOException; + +/** + * UTF8 tools based on protobuf standard library, so we are byte for byte identical + */ +@SuppressWarnings("DuplicatedCode") +public final class Utf8ToolsV1 { + + /** + * Returns the number of bytes in the UTF-8-encoded form of {@code sequence}. For a string, this + * method is equivalent to {@code string.getBytes(UTF_8).length}, but is more efficient in both + * time and space. + * + * @throws IllegalArgumentException if {@code sequence} contains ill-formed UTF-16 (unpaired + * surrogates) + */ + static int encodedLength(final String sequence) throws IOException { + if (sequence == null) { + return 0; + } + // Warning to maintainers: this implementation is highly optimized. + int utf16Length = sequence.length(); + int utf8Length = utf16Length; + int i = 0; + + // This loop optimizes for pure ASCII. + while (i < utf16Length && sequence.charAt(i) < 0x80) { + i++; + } + + // This loop optimizes for chars less than 0x800. + for (; i < utf16Length; i++) { + char c = sequence.charAt(i); + if (c < 0x800) { + utf8Length += ((0x7f - c) >>> 31); // branch free! + } else { + utf8Length += encodedLengthGeneral(sequence, i); + break; + } + } + + if (utf8Length < utf16Length) { + // Necessary and sufficient condition for overflow because of maximum 3x expansion + throw new IllegalArgumentException("UTF-8 length does not fit in int: " + (utf8Length + (1L << 32))); + } + return utf8Length; + } + + private static int encodedLengthGeneral(final CharSequence sequence, final int start) throws IOException { + int utf16Length = sequence.length(); + int utf8Length = 0; + for (int i = start; i < utf16Length; i++) { + char c = sequence.charAt(i); + if (c < 0x800) { + utf8Length += (0x7f - c) >>> 31; // branch free! + } else { + utf8Length += 2; + // jdk7+: if (Character.isSurrogate(c)) { + if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) { + // Check that we have a well-formed surrogate pair. + int cp = Character.codePointAt(sequence, i); + if (cp < MIN_SUPPLEMENTARY_CODE_POINT) { + throw new MalformedProtobufException("Unpaired surrogate at index " + i + " of " + utf16Length); + } + i++; + } + } + } + return utf8Length; + } + + public static String decodeUtf8(byte[] in, int offset, int length) { + return new String(in, offset, length, java.nio.charset.StandardCharsets.UTF_8); + } + + /** + * Encodes the input character sequence to a {@link WritableSequentialData} using the same algorithm as protoc, so we are + * byte for byte the same. + */ + static void encodeUtf8(final String in, byte[] out, int off) throws IOException { + final int inLength = in.length(); + for (int inIx = 0; inIx < inLength; ++inIx) { + final char c = in.charAt(inIx); + if (c < 0x80) { + // One byte (0xxx xxxx) + out[off++] = (byte) c; + } else if (c < 0x800) { + // Two bytes (110x xxxx 10xx xxxx) + + // Benchmarks show put performs better than putShort here (for HotSpot). + out[off++] = (byte) (0xC0 | (c >>> 6)); + out[off++] = (byte) (0x80 | (0x3F & c)); + } else if (c < MIN_SURROGATE || MAX_SURROGATE < c) { + // Three bytes (1110 xxxx 10xx xxxx 10xx xxxx) + // Maximum single-char code point is 0xFFFF, 16 bits. + + // Benchmarks show put performs better than putShort here (for HotSpot). + out[off++] = (byte) (0xE0 | (c >>> 12)); + out[off++] = (byte) (0x80 | (0x3F & (c >>> 6))); + out[off++] = (byte) (0x80 | (0x3F & c)); + } else { + // Four bytes (1111 xxxx 10xx xxxx 10xx xxxx 10xx xxxx) + // Minimum code point represented by a surrogate pair is 0x10000, 17 bits, four UTF-8 bytes + final char low; + if (inIx + 1 == inLength || !isSurrogatePair(c, (low = in.charAt(++inIx)))) { + throw new MalformedProtobufException("Unpaired surrogate at index " + inIx + " of " + inLength); + } + int codePoint = toCodePoint(c, low); + out[off++] = (byte) ((0xF << 4) | (codePoint >>> 18)); + out[off++] = (byte) (0x80 | (0x3F & (codePoint >>> 12))); + out[off++] = (byte) (0x80 | (0x3F & (codePoint >>> 6))); + out[off++] = (byte) (0x80 | (0x3F & codePoint)); + } + } + } +} diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java new file mode 100644 index 000000000..0af36f8e1 --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh.utf8; + +import static java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT; + +import com.hedera.pbj.runtime.MalformedProtobufException; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.IOException; + +/** + * UTF8 tools based on protobuf standard library, so we are byte for byte identical + */ +public final class Utf8ToolsV2 { + + /** + * Returns the number of bytes in the UTF-8-encoded form of {@code sequence}. For a string, this + * method is equivalent to {@code string.getBytes(UTF_8).length}, but is more efficient in both + * time and space. + * + * @throws IllegalArgumentException if {@code sequence} contains ill-formed UTF-16 (unpaired + * surrogates) + */ + public static int encodedLength(final String in) throws IOException { + int len = 0; + for (int i = 0; i < in.length(); ) { + int codePoint = in.codePointAt(i); + if (codePoint <= 0x7F) { + len += 1; + } else if (codePoint <= 0x7FF) { + len += 2; + } else if (codePoint <= 0xFFFF) { + len += 3; + } else { + len += 4; + } + i += Character.charCount(codePoint); + } + return len; + // if (in == null) { + // return 0; + // } + // // Warning to maintainers: this implementation is highly optimized. + // int utf16Length = in.length(); + // int utf8Length = utf16Length; + // int i = 0; + // + // // This loop optimizes for pure ASCII. + // while (i < utf16Length && in.charAt(i) < 0x80) { + // i++; + // } + // + // // This loop optimizes for chars less than 0x800. + // for (; i < utf16Length; i++) { + // char c = in.charAt(i); + // if (c < 0x800) { + // utf8Length += ((0x7f - c) >>> 31); // branch free! + // } else { + // utf8Length += encodedLengthGeneral(in, i); + // break; + // } + // } + // + // if (utf8Length < utf16Length) { + // // Necessary and sufficient condition for overflow because of maximum 3x expansion + // throw new IllegalArgumentException("UTF-8 length does not fit in int: " + (utf8Length + (1L << + // 32))); + // } + // return utf8Length; + } + + private static int encodedLengthGeneral(final CharSequence sequence, final int start) throws IOException { + int utf16Length = sequence.length(); + int utf8Length = 0; + for (int i = start; i < utf16Length; i++) { + char c = sequence.charAt(i); + if (c < 0x800) { + utf8Length += (0x7f - c) >>> 31; // branch free! + } else { + utf8Length += 2; + // jdk7+: if (Character.isSurrogate(c)) { + if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) { + // Check that we have a well-formed surrogate pair. + int cp = Character.codePointAt(sequence, i); + if (cp < MIN_SUPPLEMENTARY_CODE_POINT) { + throw new MalformedProtobufException("Unpaired surrogate at index " + i + " of " + utf16Length); + } + i++; + } + } + } + return utf8Length; + } + + public static String decodeUtf8(byte[] in, int offset, int length) { + return new String(in, offset, length, java.nio.charset.StandardCharsets.UTF_8); + } + + /** + * Encodes the input character sequence to a byte array using the same algorithm as protoc, so we are byte for + * byte the same. Returns the number of bytes written. + * + * @param out The byte array to write to + * @param offset The offset in the byte array to start writing at + * @param in The input character sequence to encode + * @return The number of bytes written + * @throws MalformedProtobufException if the input contains unpaired surrogates + */ + public static int encodeUtf8(final String in, @NonNull final byte[] out, final int offset) + throws MalformedProtobufException { + int utf16Length = in.length(); + int i = 0; + int j = offset; + // Designed to take advantage of + // https://wiki.openjdk.java.net/display/HotSpotInternals/RangeCheckElimination + for (char c; i < utf16Length && (c = in.charAt(i)) < 0x80; i++) { + out[j + i] = (byte) c; + } + if (i == utf16Length) { + return j + utf16Length - offset; + } + j += i; + for (char c; i < utf16Length; i++) { + c = in.charAt(i); + if (c < 0x80) { + out[j++] = (byte) c; + } else if (c < 0x800) { // 11 bits, two UTF-8 bytes + out[j++] = (byte) ((0xF << 6) | (c >>> 6)); + out[j++] = (byte) (0x80 | (0x3F & c)); + } else if ((c < Character.MIN_SURROGATE || Character.MAX_SURROGATE < c)) { + // Maximum single-char code point is 0xFFFF, 16 bits, three UTF-8 bytes + out[j++] = (byte) ((0xF << 5) | (c >>> 12)); + out[j++] = (byte) (0x80 | (0x3F & (c >>> 6))); + out[j++] = (byte) (0x80 | (0x3F & c)); + } else { + // Minimum code point represented by a surrogate pair is 0x10000, 17 bits, + // four UTF-8 bytes + final char low; + if (i + 1 == in.length() || !Character.isSurrogatePair(c, (low = in.charAt(++i)))) { + throw new MalformedProtobufException( + "Unpaired surrogate at index " + (i - 1) + " of " + utf16Length); + } + int codePoint = Character.toCodePoint(c, low); + out[j++] = (byte) ((0xF << 4) | (codePoint >>> 18)); + out[j++] = (byte) (0x80 | (0x3F & (codePoint >>> 12))); + out[j++] = (byte) (0x80 | (0x3F & (codePoint >>> 6))); + out[j++] = (byte) (0x80 | (0x3F & codePoint)); + } + } + return j - offset; + } +} diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV3.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV3.java new file mode 100644 index 000000000..64264bb2d --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV3.java @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh.utf8; + +import java.io.IOException; + +/** + * UTF8 tools based on protobuf standard library, so we are byte for byte identical + */ +public final class Utf8ToolsV3 { + + // ---------------------------------------------------------------------- + // Public API + // ---------------------------------------------------------------------- + + /** Strict UTF-8 decode. Throws IOException on malformed sequences. */ + public static String decodeUtf8(final byte[] in, final int offset, final int length) throws IOException { + if ((offset | length) < 0 || offset + length > in.length) { + throw new IndexOutOfBoundsException("decodeUtf8: bad offset/length"); + } + final int end = offset + length; + + // Fast ASCII-only pass; if all ASCII, build String directly. + int i = offset; + while (i < end && (in[i] >= 0)) i++; + if (i == end) { + // All ASCII + char[] chars = new char[length]; + for (int p = 0; p < length; p++) { + chars[p] = (char) (in[offset + p] & 0x7F); + } + return new String(chars); + } + + // Two-pass: count UTF-16 code units, then decode into a single char[]. + final int charCount = countUtf16UnitsStrict(in, offset, end); + final char[] out = new char[charCount]; + + // Decode pass + int outPos = 0; + i = offset; + while (i < end) { + int b0 = in[i++] & 0xFF; + if (b0 < 0x80) { + out[outPos++] = (char) b0; + // try to chew a few ASCII in a tight loop + while (i < end && (in[i] >= 0)) { + out[outPos++] = (char) (in[i++] & 0x7F); + } + continue; + } + if ((b0 & 0xE0) == 0xC0) { // 2-byte + if (i >= end) throw bad(); + int b1 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80) throw bad(); + int cp = ((b0 & 0x1F) << 6) | (b1 & 0x3F); + // reject overlongs: must be >= 0x80; also b0 >= 0xC2 + if (cp < 0x80 || b0 < 0xC2) throw bad(); + out[outPos++] = (char) cp; + continue; + } + if ((b0 & 0xF0) == 0xE0) { // 3-byte + if (i + 1 >= end) throw bad(); + int b1 = in[i++] & 0xFF; + int b2 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80) throw bad(); + + // E0: b1 >= 0xA0 to avoid overlong; ED: b1 <= 0x9F to avoid surrogates + if (b0 == 0xE0 && b1 < 0xA0) throw bad(); + if (b0 == 0xED && b1 >= 0xA0) throw bad(); + + int cp = ((b0 & 0x0F) << 12) | ((b1 & 0x3F) << 6) | (b2 & 0x3F); + if (cp >= 0xD800 && cp <= 0xDFFF) throw bad(); // no UTF-8-encoded surrogates + if (cp < 0x800) throw bad(); // overlong (should have been 2-byte) + out[outPos++] = (char) cp; + continue; + } + if ((b0 & 0xF8) == 0xF0) { // 4-byte + if (i + 2 >= end) throw bad(); + int b1 = in[i++] & 0xFF; + int b2 = in[i++] & 0xFF; + int b3 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80 || (b3 & 0xC0) != 0x80) throw bad(); + + // F0: b1 >= 0x90 (avoid overlong); F4: b1 <= 0x8F (max U+10FFFF); F5..FF invalid + if (b0 == 0xF0 && b1 < 0x90) throw bad(); + if (b0 > 0xF4 || (b0 == 0xF4 && b1 > 0x8F)) throw bad(); + + int cp = ((b0 & 0x07) << 18) | ((b1 & 0x3F) << 12) | ((b2 & 0x3F) << 6) | (b3 & 0x3F); + if (cp < 0x10000 || cp > 0x10FFFF) throw bad(); + + // Encode as surrogate pair in UTF-16 + int hi = ((cp - 0x10000) >>> 10) + 0xD800; + int lo = ((cp - 0x10000) & 0x3FF) + 0xDC00; + out[outPos++] = (char) hi; + out[outPos++] = (char) lo; + continue; + } + throw bad(); + } + + return new String(out); + } + + /** + * Encodes {@code in} to UTF-8 into {@code out} starting at {@code offset}. + * Returns the number of bytes written. Throws IOException if {@code out} + * does not have enough space or on invalid surrogate usage. + */ + public static int encodeUtf8(final String in, final byte[] out, final int offset) throws IOException { + if (in == null) throw new NullPointerException("in"); + if (offset < 0 || offset > out.length) throw new IndexOutOfBoundsException("encodeUtf8: bad offset"); + + final int need = encodedLength(in); // also validates surrogate pairs + if (out.length - offset < need) { + throw new IOException("encodeUtf8: insufficient space: need=" + need + " have=" + (out.length - offset)); + } + + int pos = offset; + final int n = in.length(); + int i = 0; + + // ASCII fast path (eat a run) + while (i < n) { + char c = in.charAt(i); + if (c <= 0x7F) { + out[pos++] = (byte) c; + i++; + while (i < n) { + char d = in.charAt(i); + if (d > 0x7F) break; + out[pos++] = (byte) d; + i++; + } + continue; + } + + if (c <= 0x7FF) { + out[pos++] = (byte) (0xC0 | (c >>> 6)); + out[pos++] = (byte) (0x80 | (c & 0x3F)); + i++; + continue; + } + + if (Character.isHighSurrogate(c)) { + if (i + 1 >= n) throw new IOException("encodeUtf8: unpaired high surrogate at end"); + char d = in.charAt(i + 1); + if (!Character.isLowSurrogate(d)) throw new IOException("encodeUtf8: unpaired high surrogate"); + int cp = Character.toCodePoint(c, d); + // 4-byte + out[pos++] = (byte) (0xF0 | (cp >>> 18)); + out[pos++] = (byte) (0x80 | ((cp >>> 12) & 0x3F)); + out[pos++] = (byte) (0x80 | ((cp >>> 6) & 0x3F)); + out[pos++] = (byte) (0x80 | (cp & 0x3F)); + i += 2; + continue; + } + + if (Character.isLowSurrogate(c)) { + throw new IOException("encodeUtf8: unpaired low surrogate"); + } + + // 3-byte (BMP non-surrogate) + out[pos++] = (byte) (0xE0 | (c >>> 12)); + out[pos++] = (byte) (0x80 | ((c >>> 6) & 0x3F)); + out[pos++] = (byte) (0x80 | (c & 0x3F)); + i++; + } + return pos - offset; + } + + // ---------------------------------------------------------------------- + // Private helpers + // ---------------------------------------------------------------------- + + /** Computes the exact number of UTF-8 bytes required for {@code str}. Validates surrogate pairing. */ + public static int encodedLength(final String str) throws IOException { + final int n = str.length(); + int len = 0; + int i = 0; + + // Fast ASCII prefix + while (i < n) { + char c = str.charAt(i); + if (c > 0x7F) break; + len++; + i++; + // nibble a few ASCII in a burst + while (i < n) { + char d = str.charAt(i); + if (d > 0x7F) break; + len++; + i++; + } + } + + while (i < n) { + char c = str.charAt(i++); + if (c <= 0x7F) { + len += 1; + } else if (c <= 0x7FF) { + len += 2; + } else if (Character.isHighSurrogate(c)) { + if (i >= n) throw new IOException("encodedLength: unpaired high surrogate at end"); + char d = str.charAt(i); + if (!Character.isLowSurrogate(d)) throw new IOException("encodedLength: unpaired high surrogate"); + i++; // consume pair + len += 4; + } else if (Character.isLowSurrogate(c)) { + throw new IOException("encodedLength: unpaired low surrogate"); + } else { + len += 3; + } + } + return len; + } + + /** Counts UTF-16 code units produced by decoding strict UTF-8 in in[offset..end). */ + private static int countUtf16UnitsStrict(final byte[] in, final int offset, final int end) throws IOException { + int i = offset; + int count = 0; + + while (i < end) { + int b0 = in[i++] & 0xFF; + if (b0 < 0x80) { + count += 1; + // run of ASCII + while (i < end && (in[i] >= 0)) { + i++; + count++; + } + continue; + } + + if ((b0 & 0xE0) == 0xC0) { // 2-byte + if (i >= end) throw bad(); + int b1 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80) throw bad(); + int cp = ((b0 & 0x1F) << 6) | (b1 & 0x3F); + if (cp < 0x80 || b0 < 0xC2) throw bad(); // overlong / invalid + count += 1; + continue; + } + + if ((b0 & 0xF0) == 0xE0) { // 3-byte + if (i + 1 >= end) throw bad(); + int b1 = in[i++] & 0xFF; + int b2 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80) throw bad(); + if (b0 == 0xE0 && b1 < 0xA0) throw bad(); // overlong + if (b0 == 0xED && b1 >= 0xA0) throw bad(); // surrogate range + int cp = ((b0 & 0x0F) << 12) | ((b1 & 0x3F) << 6) | (b2 & 0x3F); + if (cp < 0x800) throw bad(); // overlong + if (cp >= 0xD800 && cp <= 0xDFFF) throw bad(); // encoded surrogate + count += 1; + continue; + } + + if ((b0 & 0xF8) == 0xF0) { // 4-byte + if (i + 2 >= end) throw bad(); + int b1 = in[i++] & 0xFF; + int b2 = in[i++] & 0xFF; + int b3 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80 || (b3 & 0xC0) != 0x80) throw bad(); + if (b0 == 0xF0 && b1 < 0x90) throw bad(); // overlong + if (b0 > 0xF4 || (b0 == 0xF4 && b1 > 0x8F)) throw bad(); // > U+10FFFF + count += 2; // surrogate pair in UTF-16 + continue; + } + + throw bad(); + } + return count; + } + + private static IOException bad() { + return new IOException("Malformed UTF-8 input"); + } +} diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV4.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV4.java new file mode 100644 index 000000000..c108490ab --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV4.java @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh.utf8; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.util.Objects; + +/** + * UTF8 tools based on protobuf standard library, so we are byte for byte identical + */ +public final class Utf8ToolsV4 { + + // ---- Internal fast-path plumbing --------------------------------------------------------- + + // Compact Strings: coder == 0 (LATIN1), 1 (UTF16) + private static final byte CODER_LATIN1 = 0; + private static final byte CODER_UTF16 = 1; + + // Guard: if these are non-null, we can fast-path on String internals. + private static final VarHandle STRING_VALUE_VH; + private static final VarHandle STRING_CODER_VH; + private static final boolean HAS_STRING_VH; + + static { + VarHandle v = null, c = null; + boolean ok = false; + try { + MethodHandles.Lookup l = MethodHandles.privateLookupIn(String.class, MethodHandles.lookup()); + v = l.findVarHandle(String.class, "value", byte[].class); + c = l.findVarHandle(String.class, "coder", byte.class); + ok = (v != null && c != null); + } catch (Throwable ignore) { + ignore.printStackTrace(); + ok = false; + } + STRING_VALUE_VH = v; + STRING_CODER_VH = c; + HAS_STRING_VH = ok; + } + + // Helpers + private static byte[] stringValueBytes(String s) { + return (byte[]) STRING_VALUE_VH.get(s); + } + + private static byte stringCoder(String s) { + return (byte) STRING_CODER_VH.get(s); + } + + // ---- Public API -------------------------------------------------------------------------- + + /** Strict UTF-8 decode. Throws IOException on malformed sequences. */ + public static String decodeUtf8(final byte[] in, final int offset, final int length) throws IOException { + if ((offset | length) < 0 || offset + length > in.length) { + throw new IndexOutOfBoundsException("decodeUtf8: bad offset/length"); + } + final int end = offset + length; + + // ASCII run fast path + int i = offset; + while (i < end && in[i] >= 0) i++; + if (i == end) { + char[] chars = new char[length]; + for (int p = 0; p < length; p++) chars[p] = (char) (in[offset + p] & 0x7F); + return new String(chars); + } + + // Count UTF-16 units (strict validation) + final int charCount = countUtf16UnitsStrict(in, offset, end); + final char[] out = new char[charCount]; + + // Decode + int outPos = 0; + i = offset; + while (i < end) { + int b0 = in[i++] & 0xFF; + if (b0 < 0x80) { + out[outPos++] = (char) b0; + while (i < end && in[i] >= 0) out[outPos++] = (char) (in[i++] & 0x7F); + continue; + } + if ((b0 & 0xE0) == 0xC0) { + if (i >= end) throw bad(); + int b1 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80) throw bad(); + int cp = ((b0 & 0x1F) << 6) | (b1 & 0x3F); + if (cp < 0x80 || b0 < 0xC2) throw bad(); // overlong + out[outPos++] = (char) cp; + continue; + } + if ((b0 & 0xF0) == 0xE0) { + if (i + 1 >= end) throw bad(); + int b1 = in[i++] & 0xFF, b2 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80) throw bad(); + if (b0 == 0xE0 && b1 < 0xA0) throw bad(); + if (b0 == 0xED && b1 >= 0xA0) throw bad(); + int cp = ((b0 & 0x0F) << 12) | ((b1 & 0x3F) << 6) | (b2 & 0x3F); + if (cp < 0x800 || (cp >= 0xD800 && cp <= 0xDFFF)) throw bad(); + out[outPos++] = (char) cp; + continue; + } + if ((b0 & 0xF8) == 0xF0) { + if (i + 2 >= end) throw bad(); + int b1 = in[i++] & 0xFF, b2 = in[i++] & 0xFF, b3 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80 || (b3 & 0xC0) != 0x80) throw bad(); + if (b0 == 0xF0 && b1 < 0x90) throw bad(); + if (b0 > 0xF4 || (b0 == 0xF4 && b1 > 0x8F)) throw bad(); + int cp = ((b0 & 0x07) << 18) | ((b1 & 0x3F) << 12) | ((b2 & 0x3F) << 6) | (b3 & 0x3F); + if (cp < 0x10000 || cp > 0x10FFFF) throw bad(); + int hi = ((cp - 0x10000) >>> 10) + 0xD800; + int lo = ((cp - 0x10000) & 0x3FF) + 0xDC00; + out[outPos++] = (char) hi; + out[outPos++] = (char) lo; + continue; + } + throw bad(); + } + return new String(out); + } + + /** + * Encodes {@code in} to UTF-8 into {@code out} at {@code offset}. + * Returns bytes written. Throws if insufficient space or malformed surrogates. + * Uses a VarHandle fast path for String LATIN1 (Compact String) when available. + */ + public static int encodeUtf8(final String in, final byte[] out, final int offset) throws IOException { + Objects.requireNonNull(in, "in"); + // Try the internal LATIN1 fast path + if (HAS_STRING_VH && stringCoder(in) == CODER_LATIN1) { + final byte[] v = stringValueBytes(in); + // Quick ASCII check; if all ASCII, we can memcpy directly. + int i = 0, n = v.length; + while (i < n && (v[i] & 0x80) == 0) i++; + if (i == n) { + // Pure ASCII: exact size == n + if (out.length - offset < n) throw new IOException("encodeUtf8: insufficient space (ASCII path)"); + System.arraycopy(v, 0, out, offset, n); + return n; + } + // Mixed Latin-1: encode in one pass. Worst-case length = n (ASCII) + 2*(non-ascii bytes) + // Exact length is computed below (without allocating). + return encodeLatin1ToUtf8(v, out, offset); + } + // Portable path (handles UTF16 coder as well) + return encodePortable(in, out, offset); + } + + /** + * Returns the exact UTF-8 byte length for {@code str}. Validates surrogate pairing. + * Uses internal LATIN1 fast path if available. + */ + public static int encodedLength(final String str) throws IOException { + if (HAS_STRING_VH && stringCoder(str) == CODER_LATIN1) { + return encodedLengthLatin1(stringValueBytes(str)); + } + // Portable path for UTF16 (or if internals not available) + final int n = str.length(); + int len = 0; + int i = 0; + + // ASCII prefix + while (i < n && str.charAt(i) <= 0x7F) { + len++; + i++; + } + while (i < n) { + char c = str.charAt(i++); + if (c <= 0x7F) { + len += 1; + } else if (c <= 0x7FF) { + len += 2; + } else if (Character.isHighSurrogate(c)) { + if (i >= n) throw new IOException("encodedLength: unpaired high surrogate at end"); + char d = str.charAt(i); + if (!Character.isLowSurrogate(d)) throw new IOException("encodedLength: unpaired high surrogate"); + i++; + len += 4; + } else if (Character.isLowSurrogate(c)) { + throw new IOException("encodedLength: unpaired low surrogate"); + } else { + len += 3; + } + } + return len; + } + + // ---- Internal fast-path (LATIN1 String.value) -------------------------------------------- + + /** Exact UTF-8 length for a LATIN1 byte[] (no surrogates exist in LATIN1). */ + private static int encodedLengthLatin1(final byte[] latin1) { + int ascii = 0, hi = 0; // count ASCII vs high bytes + for (byte b : latin1) { + if ((b & 0x80) == 0) ascii++; + else hi++; + } + // ASCII -> 1 byte; high bytes (0x80..0xFF) -> 2-byte UTF-8 + return ascii + (hi << 1); + } + + /** Encodes a LATIN1 byte[] directly to UTF-8. Returns bytes written. */ + private static int encodeLatin1ToUtf8(final byte[] latin1, final byte[] out, final int offset) { + int pos = offset; + int i = 0, n = latin1.length; + + // ASCII run + while (i < n && (latin1[i] & 0x80) == 0) { + out[pos++] = latin1[i++]; + while (i < n && (latin1[i] & 0x80) == 0) { + out[pos++] = latin1[i++]; + } + // then fall into non-ASCII handling if applicable + } + + while (i < n) { + int b = latin1[i++] & 0xFF; + if ((b & 0x80) == 0) { + // ASCII + out[pos++] = (byte) b; + } else { + // LATIN1 0x80..0xFF -> two-byte UTF-8: 0xC2/0xC3 prefix depending on top bit of 0x80..0xFF + // Values 0x80..0xBF => 0xC2 xx ; 0xC0..0xFF => 0xC3 (b - 0x40) + if (b < 0xC0) { + out[pos++] = (byte) 0xC2; + out[pos++] = (byte) b; + } else { + out[pos++] = (byte) 0xC3; + out[pos++] = (byte) (b - 0x40); // (b & 0x3F) | 0x80 + } + } + } + return pos - offset; + } + + // ---- Portable encode (UTF-16 String) ----------------------------------------------------- + + private static int encodePortable(final String in, final byte[] out, final int offset) throws IOException { + int pos = offset; + final int n = in.length(); + int i = 0; + + // ASCII fast path + while (i < n) { + char c = in.charAt(i); + if (c <= 0x7F) { + out[pos++] = (byte) c; + i++; + while (i < n) { + char d = in.charAt(i); + if (d > 0x7F) break; + out[pos++] = (byte) d; + i++; + } + continue; + } + if (c <= 0x7FF) { + out[pos++] = (byte) (0xC0 | (c >>> 6)); + out[pos++] = (byte) (0x80 | (c & 0x3F)); + i++; + continue; + } + if (Character.isHighSurrogate(c)) { + if (i + 1 >= n) throw new IOException("encodeUtf8: unpaired high surrogate at end"); + char d = in.charAt(i + 1); + if (!Character.isLowSurrogate(d)) throw new IOException("encodeUtf8: unpaired high surrogate"); + int cp = Character.toCodePoint(c, d); + out[pos++] = (byte) (0xF0 | (cp >>> 18)); + out[pos++] = (byte) (0x80 | ((cp >>> 12) & 0x3F)); + out[pos++] = (byte) (0x80 | ((cp >>> 6) & 0x3F)); + out[pos++] = (byte) (0x80 | (cp & 0x3F)); + i += 2; + continue; + } + if (Character.isLowSurrogate(c)) { + throw new IOException("encodeUtf8: unpaired low surrogate"); + } + out[pos++] = (byte) (0xE0 | (c >>> 12)); + out[pos++] = (byte) (0x80 | ((c >>> 6) & 0x3F)); + out[pos++] = (byte) (0x80 | (c & 0x3F)); + i++; + } + return pos - offset; + } + + // ---- Strict counting for decode ---------------------------------------------------------- + + private static int countUtf16UnitsStrict(final byte[] in, final int offset, final int end) throws IOException { + int i = offset, count = 0; + while (i < end) { + int b0 = in[i++] & 0xFF; + if (b0 < 0x80) { + count += 1; + while (i < end && in[i] >= 0) { + i++; + count++; + } + continue; + } + if ((b0 & 0xE0) == 0xC0) { + if (i >= end) throw bad(); + int b1 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80) throw bad(); + if (b0 < 0xC2) throw bad(); // overlong + count += 1; + continue; + } + if ((b0 & 0xF0) == 0xE0) { + if (i + 1 >= end) throw bad(); + int b1 = in[i++] & 0xFF, b2 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80) throw bad(); + if (b0 == 0xE0 && b1 < 0xA0) throw bad(); + if (b0 == 0xED && b1 >= 0xA0) throw bad(); + count += 1; + continue; + } + if ((b0 & 0xF8) == 0xF0) { + if (i + 2 >= end) throw bad(); + int b1 = in[i++] & 0xFF, b2 = in[i++] & 0xFF, b3 = in[i++] & 0xFF; + if ((b1 & 0xC0) != 0x80 || (b2 & 0xC0) != 0x80 || (b3 & 0xC0) != 0x80) throw bad(); + if (b0 == 0xF0 && b1 < 0x90) throw bad(); + if (b0 > 0xF4 || (b0 == 0xF4 && b1 > 0x8F)) throw bad(); + count += 2; // surrogate pair + continue; + } + throw bad(); + } + return count; + } + + private static IOException bad() { + return new IOException("Malformed UTF-8"); + } + + public static void main(String[] args) throws IOException { + encodeUtf8("hello", new byte[10], 0); + } +} diff --git a/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/fuzz/SingleFuzzTest.java b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/fuzz/SingleFuzzTest.java index c2065e12b..33873afc9 100644 --- a/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/fuzz/SingleFuzzTest.java +++ b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/fuzz/SingleFuzzTest.java @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 package com.hedera.pbj.integration.fuzz; +import com.google.protobuf.InvalidProtocolBufferException; import com.hedera.pbj.runtime.Codec; import com.hedera.pbj.runtime.io.buffer.BufferedData; import java.io.InputStream; @@ -97,6 +98,15 @@ private static void tryProtocParser( } } catch (Exception ex) { // Protoc didn't like the bytes. + // However, Protoc always parses UTF-8 and fails on invalid chars. PBJ may skip UTF-8 encoding until later. + for (Throwable t = ex; t != null; t = t.getCause()) { + if (t instanceof InvalidProtocolBufferException + && t.getMessage() != null + && t.getMessage().contains("Protocol message had invalid UTF-8")) { + return; + } + } + if (doThrow) { throw new FuzzTestException( prefix + "Protoc threw an exception " diff --git a/pbj-integration-tests/src/main/proto/extendedUtf8StingTest.proto b/pbj-integration-tests/src/main/proto/extendedUtf8StingTest.proto index bd1a249c4..6c0e5208d 100644 --- a/pbj-integration-tests/src/main/proto/extendedUtf8StingTest.proto +++ b/pbj-integration-tests/src/main/proto/extendedUtf8StingTest.proto @@ -27,12 +27,31 @@ option java_package = "com.hedera.pbj.test.proto.java"; option java_multiple_files = true; // <<>> This comment is special code for setting PBJ Compiler java package +import "google/protobuf/wrappers.proto"; + /** * Simple message with a string for extended UTF8 testing */ +// <<>> message MessageWithString { /** * A single string for extended testing */ string aTestString = 1; } + +// <<>> +message MessageWithBoxedString { + google.protobuf.StringValue boxedString = 5; +} + +message MessageWithRepeatedString { + repeated string repeatedText = 1; +} + +// <<>> +message MessageWithOneOfString { + oneof oneofExample { + string text = 1; + } +} \ No newline at end of file diff --git a/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/HashEqualsTest.java b/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/HashEqualsTest.java index 482220fe2..69e2dab40 100644 --- a/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/HashEqualsTest.java +++ b/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/HashEqualsTest.java @@ -4,9 +4,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import com.hedera.pbj.test.proto.pbj.MessageWithBoxedString; +import com.hedera.pbj.test.proto.pbj.MessageWithRepeatedString; +import com.hedera.pbj.test.proto.pbj.MessageWithString; import com.hedera.pbj.test.proto.pbj.TimestampTest; import com.hedera.pbj.test.proto.pbj.TimestampTest2; +import java.util.List; import org.junit.jupiter.api.Test; class HashEqualsTest { @@ -89,4 +94,39 @@ void differentObjectsWithNoDefaulHashCode4() { assertNotEquals(tst.hashCode(), tst2.hashCode()); } + + @Test + void testStrings() { + // new String() to ensure we actually create a brand-new string instance + final MessageWithString msg1 = new MessageWithString(new String("test")); + // Same characters, but a brand-new string instance again + final MessageWithString msg2 = new MessageWithString(new String("test")); + + assertEquals(msg1.hashCode(), msg2.hashCode()); + assertTrue(msg1.equals(msg2)); + } + + @Test + void testBoxedStrings() { + // new String() to ensure we actually create a brand-new string instance + final MessageWithBoxedString msg1 = new MessageWithBoxedString(new String("test")); + // Same characters, but a brand-new string instance again + final MessageWithBoxedString msg2 = new MessageWithBoxedString(new String("test")); + + assertEquals(msg1.hashCode(), msg2.hashCode()); + assertTrue(msg1.equals(msg2)); + } + + @Test + void testRepeatedStrings() { + // new String() to ensure we actually create a brand-new string instance + final MessageWithRepeatedString msg1 = + new MessageWithRepeatedString(List.of(new String("test1"), new String("test2"))); + // Same characters, but a brand-new string instance again + final MessageWithRepeatedString msg2 = + new MessageWithRepeatedString(List.of(new String("test1"), new String("test2"))); + + assertEquals(msg1.hashCode(), msg2.hashCode()); + assertTrue(msg1.equals(msg2)); + } } diff --git a/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/SampleFuzzTest.java b/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/SampleFuzzTest.java index 5a2011218..ba2b2f2cd 100644 --- a/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/SampleFuzzTest.java +++ b/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/SampleFuzzTest.java @@ -60,7 +60,7 @@ public class SampleFuzzTest { * The fuzz test as a whole is considered passed * if that many individual model tests pass. */ - private static final double PASS_RATE_THRESHOLD = 1.; + private static final double PASS_RATE_THRESHOLD = .997; /** * A threshold for the mean value of the shares of DESERIALIZATION_FAILED @@ -70,7 +70,7 @@ public class SampleFuzzTest { * if the mean value of all the individual DESERIALIZATION_FAILED * shares is greater than this threshold. */ - private static final double DESERIALIZATION_FAILED_MEAN_THRESHOLD = .9829; + private static final double DESERIALIZATION_FAILED_MEAN_THRESHOLD = .9824; /** * Fuzz tests are tagged with this tag to allow Gradle/JUnit From 67e27462e91990c8fe15bbbfb96e6a2d19bf3695 Mon Sep 17 00:00:00 2001 From: Anthony Petrov Date: Fri, 3 Oct 2025 11:23:23 -0700 Subject: [PATCH 2/2] address comments, fix previous merge, spotless Signed-off-by: Anthony Petrov --- .../json/JsonCodecParseMethodGenerator.java | 41 ++++++++++--------- .../protobuf/CodecParseMethodGenerator.java | 7 ++-- .../pbj/integration/jmh/utf8/Utf8ToolsV2.java | 30 -------------- 3 files changed, 26 insertions(+), 52 deletions(-) diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java index 31b28228f..847e8988e 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java @@ -157,7 +157,7 @@ private static void generateFieldCaseStatement( final StringBuilder origSB, final Field field, final String valueGetter) { final StringBuilder sb = new StringBuilder(); final boolean isMapField = field instanceof SingleField && ((SingleField) field).isMapField(); - final boolean isMapFieldOrOneOf = isMapField || field.parent() != null; + final boolean isNotMapFieldOrOneOf = !isMapField && field.parent() == null; if (field.repeated()) { if (field.type() == Field.FieldType.MESSAGE) { sb.append(("parseObjArray(checkSize(\"$fieldName\", $valueGetter.arr().value(), $maxSize), " @@ -175,12 +175,13 @@ private static void generateFieldCaseStatement( case FLOAT -> sb.append("parseFloat(v)"); case DOUBLE -> sb.append("parseDouble(v)"); case STRING -> - sb.append((isMapFieldOrOneOf ? "toUtf8Bytes(" : "") + - "unescape(checkSize(\"$fieldName\", v.STRING().getText(), $maxSize))" - .replace("$maxSize", field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") - .replace("$fieldName", field.name()) + - (isMapFieldOrOneOf ? ")" : "") - ); + sb.append((isNotMapFieldOrOneOf ? "toUtf8Bytes(" : "") + + "unescape(checkSize(\"$fieldName\", v.STRING().getText(), $maxSize))" + .replace( + "$maxSize", + field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") + .replace("$fieldName", field.name()) + + (isNotMapFieldOrOneOf ? ")" : "")); case BOOL -> sb.append("parseBoolean(v)"); // maxSize * 2 - because Base64. The *2 math isn't precise, but it's good enough for our purposes. @@ -202,12 +203,13 @@ private static void generateFieldCaseStatement( case "FloatValue" -> sb.append("parseFloat($valueGetter)"); case "DoubleValue" -> sb.append("parseDouble($valueGetter)"); case "StringValue" -> - sb.append((isMapFieldOrOneOf ? "toUtf8Bytes(" : "") + - "unescape(checkSize(\"$fieldName\", $valueGetter.STRING().getText(), $maxSize))" - .replace("$maxSize", field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") - .replace("$fieldName", field.name()) + - (isMapFieldOrOneOf ? ")" : "") - ); + sb.append((isNotMapFieldOrOneOf ? "toUtf8Bytes(" : "") + + "unescape(checkSize(\"$fieldName\", $valueGetter.STRING().getText(), $maxSize))" + .replace( + "$maxSize", + field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") + .replace("$fieldName", field.name()) + + (isNotMapFieldOrOneOf ? ")" : "")); case "BoolValue" -> sb.append("parseBoolean($valueGetter)"); // maxSize * 2 - because Base64. The *2 math isn't precise, but it's good enough for our purposes: @@ -252,12 +254,13 @@ private static void generateFieldCaseStatement( case FLOAT -> sb.append("parseFloat($valueGetter)"); case DOUBLE -> sb.append("parseDouble($valueGetter)"); case STRING -> - sb.append((isMapFieldOrOneOf ? "toUtf8Bytes(" : "") + - "unescape(checkSize(\"$fieldName\", $valueGetter.STRING().getText(), $maxSize))" - .replace("$maxSize", field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") - .replace("$fieldName", field.name()) + - (isMapFieldOrOneOf ? ")" : "") - ); + sb.append((isNotMapFieldOrOneOf ? "toUtf8Bytes(" : "") + + "unescape(checkSize(\"$fieldName\", $valueGetter.STRING().getText(), $maxSize))" + .replace( + "$maxSize", + field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") + .replace("$fieldName", field.name()) + + (isNotMapFieldOrOneOf ? ")" : "")); case BOOL -> sb.append("parseBoolean($valueGetter)"); // maxSize * 2 - because Base64. The *2 math isn't precise, but it's good enough for our purposes: diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java index baa5787d7..8e1882a6b 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java @@ -498,9 +498,10 @@ static String readMethod(Field field) { case FIXED64 -> "readFixed64(input)"; case SFIXED64 -> "readSignedFixed64(input)"; case STRING -> - "readString%s(input, %s)".formatted( - field.hasDifferentStorageType() ? "Raw" : "", - field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize"); + "readString%s(input, %s)" + .formatted( + field.hasDifferentStorageType() ? "Raw" : "", + field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize"); case BOOL -> "readBool(input)"; case BYTES -> "readBytes(input, %s)".formatted(field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize"); diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java index 0af36f8e1..e863200f2 100644 --- a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/utf8/Utf8ToolsV2.java @@ -36,36 +36,6 @@ public static int encodedLength(final String in) throws IOException { i += Character.charCount(codePoint); } return len; - // if (in == null) { - // return 0; - // } - // // Warning to maintainers: this implementation is highly optimized. - // int utf16Length = in.length(); - // int utf8Length = utf16Length; - // int i = 0; - // - // // This loop optimizes for pure ASCII. - // while (i < utf16Length && in.charAt(i) < 0x80) { - // i++; - // } - // - // // This loop optimizes for chars less than 0x800. - // for (; i < utf16Length; i++) { - // char c = in.charAt(i); - // if (c < 0x800) { - // utf8Length += ((0x7f - c) >>> 31); // branch free! - // } else { - // utf8Length += encodedLengthGeneral(in, i); - // break; - // } - // } - // - // if (utf8Length < utf16Length) { - // // Necessary and sufficient condition for overflow because of maximum 3x expansion - // throw new IllegalArgumentException("UTF-8 length does not fit in int: " + (utf8Length + (1L << - // 32))); - // } - // return utf8Length; } private static int encodedLengthGeneral(final CharSequence sequence, final int start) throws IOException {