diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java index b8b6be1f..66ecaeae 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java @@ -1,203 +1,704 @@ +/* + * Copyright (C) 2023 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.hedera.pbj.compiler.impl; import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; import java.io.File; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; +import java.util.*; import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; + /** * Common functions and constants for code generation */ @SuppressWarnings({"DuplicatedCode", "EscapedSpace"}) public final class Common { - /** The indent for fields, default 4 spaces */ - public static final String FIELD_INDENT = " ".repeat(4); - - /** Number of bits used to represent the tag type */ - static final int TAG_TYPE_BITS = 3; - - /** Wire format code for var int */ - public static final int TYPE_VARINT = 0; - /** Wire format code for fixed 64bit number */ - public static final int TYPE_FIXED64 = 1; - /** Wire format code for length delimited, all the complex types */ - public static final int TYPE_LENGTH_DELIMITED = 2; - /** Wire format code for fixed 32bit number */ - public static final int TYPE_FIXED32 = 5; - - - /** - * Makes a tag value given a field number and wire type. - * - * @param wireType the wire type part of tag - * @param fieldNumber the field number part of tag - * @return packed encoded tag - */ - public static int getTag(final int wireType, final int fieldNumber) { - return (fieldNumber << TAG_TYPE_BITS) | wireType; - } - - /** - * Make sure first character of a string is upper case - * - * @param name string input who's first character can be upper or lower case - * @return name with first character converted to upper case - */ - public static String capitalizeFirstLetter(String name) { - if (name.length() > 0) { - if (name.chars().allMatch(Character::isUpperCase)) { - return Character.toUpperCase(name.charAt(0)) + name.substring(1).toLowerCase(); - } else { - return Character.toUpperCase(name.charAt(0)) + name.substring(1); - } - } - return name; - } - - /** - * Convert names like "hello_world" to "HelloWorld" or "helloWorld" depending on firstUpper. Also handles special case - * like "HELLO_WORLD" to same output as "hello_world", while "HelloWorld_Two" still becomes "helloWorldTwo". - * - * @param name input name in snake case - * @param firstUpper if true then first char is upper case otherwise it is lower - * @return out name in camel case - */ - public static String snakeToCamel(String name, boolean firstUpper) { - final String out = Arrays.stream(name.split("_")).map(Common::capitalizeFirstLetter).collect( - Collectors.joining("")); - return (firstUpper ? Character.toUpperCase(out.charAt(0)) : Character.toLowerCase(out.charAt(0)) ) - + out.substring(1); - } - - /** - * Convert a camel case name to upper case snake case - * - * @param name the input name in camel case - * @return output name in upper snake case - */ - public static String camelToUpperSnake(String name) { - // check if already camel upper - if (name.chars().allMatch(c -> Character.isUpperCase(c) || Character.isDigit(c) || c == '_')) return name; - // check if already has underscores, then just capitalize - if (name.chars().anyMatch(c -> c == '_')) return name.toUpperCase(); - // else convert - final StringBuilder buf = new StringBuilder(); - for (int i = 0; i < name.length(); i++) { - final char c = name.charAt(i); - if (Character.isUpperCase(c) && i > 0) { - buf.append("_"); - buf.append(c); - } else { - buf.append(Character.toUpperCase(c)); - } - } - // fix special case for capital ID - return buf.toString().replaceAll("_I_D", "_ID"); - } - - /** - * Build a clean java doc comment for a field - * - * @param fieldNumber The field proto number - * @param docContext The parsed field comment contact - * @return clean comment - */ - public static String buildCleanFieldJavaDoc(int fieldNumber, Protobuf3Parser.DocCommentContext docContext) { - final String cleanedComment = docContext == null ? "" : cleanJavaDocComment(docContext.getText()); - final String fieldNumComment = "("+fieldNumber+") "; - return fieldNumComment + cleanedComment; - } - - /** - * Build a clean java doc comment for an oneof field - * - * @param fieldNumbers The field proto numbers for all fields in oneof - * @param docContext The parsed field comment contact - * @return clean comment - */ - public static String buildCleanFieldJavaDoc(List fieldNumbers, Protobuf3Parser.DocCommentContext docContext) { - final String cleanedComment = docContext == null ? "" : cleanJavaDocComment(docContext.getText()); - final String fieldNumComment = - "("+fieldNumbers.stream().map(Objects::toString).collect(Collectors.joining(", "))+") "; - return fieldNumComment + cleanedComment; - } - - /** - * Clean up a java doc style comment removing all the "*" etc. - * - * @param fieldComment raw Java doc style comment - * @return clean multi-line content of the comment - */ - public static String cleanJavaDocComment(String fieldComment) { - return cleanDocStr(fieldComment - .replaceAll("/\\*\\*[\n\r\s\t]*\\*[\t\s]*|[\n\r\s\t]*\\*/","") // remove java doc - .replaceAll("\n\s+\\*\s+","\n") // remove indenting and * - .replaceAll("/\\*\\*","") // remove indenting and /** at beginning of comment. - .trim() // Remove leading and trailing spaces. - ); - } - - /** - * Clean a string so that it can be included in JavaDoc. Does things like replace unsupported HTML tags. - * - * @param docStr The string to clean - * @return cleaned output - */ - public static String cleanDocStr(String docStr) { - return docStr - .replaceAll("<(/?)tt>", "<$1code>") // tt tags are not supported in javadoc - .replaceAll(" < ", " < ") // escape loose less than - .replaceAll(" > ", " > ") // escape loose less than - .replaceAll(" & ", " & ") // escape loose less than - ; - } - - /** - * Convert a field type like "long" to the Java object wrapper type "Long", or pass though if not java primitive - * - * @param primitiveFieldType java field type like "int" etc - * @return java object wrapper type like "Integer" or pass though - */ - public static String javaPrimitiveToObjectType(String primitiveFieldType) { - return switch(primitiveFieldType){ - case "boolean" -> "Boolean"; - case "int" -> "Integer"; - case "long" -> "Long"; - case "float" -> "Float"; - case "double" -> "Double"; - default -> primitiveFieldType; - }; - } - - /** - * Remove leading dot from a string so ".a.b.c" becomes "a.b.c" - * - * @param text text to remove leading dot from - * @return text without a leading dot - */ - public static String removingLeadingDot(String text) { - if (text.length() > 0 & text.charAt(0) == '.') { - return text.substring(1); - } - return text; - } - - /** - * Get the java file for a src directory, package and classname with optional suffix. All parent directories will - * also be created. - * - * @param srcDir The src dir root of all java src - * @param javaPackage the java package with '.' deliminators - * @param className the camel case class name - * @return File object for java file - */ - public static File getJavaFile(File srcDir, String javaPackage, String className) { - File packagePath = new File(srcDir.getPath() + File.separatorChar + javaPackage.replaceAll("\\.","\\" + File.separator)); - //noinspection ResultOfMethodCallIgnored - packagePath.mkdirs(); - return new File(packagePath,className+".java"); - } + /** + * The indent for fields, default 4 spaces + */ + public static final String FIELD_INDENT = " ".repeat(4); + + /** + * Number of bits used to represent the tag type + */ + static final int TAG_TYPE_BITS = 3; + + /** + * Wire format code for var int + */ + public static final int TYPE_VARINT = 0; + /** + * Wire format code for fixed 64bit number + */ + public static final int TYPE_FIXED64 = 1; + /** + * Wire format code for length delimited, all the complex types + */ + public static final int TYPE_LENGTH_DELIMITED = 2; + /** + * Wire format code for fixed 32bit number + */ + public static final int TYPE_FIXED32 = 5; + + /** + * Makes a tag value given a field number and wire type. + * + * @param wireType the wire type part of tag + * @param fieldNumber the field number part of tag + * @return packed encoded tag + */ + public static int getTag(final int wireType, final int fieldNumber) { + return (fieldNumber << TAG_TYPE_BITS) | wireType; + } + + /** + * Make sure first character of a string is upper case + * + * @param name string input who's first character can be upper or lower case + * @return name with first character converted to upper case + */ + public static String capitalizeFirstLetter(String name) { + if (name.length() > 0) { + if (name.chars().allMatch(Character::isUpperCase)) { + return Character.toUpperCase(name.charAt(0)) + name.substring(1).toLowerCase(); + } else { + return Character.toUpperCase(name.charAt(0)) + name.substring(1); + } + } + return name; + } + + /** + * Convert names like "hello_world" to "HelloWorld" or "helloWorld" depending on firstUpper. + * Also handles special case like "HELLO_WORLD" to same output as "hello_world", while + * "HelloWorld_Two" still becomes "helloWorldTwo". + * + * @param name input name in snake case + * @param firstUpper if true then first char is upper case otherwise it is lower + * @return out name in camel case + */ + public static String snakeToCamel(String name, boolean firstUpper) { + final String out = + Arrays.stream(name.split("_")) + .map(Common::capitalizeFirstLetter) + .collect(Collectors.joining("")); + return (firstUpper + ? Character.toUpperCase(out.charAt(0)) + : Character.toLowerCase(out.charAt(0))) + + out.substring(1); + } + + /** + * Convert a camel case name to upper case snake case + * + * @param name the input name in camel case + * @return output name in upper snake case + */ + public static String camelToUpperSnake(String name) { + // check if already camel upper + if (name.chars() + .allMatch(c -> Character.isUpperCase(c) || Character.isDigit(c) || c == '_')) + return name; + // check if already has underscores, then just capitalize + if (name.contains("_")) return name.toUpperCase(); + // else convert + final StringBuilder buf = new StringBuilder(); + for (int i = 0; i < name.length(); i++) { + final char c = name.charAt(i); + if (Character.isUpperCase(c) && i > 0) { + buf.append("_"); + buf.append(c); + } else { + buf.append(Character.toUpperCase(c)); + } + } + // fix special case for capital ID + return buf.toString().replaceAll("_I_D", "_ID"); + } + + /** + * Build a clean java doc comment for a field + * + * @param fieldNumber The field proto number + * @param docContext The parsed field comment contact + * @return clean comment + */ + public static String buildCleanFieldJavaDoc( + int fieldNumber, Protobuf3Parser.DocCommentContext docContext) { + final String cleanedComment = + docContext == null ? "" : cleanJavaDocComment(docContext.getText()); + final String fieldNumComment = "(" + fieldNumber + ") "; + return fieldNumComment + cleanedComment; + } + + /** + * Build a clean java doc comment for an oneof field + * + * @param fieldNumbers The field proto numbers for all fields in oneof + * @param docContext The parsed field comment contact + * @return clean comment + */ + public static String buildCleanFieldJavaDoc( + List fieldNumbers, Protobuf3Parser.DocCommentContext docContext) { + final String cleanedComment = + docContext == null ? "" : cleanJavaDocComment(docContext.getText()); + final String fieldNumComment = + "(" + + fieldNumbers.stream() + .map(Objects::toString) + .collect(Collectors.joining(", ")) + + ") "; + return fieldNumComment + cleanedComment; + } + + /** + * Clean up a java doc style comment removing all the "*" etc. + * + * @param fieldComment raw Java doc style comment + * @return clean multi-line content of the comment + */ + public static String cleanJavaDocComment(String fieldComment) { + return cleanDocStr( + fieldComment + .replaceAll( + "/\\*\\*[\n\r\s\t]*\\*[\t\s]*|[\n\r\s\t]*\\*/", + "") // remove java doc + .replaceAll("\n\s+\\*\s+", "\n") // remove indenting and * + .replaceAll( + "/\\*\\*", "") // remove indenting and /** at beginning of comment. + .trim() // Remove leading and trailing spaces. + ); + } + + /** + * Clean a string so that it can be included in JavaDoc. Does things like replace unsupported + * HTML tags. + * + * @param docStr The string to clean + * @return cleaned output + */ + public static String cleanDocStr(String docStr) { + return docStr.replaceAll("<(/?)tt>", "<$1code>") // tt tags are not supported in javadoc + .replaceAll(" < ", " < ") // escape loose less than + .replaceAll(" > ", " > ") // escape loose less than + .replaceAll(" & ", " & ") // escape loose less than + ; + } + + /** + * Convert a field type like "long" to the Java object wrapper type "Long", or pass though if + * not java primitive + * + * @param primitiveFieldType java field type like "int" etc + * @return java object wrapper type like "Integer" or pass though + */ + public static String javaPrimitiveToObjectType(String primitiveFieldType) { + return switch (primitiveFieldType) { + case "boolean" -> "Boolean"; + case "int" -> "Integer"; + case "long" -> "Long"; + case "float" -> "Float"; + case "double" -> "Double"; + default -> primitiveFieldType; + }; + } + + /** + * Recursively calculates the hashcode for a message fields. + * + * @param fields The fields of this object. + * @param generatedCodeSoFar The accumulated hash code so far. + * @return The generated code for getting the hashCode value. + */ + public static String getFieldsHashCode(final List fields, String generatedCodeSoFar) + throws RuntimeException { + for (Field f : fields) { + if (f.parent() != null) { + final OneOfField oneOfField = f.parent(); + generatedCodeSoFar += getFieldsHashCode(oneOfField.fields(), generatedCodeSoFar); + } + + if (f.optionalValueType()) { + generatedCodeSoFar = getOptionalHashCodeGeneration(generatedCodeSoFar, f); + } else if (f.repeated()) { + generatedCodeSoFar = getRepeatedHashCodeGeneration(generatedCodeSoFar, f); + } else if (f.nameCamelFirstLower() != null) { + if (f.type() == Field.FieldType.FIXED32 + || f.type() == Field.FieldType.INT32 + || f.type() == Field.FieldType.SFIXED32 + || f.type() == Field.FieldType.SINT32 + || f.type() == Field.FieldType.UINT32) { + generatedCodeSoFar += + (FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Integer.hashCode($fieldName); + } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.FIXED64 + || f.type() == Field.FieldType.INT64 + || f.type() == Field.FieldType.SFIXED64 + || f.type() == Field.FieldType.SINT64 + || f.type() == Field.FieldType.UINT64) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Long.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.ENUM) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + $fieldName.hashCode(); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.BOOL) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Boolean.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.FLOAT) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Float.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.DOUBLE) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Double.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.STRING) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + $fieldName.hashCode(); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.BYTES) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + ($fieldName == null ? 0 : $fieldName.hashCode()); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.parent() == null) { // process sub message + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != null && $fieldName != DEFAULT.$fieldName) { + result = 31 * result + $fieldName.hashCode(); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else { + throw new RuntimeException( + "Unexpected field type for getting HashCode - " + f.type().toString()); + } + } + } + return generatedCodeSoFar; + } + + /** + * Get the hashcode codegen for a optional field. + * + * @param generatedCodeSoFar The string that the codegen is generated into. + * @param f The field for which to generate the hash code. + * @return Updated codegen string. + */ + @NotNull + private static String getOptionalHashCodeGeneration(String generatedCodeSoFar, Field f) { + switch (f.messageType()) { + case "StringValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + $fieldName.hashCode(); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "BoolValue" -> generatedCodeSoFar += + (FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Boolean.hashCode($fieldName); + } + """) + .replace("$fieldName", f.nameCamelFirstLower()); + case "Int32Value", "UInt32Value" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Integer.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "Int64Value", "UInt64Value" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Long.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "FloatValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Float.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "DoubleValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + Double.hashCode($fieldName); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "BytesValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != DEFAULT.$fieldName) { + result = 31 * result + ($fieldName == null ? 0 : $fieldName.hashCode()); + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + default -> throw new UnsupportedOperationException( + "Unhandled optional message type:" + f.messageType()); + } + return generatedCodeSoFar; + } + + /** + * Get the hashcode codegen for a repeated field. + * + * @param generatedCodeSoFar The string that the codegen is generated into. + * @param f The field for which to generate the hash code. + * @return Updated codegen string. + */ + @NotNull + private static String getRepeatedHashCodeGeneration(String generatedCodeSoFar, Field f) { + generatedCodeSoFar += + FIELD_INDENT + + """ + java.util.List list$$fieldName = $fieldName; + for (Object o : list$$fieldName) { + if ($fieldName != DEFAULT.$fieldName) { + if (o != null) { + result = 31 * result; + } + else { + result = 31 * result + o.hashCode(); + } + } + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + return generatedCodeSoFar; + } + + /** + * Recursively calculates the hashcode for a message fields. + * + * @param fields The fields of this object. + * @param generatedCodeSoFar The accumulated hash code so far. + * @return The generated code for getting the hashCode value. + */ + public static String getFieldsEqualsStatements( + final List fields, String generatedCodeSoFar) { + for (Field f : fields) { + if (f.parent() != null) { + final OneOfField oneOfField = f.parent(); + generatedCodeSoFar += + getFieldsEqualsStatements(oneOfField.fields(), generatedCodeSoFar); + } + + if (f.optionalValueType()) { + generatedCodeSoFar = getOptionalEqualsGeneration(generatedCodeSoFar, f); + } else if (f.repeated()) { + generatedCodeSoFar = getRepeatedEqualsGeneration(generatedCodeSoFar, f); + } else if (f.nameCamelFirstLower() != null) { + if (f.type() == Field.FieldType.FIXED32 + || f.type() == Field.FieldType.INT32 + || f.type() == Field.FieldType.SFIXED32 + || f.type() == Field.FieldType.SINT32 + || f.type() == Field.FieldType.UINT32) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.FIXED64 + || f.type() == Field.FieldType.INT64 + || f.type() == Field.FieldType.SFIXED64 + || f.type() == Field.FieldType.SINT64 + || f.type() == Field.FieldType.UINT64) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.BOOL) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.DOUBLE) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.FLOAT) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.DOUBLE) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.STRING + || f.type() == Field.FieldType.BYTES + || f.type() == Field.FieldType.ENUM + || f.parent() == null /* Process a sub-message */) { + generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName == null && thatObj.$fieldName != null) { + return false; + } + if ($fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + } else { + throw new RuntimeException( + "Unexpected field type for getting HashCode - " + f.type().toString()); + } + } + } + return generatedCodeSoFar; + } + + /** + * Get the equals codegen for a optional field. + * + * @param generatedCodeSoFar The string that the codegen is generated into. + * @param f The field for which to generate the equals code. + * @return Updated codegen string. + */ + @NotNull + private static String getOptionalEqualsGeneration(String generatedCodeSoFar, Field f) { + switch (f.messageType()) { + case "StringValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if (this.$fieldName == null && thatObj.$fieldName != null) { + return false; + } + if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "BoolValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName instanceof Object) { + if (this.$fieldName == null && thatObj.$fieldName != null) { + return false; + } + if (!$fieldName.equals(thatObj.$fieldName)) { + return false; + } + } + else if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "Int32Value", "UInt32Value", "Int64Value", "UInt64Value" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName instanceof Object) { + if (this.$fieldName == null && thatObj.$fieldName != null) { + return false; + } + if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + } + else if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "FloatValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName instanceof Object) { + if (this.$fieldName == null && thatObj.$fieldName != null) { + return false; + } + if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + } + else if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "DoubleValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if ($fieldName instanceof Object) { + if (this.$fieldName == null && thatObj.$fieldName != null) { + return false; + } + if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + } + else if ($fieldName != thatObj.$fieldName) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + case "BytesValue" -> generatedCodeSoFar += + FIELD_INDENT + + """ + if (this.$fieldName == null && thatObj.$fieldName != null) { + return false; + } + if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + default -> throw new UnsupportedOperationException( + "Unhandled optional message type:" + f.messageType()); + } + return generatedCodeSoFar; + } + + /** + * Get the equals codegen for a repeated field. + * + * @param generatedCodeSoFar The string that the codegen is generated into. + * @param f The field for which to generate the equals code. + * @return Updated codegen string. + */ + @NotNull + private static String getRepeatedEqualsGeneration(String generatedCodeSoFar, Field f) { + generatedCodeSoFar += + """ + if (this.$fieldName == null && this.$fieldName != null) { + return false; + } + + if (this.$fieldName != null && !$fieldName.equals(thatObj.$fieldName)) { + return false; + } + """ + .replace("$fieldName", f.nameCamelFirstLower()); + return generatedCodeSoFar; + } + + /** + * Remove leading dot from a string so ".a.b.c" becomes "a.b.c" + * + * @param text text to remove leading dot from + * @return text without a leading dot + */ + public static String removingLeadingDot(String text) { + if (text.length() > 0 & text.charAt(0) == '.') { + return text.substring(1); + } + return text; + } + + /** + * Get the java file for a src directory, package and classname with optional suffix. All parent + * directories will also be created. + * + * @param srcDir The src dir root of all java src + * @param javaPackage the java package with '.' deliminators + * @param className the camel case class name + * @return File object for java file + */ + public static File getJavaFile(File srcDir, String javaPackage, String className) { + File packagePath = + new File( + srcDir.getPath() + + File.separatorChar + + javaPackage.replaceAll("\\.", "\\" + File.separator)); + //noinspection ResultOfMethodCallIgnored + packagePath.mkdirs(); + return new File(packagePath, className + ".java"); + } } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java index a1a27f00..e831e59e 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java @@ -1,5 +1,25 @@ +/* + * Copyright (C) 2023 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.hedera.pbj.compiler.impl.generators; +import static com.hedera.pbj.compiler.impl.Common.*; +import static com.hedera.pbj.compiler.impl.generators.EnumGenerator.EnumValue; +import static com.hedera.pbj.compiler.impl.generators.EnumGenerator.createEnum; + import com.hedera.pbj.compiler.impl.*; import com.hedera.pbj.compiler.impl.Field.FieldType; import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; @@ -10,537 +30,818 @@ import java.util.*; import java.util.stream.Collectors; -import static com.hedera.pbj.compiler.impl.Common.*; -import static com.hedera.pbj.compiler.impl.generators.EnumGenerator.EnumValue; -import static com.hedera.pbj.compiler.impl.generators.EnumGenerator.createEnum; - /** - * Code generator that parses protobuf files and generates nice Java source for record files for each message type and - * enum. + * Code generator that parses protobuf files and generates nice Java source for record files for + * each message type and enum. */ -@SuppressWarnings({"StringConcatenationInLoop", "EscapedSpace", "RedundantLabeledSwitchRuleCodeBlock"}) +@SuppressWarnings({ + "StringConcatenationInLoop", + "EscapedSpace", + "RedundantLabeledSwitchRuleCodeBlock" +}) public final class ModelGenerator implements Generator { - /** - * {@inheritDoc} - * - *

Generates a new model object, as a Java Record type. - */ - public void generate(final Protobuf3Parser.MessageDefContext msgDef, - final File destinationSrcDir, - File destinationTestSrcDir, final ContextualLookupHelper lookupHelper) throws IOException { - - // The javaRecordName will be something like "AccountID". - final var javaRecordName = lookupHelper.getUnqualifiedClassForMessage(FileType.MODEL, msgDef); - // The modelPackage is the Java package to put the model class into. - final String modelPackage = lookupHelper.getPackageForMessage(FileType.MODEL, msgDef); - // The File to write the sources that we generate into - final File javaFile = getJavaFile(destinationSrcDir, modelPackage, javaRecordName); - // The javadoc comment to use for the model class, which comes **directly** from the protobuf schema, - // but is cleaned up and formatted for use in JavaDoc. - String javaDocComment = (msgDef.docComment()== null) ? "" : - cleanDocStr(msgDef.docComment().getText().replaceAll("\n \\*\s*\n","\n *

\n")); - // The Javadoc "@Deprecated" tag, which is set if the protobuf schema says the field is deprecated - String deprecated = ""; - // The list of fields, as defined in the protobuf schema - final List fields = new ArrayList<>(); - // The generated Java code for an enum field if OneOf is used - final List oneofEnums = new ArrayList<>(); - // The generated Java code for getters if OneOf is used - final List oneofGetters = new ArrayList<>(); - // The generated Java code for has methods for normal fields - final List hasMethods = new ArrayList<>(); - // The generated Java import statements. We'll build this up as we go. - final Set imports = new TreeSet<>(); - imports.add("com.hedera.pbj.runtime"); - imports.add("com.hedera.pbj.runtime.io"); - imports.add("com.hedera.pbj.runtime.io.buffer"); - imports.add("com.hedera.pbj.runtime.io.stream"); - imports.add("edu.umd.cs.findbugs.annotations"); - - // Iterate over all the items in the protobuf schema - for(var item: msgDef.messageBody().messageElement()) { - if (item.messageDef() != null) { // process sub messages - generate(item.messageDef(), destinationSrcDir, destinationTestSrcDir, lookupHelper); - } else if (item.oneof() != null) { // process one ofs - final var oneOfField = new OneOfField(item.oneof(), javaRecordName, lookupHelper); - final var enumName = oneOfField.nameCamelFirstUpper() + "OneOfType"; - final int maxIndex = oneOfField.fields().get(oneOfField.fields().size() - 1).fieldNumber(); - final Map enumValues = new HashMap<>(); - for (final var field : oneOfField.fields()) { - final String javaFieldType = javaPrimitiveToObjectType(field.javaFieldType()); - final String enumComment = cleanDocStr(field.comment()) - .replaceAll("[\t\s]*/\\*\\*","") // remove doc start indenting - .replaceAll("\n[\t\s]+\\*","\n") // remove doc indenting - .replaceAll("/\\*\\*","") // remove doc start - .replaceAll("\\*\\*/",""); // remove doc end - enumValues.put(field.fieldNumber(), new EnumValue(field.name(), field.deprecated(), enumComment)); - // generate getters for one ofs - oneofGetters.add(""" - /** - * Direct typed getter for one of field $fieldName. - * - * @return one of value or null if one of is not set or a different one of value - */ - public @Nullable $javaFieldType $fieldName() { - return $oneOfField.kind() == $enumName.$enumValue ? ($javaFieldType)$oneOfField.value() : null; - } - - /** - * Convenience method to check if the $oneOfField has a one-of with type $enumValue - * - * @return true of the one of kind is $enumValue - */ - public boolean has$fieldNameUpperFirst() { - return $oneOfField.kind() == $enumName.$enumValue; - } - - /** - * Gets the value for $fieldName if it has a value, or else returns the default - * value for the type. - * - * @param defaultValue the default value to return if $fieldName is null - * @return the value for $fieldName if it has a value, or else returns the default value - */ - public $javaFieldType $fieldNameOrElse(@NonNull final $javaFieldType defaultValue) { - return has$fieldNameUpperFirst() ? $fieldName() : defaultValue; - } - - /** - * Gets the value for $fieldName if it was set, or throws a NullPointerException if it was not set. - * - * @return the value for $fieldName if it has a value - * @throws NullPointerException if $fieldName is null - */ - public @NonNull $javaFieldType $fieldNameOrThrow() { - return requireNonNull($fieldName(), "Field $fieldName is null"); - } - """ - .replace("$fieldNameUpperFirst",field.nameCamelFirstUpper()) - .replace("$fieldName",field.nameCamelFirstLower()) - .replace("$javaFieldType",javaFieldType) - .replace("$oneOfField",oneOfField.nameCamelFirstLower()) - .replace("$enumName",enumName) - .replace("$enumValue",camelToUpperSnake(field.name())) - .replace("$enumValue",camelToUpperSnake(field.name())) - .replaceAll("\n","\n" + FIELD_INDENT)); - if (field.type() == Field.FieldType.MESSAGE) { - field.addAllNeededImports(imports, true, false, false); - } - } - final String enumComment = """ - /** - * Enum for the type of "%s" oneof value - */""".formatted(oneOfField.name()); - final String enumString = createEnum(FIELD_INDENT,enumComment ,"",enumName,maxIndex,enumValues, true); - oneofEnums.add(enumString); - fields.add(oneOfField); - imports.add("com.hedera.pbj.runtime"); - } else if (item.mapField() != null) { // process map fields - System.err.println("Encountered a mapField that was not handled in " + javaRecordName); - } else if (item.field() != null && item.field().fieldName() != null) { - final SingleField field = new SingleField(item.field(), lookupHelper); - fields.add(field); - field.addAllNeededImports(imports, true, false, false); - if (field.type() == FieldType.MESSAGE) { - hasMethods.add(""" - /** - * Convenience method to check if the $fieldName has a value - * - * @return true of the $fieldName has a value - */ - public boolean has$fieldNameUpperFirst() { - return $fieldName != null; - } - - /** - * Gets the value for $fieldName if it has a value, or else returns the default - * value for the type. - * - * @param defaultValue the default value to return if $fieldName is null - * @return the value for $fieldName if it has a value, or else returns the default value - */ - public $javaFieldType $fieldNameOrElse(@NonNull final $javaFieldType defaultValue) { - return has$fieldNameUpperFirst() ? $fieldName : defaultValue; - } - - /** - * Gets the value for $fieldName if it has a value, or else throws an NPE. - * value for the type. - * - * @return the value for $fieldName if it has a value - * @throws NullPointerException if $fieldName is null - */ - public @NonNull $javaFieldType $fieldNameOrThrow() { - return requireNonNull($fieldName, "Field $fieldName is null"); - } - - /** - * Executes the supplied {@link Consumer} if, and only if, the $fieldName has a value - * - * @param ifPresent the {@link Consumer} to execute - */ - public void if$fieldNameUpperFirst(@NonNull final Consumer<$javaFieldType> ifPresent) { - if (has$fieldNameUpperFirst()) { - ifPresent.accept($fieldName); - } - } - """ - .replace("$fieldNameUpperFirst", field.nameCamelFirstUpper()) - .replace("$javaFieldType", field.javaFieldType()) - .replace("$fieldName", field.nameCamelFirstLower())); - } - } else if (item.optionStatement() != null){ - if ("deprecated".equals(item.optionStatement().optionName().getText())) { - deprecated = "@Deprecated "; - } else { - System.err.println("Unhandled Option: "+item.optionStatement().getText()); - } - } else if (item.reserved() == null){ // ignore reserved and warn about anything else - System.err.println("ModelGenerator Warning - Unknown element: "+item+" -- "+item.getText()); - } - } - - // process field java doc and insert into record java doc - if (!fields.isEmpty()) { - String recordJavaDoc = javaDocComment.length() > 0 ? - javaDocComment.replaceAll("\n\s*\\*/","") : - "/**\n * "+javaRecordName; - recordJavaDoc += "\n *"; - for(var field: fields) { - recordJavaDoc += "\n * @param "+field.nameCamelFirstLower()+" "+ - field.comment() - .replaceAll("\n", "\n * "+" ".repeat(field.nameCamelFirstLower().length())); - } - recordJavaDoc += "\n */"; - javaDocComment = cleanDocStr(recordJavaDoc); - } - - // === Build Body Content - String bodyContent = ""; - - // static codec and default instance - bodyContent += """ - /** Protobuf codec for reading and writing in protobuf format */ - public static final Codec<$modelClass> PROTOBUF = new $qualifiedCodecClass(); - /** JSON codec for reading and writing in JSON format */ - public static final JsonCodec<$modelClass> JSON = new $qualifiedJsonCodecClass(); - - /** Default instance with all fields set to default values */ - public static final $modelClass DEFAULT = newBuilder().build(); - """ - .replace("$modelClass",javaRecordName) - .replace("$qualifiedCodecClass",lookupHelper.getFullyQualifiedMessageClassname(FileType.CODEC, msgDef)) - .replace("$qualifiedJsonCodecClass",lookupHelper.getFullyQualifiedMessageClassname(FileType.JSON_CODEC, msgDef)) - .replaceAll("\n","\n"+FIELD_INDENT); - - // constructor - if (fields.stream().anyMatch(f -> f instanceof OneOfField || f.optionalValueType())) { - bodyContent += """ - - /** - * Override the default constructor adding input validation - * %s - */ - public %s { - %s - } - - """.formatted( - fields.stream().map(field -> "\n * @param "+field.nameCamelFirstLower()+" "+ - field.comment() - .replaceAll("\n", "\n * "+" ".repeat(field.nameCamelFirstLower().length())) - ).collect(Collectors.joining()), - javaRecordName, - fields.stream() - .filter(f -> f instanceof OneOfField) - .map(ModelGenerator::generateConstructorCode) - .collect(Collectors.joining("\n")) - ).replaceAll("\n","\n"+FIELD_INDENT); - } - - // Add here hashCode() for object with a single int field. - if (fields.size() == 1) { - FieldType fieldType = fields.get(0).type(); - switch (fieldType) { - case INT32, UINT32, SINT32, FIXED32, SFIXED32, - FIXED64, SFIXED64, INT64, UINT64, SINT64 -> { - bodyContent += FIELD_INDENT + """ - /** - * Override the default hashCode method for - * single field int objects. - */ - @Override - public int hashCode() { - // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 - long x = $fieldName; - x += x << 30; - x ^= x >>> 27; - x += x << 16; - x ^= x >>> 20; - x += x << 5; - x ^= x >>> 18; - x += x << 10; - x ^= x >>> 24; - x += x << 30; - return (int)x; - }""" - .replace("$fieldName", fields.get(0).name()) - .replaceAll("\n","\n"+FIELD_INDENT); - } - default -> { - // Do nothing. - } - } - } - - // Has methods - bodyContent += String.join("\n", hasMethods).replaceAll("\n","\n"+FIELD_INDENT); - bodyContent += "\n"; - - // oneof getters - bodyContent += String.join("\n ", oneofGetters); - bodyContent += "\n"; - - // builder copy & new builder methods - bodyContent += FIELD_INDENT + """ - /** - * Return a builder for building a copy of this model object. It will be pre-populated with all the data from this - * model object. - * - * @return a pre-populated builder - */ - public Builder copyBuilder() { - return new Builder(%s); - } - - /** - * Return a new builder for building a model object. This is just a shortcut for new Model.Builder(). - * - * @return a new builder - */ - public static Builder newBuilder() { - return new Builder(); - } - - """ - .formatted(fields.stream().map(Field::nameCamelFirstLower).collect(Collectors.joining(", "))) - .replaceAll("\n","\n"+FIELD_INDENT); - - // generate builder - bodyContent += generateBuilder(msgDef, fields, lookupHelper); - bodyContent += "\n"+FIELD_INDENT; - - // oneof enums - bodyContent += String.join("\n ", oneofEnums); - - // === Build file - try (FileWriter javaWriter = new FileWriter(javaFile)) { - javaWriter.write(""" - package $package; - $imports - import com.hedera.pbj.runtime.Codec; - import java.util.function.Consumer; - import edu.umd.cs.findbugs.annotations.Nullable; - import static java.util.Objects.requireNonNull; - - $javaDocComment - $deprecated - public record $javaRecordName( - $fields - ){ - $bodyContent - } - """ - .replace("$package",modelPackage) - .replace("$imports",imports.isEmpty() ? "" : imports.stream().collect(Collectors.joining(".*;\nimport ","\nimport ",".*;\n"))) - .replace("$javaDocComment",javaDocComment) - .replace("$deprecated",deprecated) - .replace("$javaRecordName",javaRecordName) - .replace("$fields",fields.stream().map(field -> - FIELD_INDENT + (field.type() == FieldType.MESSAGE ? "@Nullable " : "") + field.javaFieldType() + " " + field.nameCamelFirstLower() - ).collect(Collectors.joining(",\n "))) - .replace("$bodyContent",bodyContent) - ); - } - } - - private static void generateBuilderMethods(List builderMethods, Field field) { - final String prefix, postfix, fieldToSet; - final OneOfField parentOneOfField = field.parent(); - if (parentOneOfField != null) { - final String oneOfEnumValue = parentOneOfField.getEnumClassRef()+"."+camelToUpperSnake(field.name()); - prefix = "new OneOf<>("+oneOfEnumValue+","; - postfix = ")"; - fieldToSet = parentOneOfField.nameCamelFirstLower(); - } else { - prefix = ""; - postfix = ""; - fieldToSet = field.nameCamelFirstLower(); - } - builderMethods.add(""" - /** - * $fieldDoc - * - * @param $fieldName value to set - * @return builder to continue building with - */ - public Builder $fieldName($fieldType $fieldName) { - this.$fieldToSet = $prefix $fieldName $postfix; - return this; - }""" - .replace("$fieldDoc",field.comment() - .replaceAll("\n", "\n * ")) - .replace("$fieldName",field.nameCamelFirstLower()) - .replace("$fieldToSet",fieldToSet) - .replace("$prefix",prefix) - .replace("$postfix",postfix) - .replace("$fieldType",field.javaFieldType()) - .replaceAll("\n","\n"+FIELD_INDENT)); - // add nice method for simple message fields so can just set using un-built builder - if (field.type() == Field.FieldType.MESSAGE && !field.optionalValueType() && !field.repeated()) { - builderMethods.add(""" - /** - * $fieldDoc - * - * @param builder A pre-populated builder - * @return builder to continue building with - */ - public Builder $fieldName($messageClass.Builder builder) { - this.$fieldToSet = $prefix builder.build() $postfix; - return this; - }""" - .replace("$messageClass",field.messageType()) - .replace("$fieldDoc",field.comment() - .replaceAll("\n", "\n * ")) - .replace("$fieldName",field.nameCamelFirstLower()) - .replace("$fieldToSet",fieldToSet) - .replace("$prefix",prefix) - .replace("$postfix",postfix) - .replace("$fieldType",field.javaFieldType()) - .replaceAll("\n","\n"+FIELD_INDENT)); - } - - // add nice method for message fields with list types for varargs - if (field.repeated()) { - builderMethods.add(""" - /** - * $fieldDoc - * - * @param values varargs value to be built into a list - * @return builder to continue building with - */ - public Builder $fieldName($baseType ... values) { - this.$fieldToSet = $prefix List.of(values) $postfix; - return this; - }""" - .replace("$baseType",field.javaFieldType().substring("List<".length(),field.javaFieldType().length()-1)) - .replace("$fieldDoc",field.comment() - .replaceAll("\n", "\n * ")) - .replace("$fieldName",field.nameCamelFirstLower()) - .replace("$fieldToSet",fieldToSet) - .replace("$fieldType",field.javaFieldType()) - .replace("$prefix",prefix) - .replace("$postfix",postfix) - .replaceAll("\n","\n"+FIELD_INDENT)); - } - } - - private static String generateBuilder(final Protobuf3Parser.MessageDefContext msgDef, List fields, final ContextualLookupHelper lookupHelper) { - final String javaRecordName = msgDef.messageName().getText(); - List builderMethods = new ArrayList<>(); - for (Field field: fields) { - if (field.type() == Field.FieldType.ONE_OF) { - final OneOfField oneOfField = (OneOfField) field; - for (Field subField: oneOfField.fields()) { - generateBuilderMethods(builderMethods, subField); - } - } else { - generateBuilderMethods(builderMethods, field); - } - } - return """ - /** - * Builder class for easy creation, ideal for clean code where performance is not critical. In critical performance - * paths use the constructor directly. - */ - public static final class Builder { - $fields; - - /** - * Create an empty builder - */ - public Builder() {} - - /** - * Create a pre-populated builder - * $constructorParamDocs - */ - public Builder($constructorParams) { - $constructorCode; - } - - /** - * Build a new model record with data set on builder - * - * @return new model record with data set - */ - public $javaRecordName build() { - return new $javaRecordName($recordParams); - } - - $builderMethods - }""" - .replace("$fields", fields.stream().map(field -> - "private " + field.javaFieldType() + " " + field.nameCamelFirstLower() + - " = " + getDefaultValue(field, msgDef, lookupHelper) - ).collect(Collectors.joining(";\n "))) - .replace("$constructorParamDocs",fields.stream().map(field -> - "\n * @param "+field.nameCamelFirstLower()+" "+ - field.comment().replaceAll("\n", "\n * "+" ".repeat(field.nameCamelFirstLower().length())) - ).collect(Collectors.joining(", "))) - .replace("$constructorParams",fields.stream().map(field -> - field.javaFieldType() + " " + field.nameCamelFirstLower() - ).collect(Collectors.joining(", "))) - .replace("$constructorCode",fields.stream().map(field -> - "this." + field.nameCamelFirstLower() + " = " + field.nameCamelFirstLower() - ).collect(Collectors.joining(";\n"+FIELD_INDENT+FIELD_INDENT))) - .replace("$javaRecordName",javaRecordName) - .replace("$recordParams",fields.stream().map(Field::nameCamelFirstLower).collect(Collectors.joining(", "))) - .replace("$builderMethods",builderMethods.stream().collect(Collectors.joining("\n\n"+FIELD_INDENT))) - .replaceAll("\n","\n"+FIELD_INDENT); - } - - private static String getDefaultValue(Field field, final Protobuf3Parser.MessageDefContext msgDef, final ContextualLookupHelper lookupHelper) { - if (field.type() == Field.FieldType.ONE_OF) { - return lookupHelper.getFullyQualifiedMessageClassname(FileType.CODEC, msgDef)+"."+field.javaDefault(); - } else { - return field.javaDefault(); - } - } - - private static String generateConstructorCode(final Field f) { - StringBuilder sb = new StringBuilder(FIELD_INDENT+""" - if ($fieldName == null) { - throw new NullPointerException("Parameter '$fieldName' must be supplied and can not be null"); - }""".replace("$fieldName", f.nameCamelFirstLower())); - if (f instanceof final OneOfField oof) { - for (Field subField: oof.fields()) { - if(subField.optionalValueType()) { - sb.append(""" - - // handle special case where protobuf does not have destination between a OneOf with optional - // value of empty vs an unset OneOf. - if($fieldName.kind() == $fieldUpperNameOneOfType.$subFieldNameUpper && $fieldName.value() == null) { - $fieldName = new OneOf<>($fieldUpperNameOneOfType.UNSET, null); - }""" - .replace("$fieldName", f.nameCamelFirstLower()) - .replace("$fieldUpperName", f.nameCamelFirstUpper()) - .replace("$subFieldNameUpper", camelToUpperSnake(subField.name())) - ); - } - } - } - return sb.toString().replaceAll("\n","\n"+FIELD_INDENT); - } -} \ No newline at end of file + /** + * {@inheritDoc} + * + *

Generates a new model object, as a Java Record type. + */ + public void generate( + final Protobuf3Parser.MessageDefContext msgDef, + final File destinationSrcDir, + File destinationTestSrcDir, + final ContextualLookupHelper lookupHelper) + throws IOException { + + // The javaRecordName will be something like "AccountID". + final var javaRecordName = + lookupHelper.getUnqualifiedClassForMessage(FileType.MODEL, msgDef); + // The modelPackage is the Java package to put the model class into. + final String modelPackage = lookupHelper.getPackageForMessage(FileType.MODEL, msgDef); + // The File to write the sources that we generate into + final File javaFile = getJavaFile(destinationSrcDir, modelPackage, javaRecordName); + // The javadoc comment to use for the model class, which comes **directly** from the + // protobuf schema, + // but is cleaned up and formatted for use in JavaDoc. + String javaDocComment = + (msgDef.docComment() == null) + ? "" + : cleanDocStr( + msgDef.docComment() + .getText() + .replaceAll("\n \\*\s*\n", "\n *

\n")); + // The Javadoc "@Deprecated" tag, which is set if the protobuf schema says the field is + // deprecated + String deprecated = ""; + // The list of fields, as defined in the protobuf schema + final List fields = new ArrayList<>(); + // The generated Java code for an enum field if OneOf is used + final List oneofEnums = new ArrayList<>(); + // The generated Java code for getters if OneOf is used + final List oneofGetters = new ArrayList<>(); + // The generated Java code for has methods for normal fields + final List hasMethods = new ArrayList<>(); + // The generated Java import statements. We'll build this up as we go. + final Set imports = new TreeSet<>(); + imports.add("com.hedera.pbj.runtime"); + imports.add("com.hedera.pbj.runtime.io"); + imports.add("com.hedera.pbj.runtime.io.buffer"); + imports.add("com.hedera.pbj.runtime.io.stream"); + imports.add("edu.umd.cs.findbugs.annotations"); + + // Iterate over all the items in the protobuf schema + for (var item : msgDef.messageBody().messageElement()) { + if (item.messageDef() != null) { // process sub messages + generate(item.messageDef(), destinationSrcDir, destinationTestSrcDir, lookupHelper); + } else if (item.oneof() != null) { // process one ofs + final var oneOfField = new OneOfField(item.oneof(), javaRecordName, lookupHelper); + final var enumName = oneOfField.nameCamelFirstUpper() + "OneOfType"; + final int maxIndex = + oneOfField.fields().get(oneOfField.fields().size() - 1).fieldNumber(); + final Map enumValues = new HashMap<>(); + for (final var field : oneOfField.fields()) { + final String javaFieldType = javaPrimitiveToObjectType(field.javaFieldType()); + final String enumComment = + cleanDocStr(field.comment()) + .replaceAll("[\t\s]*/\\*\\*", "") // remove doc start indenting + .replaceAll("\n[\t\s]+\\*", "\n") // remove doc indenting + .replaceAll("/\\*\\*", "") // remove doc start + .replaceAll("\\*\\*/", ""); // remove doc end + enumValues.put( + field.fieldNumber(), + new EnumValue(field.name(), field.deprecated(), enumComment)); + // generate getters for one ofs + oneofGetters.add( + """ + /** + * Direct typed getter for one of field $fieldName. + * + * @return one of value or null if one of is not set or a different one of value + */ + public @Nullable $javaFieldType $fieldName() { + return $oneOfField.kind() == $enumName.$enumValue ? ($javaFieldType)$oneOfField.value() : null; + } + + /** + * Convenience method to check if the $oneOfField has a one-of with type $enumValue + * + * @return true of the one of kind is $enumValue + */ + public boolean has$fieldNameUpperFirst() { + return $oneOfField.kind() == $enumName.$enumValue; + } + + /** + * Gets the value for $fieldName if it has a value, or else returns the default + * value for the type. + * + * @param defaultValue the default value to return if $fieldName is null + * @return the value for $fieldName if it has a value, or else returns the default value + */ + public $javaFieldType $fieldNameOrElse(@NonNull final $javaFieldType defaultValue) { + return has$fieldNameUpperFirst() ? $fieldName() : defaultValue; + } + + /** + * Gets the value for $fieldName if it was set, or throws a NullPointerException if it was not set. + * + * @return the value for $fieldName if it has a value + * @throws NullPointerException if $fieldName is null + */ + public @NonNull $javaFieldType $fieldNameOrThrow() { + return requireNonNull($fieldName(), "Field $fieldName is null"); + } + """ + .replace("$fieldNameUpperFirst", field.nameCamelFirstUpper()) + .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$javaFieldType", javaFieldType) + .replace("$oneOfField", oneOfField.nameCamelFirstLower()) + .replace("$enumName", enumName) + .replace("$enumValue", camelToUpperSnake(field.name())) + .replace("$enumValue", camelToUpperSnake(field.name())) + .replaceAll("\n", "\n" + FIELD_INDENT)); + if (field.type() == Field.FieldType.MESSAGE) { + field.addAllNeededImports(imports, true, false, false); + } + } + final String enumComment = + """ + /** + * Enum for the type of "%s" oneof value + */""" + .formatted(oneOfField.name()); + final String enumString = + createEnum( + FIELD_INDENT, + enumComment, + "", + enumName, + maxIndex, + enumValues, + true); + oneofEnums.add(enumString); + fields.add(oneOfField); + imports.add("com.hedera.pbj.runtime"); + } else if (item.mapField() != null) { // process map fields + System.err.println( + "Encountered a mapField that was not handled in " + javaRecordName); + } else if (item.field() != null && item.field().fieldName() != null) { + final SingleField field = new SingleField(item.field(), lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, false, false); + if (field.type() == FieldType.MESSAGE) { + hasMethods.add( + """ + /** + * Convenience method to check if the $fieldName has a value + * + * @return true of the $fieldName has a value + */ + public boolean has$fieldNameUpperFirst() { + return $fieldName != null; + } + + /** + * Gets the value for $fieldName if it has a value, or else returns the default + * value for the type. + * + * @param defaultValue the default value to return if $fieldName is null + * @return the value for $fieldName if it has a value, or else returns the default value + */ + public $javaFieldType $fieldNameOrElse(@NonNull final $javaFieldType defaultValue) { + return has$fieldNameUpperFirst() ? $fieldName : defaultValue; + } + + /** + * Gets the value for $fieldName if it has a value, or else throws an NPE. + * value for the type. + * + * @return the value for $fieldName if it has a value + * @throws NullPointerException if $fieldName is null + */ + public @NonNull $javaFieldType $fieldNameOrThrow() { + return requireNonNull($fieldName, "Field $fieldName is null"); + } + + /** + * Executes the supplied {@link Consumer} if, and only if, the $fieldName has a value + * + * @param ifPresent the {@link Consumer} to execute + */ + public void if$fieldNameUpperFirst(@NonNull final Consumer<$javaFieldType> ifPresent) { + if (has$fieldNameUpperFirst()) { + ifPresent.accept($fieldName); + } + } + """ + .replace("$fieldNameUpperFirst", field.nameCamelFirstUpper()) + .replace("$javaFieldType", field.javaFieldType()) + .replace("$fieldName", field.nameCamelFirstLower())); + } + } else if (item.optionStatement() != null) { + if ("deprecated".equals(item.optionStatement().optionName().getText())) { + deprecated = "@Deprecated "; + } else { + System.err.println("Unhandled Option: " + item.optionStatement().getText()); + } + } else if (item.reserved() == null) { // ignore reserved and warn about anything else + System.err.println( + "ModelGenerator Warning - Unknown element: " + + item + + " -- " + + item.getText()); + } + } + + // process field java doc and insert into record java doc + if (!fields.isEmpty()) { + String recordJavaDoc = + javaDocComment.length() > 0 + ? javaDocComment.replaceAll("\n\s*\\*/", "") + : "/**\n * " + javaRecordName; + recordJavaDoc += "\n *"; + for (var field : fields) { + recordJavaDoc += + "\n * @param " + + field.nameCamelFirstLower() + + " " + + field.comment() + .replaceAll( + "\n", + "\n * " + + " " + .repeat( + field.nameCamelFirstLower() + .length())); + } + recordJavaDoc += "\n */"; + javaDocComment = cleanDocStr(recordJavaDoc); + } + + // === Build Body Content + String bodyContent = ""; + + // static codec and default instance + bodyContent += + """ + /** Protobuf codec for reading and writing in protobuf format */ + public static final Codec<$modelClass> PROTOBUF = new $qualifiedCodecClass(); + /** JSON codec for reading and writing in JSON format */ + public static final JsonCodec<$modelClass> JSON = new $qualifiedJsonCodecClass(); + + /** Default instance with all fields set to default values */ + public static final $modelClass DEFAULT = newBuilder().build(); + """ + .replace("$modelClass", javaRecordName) + .replace( + "$qualifiedCodecClass", + lookupHelper.getFullyQualifiedMessageClassname( + FileType.CODEC, msgDef)) + .replace( + "$qualifiedJsonCodecClass", + lookupHelper.getFullyQualifiedMessageClassname( + FileType.JSON_CODEC, msgDef)) + .replaceAll("\n", "\n" + FIELD_INDENT); + + // constructor + if (fields.stream().anyMatch(f -> f instanceof OneOfField || f.optionalValueType())) { + bodyContent += + """ + + /** + * Override the default constructor adding input validation + * %s + */ + public %s { + %s + } + + """ + .formatted( + fields.stream() + .map( + field -> + "\n * @param " + + field.nameCamelFirstLower() + + " " + + field.comment() + .replaceAll( + "\n", + "\n * " + + " " + .repeat( + field.nameCamelFirstLower() + .length()))) + .collect(Collectors.joining()), + javaRecordName, + fields.stream() + .filter(f -> f instanceof OneOfField) + .map(ModelGenerator::generateConstructorCode) + .collect(Collectors.joining("\n"))) + .replaceAll("\n", "\n" + FIELD_INDENT); + } + + // Add here hashCode() for object with a single int field. + boolean hashCodeGenerated = true; + if (fields.size() == 1) { + FieldType fieldType = fields.get(0).type(); + switch (fieldType) { + case INT32, + UINT32, + SINT32, + FIXED32, + SFIXED32, + FIXED64, + SFIXED64, + INT64, + UINT64, + SINT64 -> { + bodyContent += + FIELD_INDENT + + """ + /** + * Override the default hashCode method for + * single field int objects. + */ + @Override + public int hashCode() { + // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 + long x = $fieldName; + x += x << 30; + x ^= x >>> 27; + x += x << 16; + x ^= x >>> 20; + x += x << 5; + x ^= x >>> 18; + x += x << 10; + x ^= x >>> 24; + x += x << 30; + return (int)x; + }""" + .replace("$fieldName", fields.get(0).name()) + .replaceAll("\n", "\n" + FIELD_INDENT); + } + case FLOAT, DOUBLE -> { + bodyContent += + FIELD_INDENT + + """ + /** + * Override the default hashCode method for + * single field float and double objects. + */ + @Override + public int hashCode() { + // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 + double x = $fieldName; + x += x << 30; + x ^= x >>> 27; + x += x << 16; + x ^= x >>> 20; + x += x << 5; + x ^= x >>> 18; + x += x << 10; + x ^= x >>> 24; + x += x << 30; + return (int)x; + }""" + .replace("$fieldName", fields.get(0).name()) + .replaceAll("\n", "\n" + FIELD_INDENT); + } + case STRING -> { + bodyContent += + FIELD_INDENT + + """ + /** + * Override the default hashCode method for + * single field String objects. + */ + @Override + public int hashCode() { + // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 + long x = $fieldName.hashCode(); + x += x << 30; + x ^= x >>> 27; + x += x << 16; + x ^= x >>> 20; + x += x << 5; + x ^= x >>> 18; + x += x << 10; + x ^= x >>> 24; + x += x << 30; + return (int)x; + }""" + .replace("$fieldName", fields.get(0).name()) + .replaceAll("\n", "\n" + FIELD_INDENT); + } + default -> { + hashCodeGenerated = false; + } + } + } else { + hashCodeGenerated = false; + } + + String statements = ""; + if (!hashCodeGenerated) { + // Generate a call to private method that iterates through fields + // and calculates the hashcode. + statements = Common.getFieldsHashCode(fields, statements); + + bodyContent += + """ + /** + * Override the default hashCode method for + * all other objects to make hashCode + */ + @Override + public int hashCode() { + int result = 1; + """; + + bodyContent += statements; + + bodyContent += + """ + long hashCode = result; + hashCode += hashCode << 30; + hashCode ^= hashCode >>> 27; + hashCode += hashCode << 16; + hashCode ^= hashCode >>> 20; + hashCode += hashCode << 5; + hashCode ^= hashCode >>> 18; + hashCode += hashCode << 10; + hashCode ^= hashCode >>> 24; + hashCode += hashCode << 30; + return (int)hashCode; + } + """; + } + + String equalsStatements = ""; + // Generate a call to private method that iterates through fields + // and calculates the hashcode. + equalsStatements = Common.getFieldsEqualsStatements(fields, equalsStatements); + + bodyContent += + FIELD_INDENT + + """ + /** + * Override the default equals method for + */ + @Override + public boolean equals(Object that) { + if (that == null || this.getClass() != that.getClass()) { + return false; + } + + $javaRecordName thatObj = ($javaRecordName)that; + """ + .replace("$javaRecordName", javaRecordName); + + bodyContent += equalsStatements; + bodyContent += """ + return true; + } + """; + + bodyContent = bodyContent.replaceAll("\n", "\n" + FIELD_INDENT); + + // Has methods + bodyContent += String.join("\n", hasMethods).replaceAll("\n", "\n" + FIELD_INDENT); + bodyContent += "\n"; + + // oneof getters + bodyContent += String.join("\n ", oneofGetters); + bodyContent += "\n"; + + // builder copy & new builder methods + bodyContent += + FIELD_INDENT + + """ + /** + * Return a builder for building a copy of this model object. It will be pre-populated with all the data from this + * model object. + * + * @return a pre-populated builder + */ + public Builder copyBuilder() { + return new Builder(%s); + } + + /** + * Return a new builder for building a model object. This is just a shortcut for new Model.Builder(). + * + * @return a new builder + */ + public static Builder newBuilder() { + return new Builder(); + } + + """ + .formatted( + fields.stream() + .map(Field::nameCamelFirstLower) + .collect(Collectors.joining(", "))) + .replaceAll("\n", "\n" + FIELD_INDENT); + + // generate builder + bodyContent += generateBuilder(msgDef, fields, lookupHelper); + bodyContent += "\n" + FIELD_INDENT; + + // oneof enums + bodyContent += String.join("\n ", oneofEnums); + + // === Build file + try (FileWriter javaWriter = new FileWriter(javaFile)) { + javaWriter.write( + """ + package $package; + $imports + import com.hedera.pbj.runtime.Codec; + import java.util.function.Consumer; + import edu.umd.cs.findbugs.annotations.Nullable; + import static java.util.Objects.requireNonNull; + + $javaDocComment + $deprecated + public record $javaRecordName( + $fields + ){ + $bodyContent + } + """ + .replace("$package", modelPackage) + .replace( + "$imports", + imports.isEmpty() + ? "" + : imports.stream() + .collect( + Collectors.joining( + ".*;\nimport ", + "\nimport ", + ".*;\n"))) + .replace("$javaDocComment", javaDocComment) + .replace("$deprecated", deprecated) + .replace("$javaRecordName", javaRecordName) + .replace( + "$fields", + fields.stream() + .map( + field -> + FIELD_INDENT + + (field.type() + == FieldType + .MESSAGE + ? "@Nullable " + : "") + + field.javaFieldType() + + " " + + field.nameCamelFirstLower()) + .collect(Collectors.joining(",\n "))) + .replace("$bodyContent", bodyContent)); + } + } + + private static void generateBuilderMethods(List builderMethods, Field field) { + final String prefix, postfix, fieldToSet; + final OneOfField parentOneOfField = field.parent(); + if (parentOneOfField != null) { + final String oneOfEnumValue = + parentOneOfField.getEnumClassRef() + "." + camelToUpperSnake(field.name()); + prefix = "new OneOf<>(" + oneOfEnumValue + ","; + postfix = ")"; + fieldToSet = parentOneOfField.nameCamelFirstLower(); + } else { + prefix = ""; + postfix = ""; + fieldToSet = field.nameCamelFirstLower(); + } + builderMethods.add( + """ + /** + * $fieldDoc + * + * @param $fieldName value to set + * @return builder to continue building with + */ + public Builder $fieldName($fieldType $fieldName) { + this.$fieldToSet = $prefix $fieldName $postfix; + return this; + }""" + .replace("$fieldDoc", field.comment().replaceAll("\n", "\n * ")) + .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$fieldToSet", fieldToSet) + .replace("$prefix", prefix) + .replace("$postfix", postfix) + .replace("$fieldType", field.javaFieldType()) + .replaceAll("\n", "\n" + FIELD_INDENT)); + // add nice method for simple message fields so can just set using un-built builder + if (field.type() == Field.FieldType.MESSAGE + && !field.optionalValueType() + && !field.repeated()) { + builderMethods.add( + """ + /** + * $fieldDoc + * + * @param builder A pre-populated builder + * @return builder to continue building with + */ + public Builder $fieldName($messageClass.Builder builder) { + this.$fieldToSet = $prefix builder.build() $postfix; + return this; + }""" + .replace("$messageClass", field.messageType()) + .replace("$fieldDoc", field.comment().replaceAll("\n", "\n * ")) + .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$fieldToSet", fieldToSet) + .replace("$prefix", prefix) + .replace("$postfix", postfix) + .replace("$fieldType", field.javaFieldType()) + .replaceAll("\n", "\n" + FIELD_INDENT)); + } + + // add nice method for message fields with list types for varargs + if (field.repeated()) { + builderMethods.add( + """ + /** + * $fieldDoc + * + * @param values varargs value to be built into a list + * @return builder to continue building with + */ + public Builder $fieldName($baseType ... values) { + this.$fieldToSet = $prefix List.of(values) $postfix; + return this; + }""" + .replace( + "$baseType", + field.javaFieldType() + .substring( + "List<".length(), + field.javaFieldType().length() - 1)) + .replace("$fieldDoc", field.comment().replaceAll("\n", "\n * ")) + .replace("$fieldName", field.nameCamelFirstLower()) + .replace("$fieldToSet", fieldToSet) + .replace("$fieldType", field.javaFieldType()) + .replace("$prefix", prefix) + .replace("$postfix", postfix) + .replaceAll("\n", "\n" + FIELD_INDENT)); + } + } + + private static String generateBuilder( + final Protobuf3Parser.MessageDefContext msgDef, + List fields, + final ContextualLookupHelper lookupHelper) { + final String javaRecordName = msgDef.messageName().getText(); + List builderMethods = new ArrayList<>(); + for (Field field : fields) { + if (field.type() == Field.FieldType.ONE_OF) { + final OneOfField oneOfField = (OneOfField) field; + for (Field subField : oneOfField.fields()) { + generateBuilderMethods(builderMethods, subField); + } + } else { + generateBuilderMethods(builderMethods, field); + } + } + return """ + /** + * Builder class for easy creation, ideal for clean code where performance is not critical. In critical performance + * paths use the constructor directly. + */ + public static final class Builder { + $fields; + + /** + * Create an empty builder + */ + public Builder() {} + + /** + * Create a pre-populated builder + * $constructorParamDocs + */ + public Builder($constructorParams) { + $constructorCode; + } + + /** + * Build a new model record with data set on builder + * + * @return new model record with data set + */ + public $javaRecordName build() { + return new $javaRecordName($recordParams); + } + + $builderMethods + }""" + .replace( + "$fields", + fields.stream() + .map( + field -> + "private " + + field.javaFieldType() + + " " + + field.nameCamelFirstLower() + + " = " + + getDefaultValue( + field, msgDef, lookupHelper)) + .collect(Collectors.joining(";\n "))) + .replace( + "$constructorParamDocs", + fields.stream() + .map( + field -> + "\n * @param " + + field.nameCamelFirstLower() + + " " + + field.comment() + .replaceAll( + "\n", + "\n * " + + " " + .repeat( + field.nameCamelFirstLower() + .length()))) + .collect(Collectors.joining(", "))) + .replace( + "$constructorParams", + fields.stream() + .map( + field -> + field.javaFieldType() + + " " + + field.nameCamelFirstLower()) + .collect(Collectors.joining(", "))) + .replace( + "$constructorCode", + fields.stream() + .map( + field -> + "this." + + field.nameCamelFirstLower() + + " = " + + field.nameCamelFirstLower()) + .collect(Collectors.joining(";\n" + FIELD_INDENT + FIELD_INDENT))) + .replace("$javaRecordName", javaRecordName) + .replace( + "$recordParams", + fields.stream() + .map(Field::nameCamelFirstLower) + .collect(Collectors.joining(", "))) + .replace( + "$builderMethods", + builderMethods.stream().collect(Collectors.joining("\n\n" + FIELD_INDENT))) + .replaceAll("\n", "\n" + FIELD_INDENT); + } + + private static String getDefaultValue( + Field field, + final Protobuf3Parser.MessageDefContext msgDef, + final ContextualLookupHelper lookupHelper) { + if (field.type() == Field.FieldType.ONE_OF) { + return lookupHelper.getFullyQualifiedMessageClassname(FileType.CODEC, msgDef) + + "." + + field.javaDefault(); + } else { + return field.javaDefault(); + } + } + + private static String generateConstructorCode(final Field f) { + StringBuilder sb = + new StringBuilder( + FIELD_INDENT + + """ + if ($fieldName == null) { + throw new NullPointerException("Parameter '$fieldName' must be supplied and can not be null"); + }""" + .replace("$fieldName", f.nameCamelFirstLower())); + if (f instanceof final OneOfField oof) { + for (Field subField : oof.fields()) { + if (subField.optionalValueType()) { + sb.append( + """ + + // handle special case where protobuf does not have destination between a OneOf with optional + // value of empty vs an unset OneOf. + if($fieldName.kind() == $fieldUpperNameOneOfType.$subFieldNameUpper && $fieldName.value() == null) { + $fieldName = new OneOf<>($fieldUpperNameOneOfType.UNSET, null); + }""" + .replace("$fieldName", f.nameCamelFirstLower()) + .replace("$fieldUpperName", f.nameCamelFirstUpper()) + .replace( + "$subFieldNameUpper", + camelToUpperSnake(subField.name()))); + } + } + } + return sb.toString().replaceAll("\n", "\n" + FIELD_INDENT); + } +} diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java index 57267120..060f7601 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2023 Hedera Hashgraph, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.hedera.pbj.compiler.impl.generators; import com.hedera.pbj.compiler.impl.*; @@ -17,320 +33,425 @@ */ public final class TestGenerator implements Generator { - /** - * {@inheritDoc} - */ - public void generate(Protobuf3Parser.MessageDefContext msgDef, File destinationSrcDir, - File destinationTestSrcDir, final ContextualLookupHelper lookupHelper) throws IOException { - final var modelClassName = lookupHelper.getUnqualifiedClassForMessage(FileType.MODEL, msgDef); - final var testClassName = lookupHelper.getUnqualifiedClassForMessage(FileType.TEST, msgDef); - final String testPackage = lookupHelper.getPackageForMessage(FileType.TEST, msgDef); - final String protoCJavaFullQualifiedClass = lookupHelper.getFullyQualifiedMessageClassname(FileType.PROTOC,msgDef); - final File javaFile = Common.getJavaFile(destinationTestSrcDir, testPackage, testClassName); - final List fields = new ArrayList<>(); - final Set imports = new TreeSet<>(); - imports.add("com.hedera.pbj.runtime.io.buffer"); - imports.add(lookupHelper.getPackageForMessage(FileType.MODEL, msgDef)); - for(var item: msgDef.messageBody().messageElement()) { - if (item.messageDef() != null) { // process sub messages - generate(item.messageDef(), destinationSrcDir, destinationTestSrcDir, lookupHelper); - } else if (item.oneof() != null) { // process one ofs - final var field = new OneOfField(item.oneof(), modelClassName, lookupHelper); - fields.add(field); - field.addAllNeededImports(imports, true, false, true); - for(var subField : field.fields()) { - subField.addAllNeededImports(imports, true, false, true); - } - } else if (item.mapField() != null) { // process map fields - throw new IllegalStateException("Encountered a mapField that was not handled in "+ modelClassName); - } else if (item.field() != null && item.field().fieldName() != null) { - final var field = new SingleField(item.field(), lookupHelper); - fields.add(field); - if (field.type() == Field.FieldType.MESSAGE || field.type() == Field.FieldType.ENUM) { - field.addAllNeededImports(imports, true, false, true); - } - } else if (item.reserved() == null && item.optionStatement() == null) { - System.err.println("TestGenerator Warning - Unknown element: "+item+" -- "+item.getText()); - } - } - imports.add("java.util"); - try (FileWriter javaWriter = new FileWriter(javaFile)) { - javaWriter.write(""" - package %s; - - import com.google.protobuf.util.JsonFormat; - import com.google.protobuf.CodedOutputStream; - import com.hedera.pbj.runtime.io.buffer.BufferedData; - import com.hedera.pbj.runtime.JsonTools; - import org.junit.jupiter.params.ParameterizedTest; - import org.junit.jupiter.params.provider.MethodSource; - import com.hedera.pbj.runtime.test.*; - import java.util.stream.IntStream; - import java.util.stream.Stream; - import java.nio.ByteBuffer; - import java.nio.CharBuffer; - %s - - import com.google.protobuf.CodedInputStream; - import com.google.protobuf.WireFormat; - import java.io.IOException; - import java.nio.charset.StandardCharsets; - - import static com.hedera.pbj.runtime.ProtoTestTools.*; - import static org.junit.jupiter.api.Assertions.*; - - /** - * Unit Test for %s model object. Generate based on protobuf schema. - */ - public final class %s { - %s - %s - } - """.formatted( - testPackage, - imports.isEmpty() ? "" : imports.stream() - .filter(input -> !input.equals(testPackage)) - .collect(Collectors.joining(".*;\nimport ","\nimport ",".*;\n")), - modelClassName, - testClassName, - generateTestMethod(modelClassName, protoCJavaFullQualifiedClass) - .replaceAll("\n","\n"+ Common.FIELD_INDENT), - generateModelTestArgumentsMethod(modelClassName, fields) - .replaceAll("\n","\n"+ Common.FIELD_INDENT) - ) - ); - } - } + /** + * {@inheritDoc} + */ + public void generate( + Protobuf3Parser.MessageDefContext msgDef, + File destinationSrcDir, + File destinationTestSrcDir, + final ContextualLookupHelper lookupHelper) + throws IOException { + final var modelClassName = + lookupHelper.getUnqualifiedClassForMessage(FileType.MODEL, msgDef); + final var testClassName = lookupHelper.getUnqualifiedClassForMessage(FileType.TEST, msgDef); + final String testPackage = lookupHelper.getPackageForMessage(FileType.TEST, msgDef); + final String protoCJavaFullQualifiedClass = + lookupHelper.getFullyQualifiedMessageClassname(FileType.PROTOC, msgDef); + final File javaFile = Common.getJavaFile(destinationTestSrcDir, testPackage, testClassName); + final List fields = new ArrayList<>(); + final Set imports = new TreeSet<>(); + imports.add("com.hedera.pbj.runtime.io.buffer"); + imports.add(lookupHelper.getPackageForMessage(FileType.MODEL, msgDef)); + for (var item : msgDef.messageBody().messageElement()) { + if (item.messageDef() != null) { // process sub messages + generate(item.messageDef(), destinationSrcDir, destinationTestSrcDir, lookupHelper); + } else if (item.oneof() != null) { // process one ofs + final var field = new OneOfField(item.oneof(), modelClassName, lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, false, true); + for (var subField : field.fields()) { + subField.addAllNeededImports(imports, true, false, true); + } + } else if (item.mapField() != null) { // process map fields + throw new IllegalStateException( + "Encountered a mapField that was not handled in " + modelClassName); + } else if (item.field() != null && item.field().fieldName() != null) { + final var field = new SingleField(item.field(), lookupHelper); + fields.add(field); + if (field.type() == Field.FieldType.MESSAGE + || field.type() == Field.FieldType.ENUM) { + field.addAllNeededImports(imports, true, false, true); + } + } else if (item.reserved() == null && item.optionStatement() == null) { + System.err.println( + "TestGenerator Warning - Unknown element: " + + item + + " -- " + + item.getText()); + } + } + imports.add("java.util"); + try (FileWriter javaWriter = new FileWriter(javaFile)) { + javaWriter.write( + """ + package %s; + + import com.google.protobuf.util.JsonFormat; + import com.google.protobuf.CodedOutputStream; + import com.hedera.pbj.runtime.io.buffer.BufferedData; + import com.hedera.pbj.runtime.JsonTools; + import org.junit.jupiter.params.ParameterizedTest; + import org.junit.jupiter.params.provider.MethodSource; + import com.hedera.pbj.runtime.test.*; + import java.util.stream.IntStream; + import java.util.stream.Stream; + import java.nio.ByteBuffer; + import java.nio.CharBuffer; + %s + + import com.google.protobuf.CodedInputStream; + import com.google.protobuf.WireFormat; + import java.io.IOException; + import java.nio.charset.StandardCharsets; + + import static com.hedera.pbj.runtime.ProtoTestTools.*; + import static org.junit.jupiter.api.Assertions.*; + + /** + * Unit Test for %s model object. Generate based on protobuf schema. + */ + public final class %s { + %s + %s + } + """ + .formatted( + testPackage, + imports.isEmpty() + ? "" + : imports.stream() + .filter(input -> !input.equals(testPackage)) + .collect( + Collectors.joining( + ".*;\nimport ", + "\nimport ", + ".*;\n")), + modelClassName, + testClassName, + generateTestMethod(modelClassName, protoCJavaFullQualifiedClass) + .replaceAll("\n", "\n" + Common.FIELD_INDENT), + generateModelTestArgumentsMethod(modelClassName, fields) + .replaceAll("\n", "\n" + Common.FIELD_INDENT))); + } + } + + private static String generateModelTestArgumentsMethod( + final String modelClassName, final List fields) { + return """ + /** + * List of all valid arguments for testing, built as a static list, so we can reuse it. + */ + public static final List<%s> ARGUMENTS; + + static { + %s + // work out the longest of all the lists of args as that is how many test cases we need + final int maxValues = IntStream.of( + %s + ).max().getAsInt(); + // create new stream of model objects using lists above as constructor params + ARGUMENTS = IntStream.range(0,maxValues) + .mapToObj(i -> new %s( + %s + )).toList(); + } + + /** + * Create a stream of all test permutations of the %s class we are testing. This is reused by other tests + * as well that have model objects with fields of this type. + * + * @return stream of model objects for all test cases + */ + public static Stream> createModelTestArguments() { + return ARGUMENTS.stream().map(NoToStringWrapper::new); + } + """ + .formatted( + modelClassName, + fields.stream() + .map( + f -> + "final var " + + f.nameCamelFirstLower() + + "List = " + + generateTestData( + modelClassName, + f, + f.optionalValueType(), + f.repeated()) + + ";") + .collect(Collectors.joining("\n" + Common.FIELD_INDENT)), + fields.stream() + .map(f -> f.nameCamelFirstLower() + "List.size()") + .collect( + Collectors.joining( + ",\n" + Common.FIELD_INDENT + Common.FIELD_INDENT)), + modelClassName, + fields.stream() + .map( + field -> + "%sList.get(Math.min(i, %sList.size()-1))" + .formatted( + field.nameCamelFirstLower(), + field.nameCamelFirstLower())) + .collect( + Collectors.joining( + ",\n" + + Common.FIELD_INDENT + + Common.FIELD_INDENT + + Common.FIELD_INDENT + + Common.FIELD_INDENT)), + modelClassName, + modelClassName); + } + + private static String generateTestData( + String modelClassName, Field field, boolean optional, boolean repeated) { + if (optional) { + + Field.FieldType convertedFieldType = getOptionalConvertedFieldType(field); + return """ + addNull(%s)""" + .formatted( + getOptionsForFieldType(convertedFieldType, convertedFieldType.javaType)) + .replaceAll("\n", "\n" + Common.FIELD_INDENT + Common.FIELD_INDENT); + } else if (repeated) { + final String optionsList = + generateTestData(modelClassName, field, field.optionalValueType(), false); + return """ + generateListArguments(%s)""" + .formatted(optionsList) + .replaceAll("\n", "\n" + Common.FIELD_INDENT + Common.FIELD_INDENT); + } else if (field instanceof final OneOfField oneOf) { + final List options = new ArrayList<>(); + for (var subField : oneOf.fields()) { + if (subField instanceof SingleField) { + final String enumValueName = Common.camelToUpperSnake(subField.name()); + // special cases to break cyclic dependencies + if (!("THRESHOLD_KEY".equals(enumValueName) + || "KEY_LIST".equals(enumValueName) + || "THRESHOLD_SIGNATURE".equals(enumValueName) + || "SIGNATURE_LIST".equals(enumValueName))) { + final String listStr; + if (subField.optionalValueType()) { + Field.FieldType convertedSubFieldType = + getOptionalConvertedFieldType(subField); + listStr = + getOptionsForFieldType( + convertedSubFieldType, convertedSubFieldType.javaType); + } else { + listStr = + getOptionsForFieldType( + subField.type(), + ((SingleField) subField).javaFieldTypeForTest()); + } + options.add( + listStr + + """ + .stream() + .map(value -> new OneOf<>(%sOneOfType.%s, value)) + .toList()""" + .formatted( + modelClassName + + "." + + field.nameCamelFirstUpper(), + enumValueName) + .replaceAll( + "\n", + "\n" + + Common.FIELD_INDENT + + Common.FIELD_INDENT)); + } + } else { + System.err.println( + "Did not expect a OneOfField in a OneOfField. In " + + "modelClassName=" + + modelClassName + + " field=" + + field + + " subField=" + + subField); + } + } + return """ + Stream.of( + List.of(new OneOf<>(%sOneOfType.UNSET, null)), + %s + ).flatMap(List::stream).toList()""" + .formatted( + modelClassName + "." + field.nameCamelFirstUpper(), + options.stream() + .collect(Collectors.joining(",\n" + Common.FIELD_INDENT))) + .replaceAll("\n", "\n" + Common.FIELD_INDENT + Common.FIELD_INDENT); + } else { + return getOptionsForFieldType( + field.type(), ((SingleField) field).javaFieldTypeForTest()); + } + } + + private static Field.FieldType getOptionalConvertedFieldType(final Field field) { + return switch (field.messageType()) { + case "StringValue" -> Field.FieldType.STRING; + case "Int32Value" -> Field.FieldType.INT32; + case "UInt32Value" -> Field.FieldType.UINT32; + case "Int64Value" -> Field.FieldType.INT64; + case "UInt64Value" -> Field.FieldType.UINT64; + case "FloatValue" -> Field.FieldType.FLOAT; + case "DoubleValue" -> Field.FieldType.DOUBLE; + case "BoolValue" -> Field.FieldType.BOOL; + case "BytesValue" -> Field.FieldType.BYTES; + default -> Field.FieldType.MESSAGE; + }; + } + + private static String getOptionsForFieldType(Field.FieldType fieldType, String javaFieldType) { + return switch (fieldType) { + case INT32, SINT32, SFIXED32 -> "INTEGER_TESTS_LIST"; + case UINT32, FIXED32 -> "UNSIGNED_INTEGER_TESTS_LIST"; + case INT64, SINT64, SFIXED64 -> "LONG_TESTS_LIST"; + case UINT64, FIXED64 -> "UNSIGNED_LONG_TESTS_LIST"; + case FLOAT -> "FLOAT_TESTS_LIST"; + case DOUBLE -> "DOUBLE_TESTS_LIST"; + case BOOL -> "BOOLEAN_TESTS_LIST"; + case STRING -> "STRING_TESTS_LIST"; + case BYTES -> "BYTES_TESTS_LIST"; + case ENUM -> "Arrays.asList(" + javaFieldType + ".values())"; + case ONE_OF -> throw new RuntimeException( + "Should never happen, should have been caught in generateTestData()"); + case MESSAGE -> javaFieldType + + FileAndPackageNamesConfig.TEST_JAVA_FILE_SUFFIX + + ".ARGUMENTS"; + }; + } + + /** + * Generate code for test method. The test method is designed to reuse thread local buffers. + * This is very important for performance as without this the tests quickly overwhelm the + * garbage collector. + * + * @param modelClassName The class name of the model object we are creating a test for + * @param protoCJavaFullQualifiedClass The qualified class name of the protoc generated object + * class + * @return Code for test method + */ + private static String generateTestMethod( + final String modelClassName, final String protoCJavaFullQualifiedClass) { + return """ + @ParameterizedTest + @MethodSource("createModelTestArguments") + public void test$modelClassNameAgainstProtoC(final NoToStringWrapper<$modelClassName> modelObjWrapper) throws Exception { + final $modelClassName modelObj = modelObjWrapper.getValue(); + // get reusable thread buffers + final var dataBuffer = getThreadLocalDataBuffer(); + final var dataBuffer2 = getThreadLocalDataBuffer2(); + final var byteBuffer = getThreadLocalByteBuffer(); + final var charBuffer = getThreadLocalCharBuffer(); + final var charBuffer2 = getThreadLocalCharBuffer2(); + + // model to bytes with PBJ + $modelClassName.PROTOBUF.write(modelObj, dataBuffer); + // clamp limit to bytes written + dataBuffer.limit(dataBuffer.position()); + + // copy bytes to ByteBuffer + dataBuffer.resetPosition(); + final int protoBufByteCount = (int)dataBuffer.remaining(); + dataBuffer.readBytes(byteBuffer); + byteBuffer.flip(); + + // read proto bytes with ProtoC to make sure it is readable and no parse exceptions are thrown + final $protocModelClass protoCModelObj = $protocModelClass.parseFrom(byteBuffer); + + // read proto bytes with PBJ parser + dataBuffer.resetPosition(); + final $modelClassName modelObj2 = $modelClassName.PROTOBUF.parse(dataBuffer); - private static String generateModelTestArgumentsMethod(final String modelClassName, final List fields) { - return """ - /** - * List of all valid arguments for testing, built as a static list, so we can reuse it. - */ - public static final List<%s> ARGUMENTS; - - static { - %s - // work out the longest of all the lists of args as that is how many test cases we need - final int maxValues = IntStream.of( - %s - ).max().getAsInt(); - // create new stream of model objects using lists above as constructor params - ARGUMENTS = IntStream.range(0,maxValues) - .mapToObj(i -> new %s( - %s - )).toList(); - } - - /** - * Create a stream of all test permutations of the %s class we are testing. This is reused by other tests - * as well that have model objects with fields of this type. - * - * @return stream of model objects for all test cases - */ - public static Stream> createModelTestArguments() { - return ARGUMENTS.stream().map(NoToStringWrapper::new); - } - """.formatted( - modelClassName, - fields.stream() - .map(f -> "final var "+f.nameCamelFirstLower()+"List = "+generateTestData(modelClassName, f, f.optionalValueType(), f.repeated())+";") - .collect(Collectors.joining("\n"+ Common.FIELD_INDENT)), - fields.stream() - .map(f -> f.nameCamelFirstLower()+"List.size()") - .collect(Collectors.joining(",\n"+ Common.FIELD_INDENT+ Common.FIELD_INDENT)), - modelClassName, - fields.stream().map(field -> "%sList.get(Math.min(i, %sList.size()-1))".formatted( - field.nameCamelFirstLower(), - field.nameCamelFirstLower() - )) - .collect(Collectors.joining(",\n"+ Common.FIELD_INDENT+ Common.FIELD_INDENT+ Common.FIELD_INDENT+ Common.FIELD_INDENT)), - modelClassName, - modelClassName - ); - } + // check the read back object is equal to written original one + //assertEquals(modelObj.toString(), modelObj2.toString()); + assertEquals(modelObj, modelObj2); - private static String generateTestData(String modelClassName, Field field, boolean optional, boolean repeated) { - if (optional) { + // model to bytes with ProtoC writer + byteBuffer.clear(); + final CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteBuffer); + protoCModelObj.writeTo(codedOutput); + codedOutput.flush(); + byteBuffer.flip(); + // copy to a data buffer + dataBuffer2.writeBytes(byteBuffer); + dataBuffer2.flip(); - Field.FieldType convertedFieldType = getOptionalConvertedFieldType(field); - return """ - addNull(%s)""" - .formatted(getOptionsForFieldType(convertedFieldType, convertedFieldType.javaType)) - .replaceAll("\n","\n"+ Common.FIELD_INDENT+ Common.FIELD_INDENT); - } else if (repeated) { - final String optionsList = generateTestData(modelClassName, field, field.optionalValueType(), false); - return """ - generateListArguments(%s)""".formatted(optionsList) - .replaceAll("\n","\n"+ Common.FIELD_INDENT+ Common.FIELD_INDENT); - } else if(field instanceof final OneOfField oneOf) { - final List options = new ArrayList<>(); - for (var subField: oneOf.fields()) { - if(subField instanceof SingleField) { - final String enumValueName = Common.camelToUpperSnake(subField.name()); - // special cases to break cyclic dependencies - if (!("THRESHOLD_KEY".equals(enumValueName) || "KEY_LIST".equals(enumValueName) - || "THRESHOLD_SIGNATURE".equals(enumValueName)|| "SIGNATURE_LIST".equals(enumValueName))) { - final String listStr; - if (subField.optionalValueType()) { - Field.FieldType convertedSubFieldType = getOptionalConvertedFieldType(subField); - listStr = getOptionsForFieldType(convertedSubFieldType, convertedSubFieldType.javaType); - } else { - listStr = getOptionsForFieldType(subField.type(), ((SingleField) subField).javaFieldTypeForTest()); - } - options.add(listStr + """ - .stream() - .map(value -> new OneOf<>(%sOneOfType.%s, value)) - .toList()""".formatted( - modelClassName + "." + field.nameCamelFirstUpper(), - enumValueName - ).replaceAll("\n", "\n" + Common.FIELD_INDENT + Common.FIELD_INDENT) - ); - } - } else { - System.err.println("Did not expect a OneOfField in a OneOfField. In "+ - "modelClassName="+modelClassName+" field="+field+" subField="+subField); - } - } - return """ - Stream.of( - List.of(new OneOf<>(%sOneOfType.UNSET, null)), - %s - ).flatMap(List::stream).toList()""".formatted( - modelClassName+"."+field.nameCamelFirstUpper(), - options.stream().collect(Collectors.joining(",\n"+ Common.FIELD_INDENT)) - ).replaceAll("\n","\n"+ Common.FIELD_INDENT+ Common.FIELD_INDENT); - } else { - return getOptionsForFieldType(field.type(), ((SingleField)field).javaFieldTypeForTest()); - } - } + // compare written bytes + assertEquals(dataBuffer, dataBuffer2); - private static Field.FieldType getOptionalConvertedFieldType(final Field field) { - return switch (field.messageType()) { - case "StringValue" -> Field.FieldType.STRING; - case "Int32Value" -> Field.FieldType.INT32; - case "UInt32Value" -> Field.FieldType.UINT32; - case "Int64Value" -> Field.FieldType.INT64; - case "UInt64Value" -> Field.FieldType.UINT64; - case "FloatValue" -> Field.FieldType.FLOAT; - case "DoubleValue" -> Field.FieldType.DOUBLE; - case "BoolValue" -> Field.FieldType.BOOL; - case "BytesValue" -> Field.FieldType.BYTES; - default -> Field.FieldType.MESSAGE; - }; - } + // parse those bytes again with PBJ + dataBuffer2.resetPosition(); + final $modelClassName modelObj3 = $modelClassName.PROTOBUF.parse(dataBuffer2); + assertEquals(modelObj, modelObj3); - private static String getOptionsForFieldType(Field.FieldType fieldType, String javaFieldType) { - return switch (fieldType) { - case INT32, SINT32, SFIXED32 -> "INTEGER_TESTS_LIST"; - case UINT32, FIXED32 -> "UNSIGNED_INTEGER_TESTS_LIST"; - case INT64, SINT64, SFIXED64 -> "LONG_TESTS_LIST"; - case UINT64, FIXED64 -> "UNSIGNED_LONG_TESTS_LIST"; - case FLOAT -> "FLOAT_TESTS_LIST"; - case DOUBLE -> "DOUBLE_TESTS_LIST"; - case BOOL -> "BOOLEAN_TESTS_LIST"; - case STRING -> "STRING_TESTS_LIST"; - case BYTES -> "BYTES_TESTS_LIST"; - case ENUM -> "Arrays.asList(" + javaFieldType + ".values())"; - case ONE_OF -> throw new RuntimeException("Should never happen, should have been caught in generateTestData()"); - case MESSAGE -> javaFieldType + FileAndPackageNamesConfig.TEST_JAVA_FILE_SUFFIX + ".ARGUMENTS"; - }; - } + // check measure methods + dataBuffer2.resetPosition(); + assertEquals(protoBufByteCount, $modelClassName.PROTOBUF.measure(dataBuffer2)); + assertEquals(protoBufByteCount, $modelClassName.PROTOBUF.measureRecord(modelObj)); - /** - * Generate code for test method. The test method is designed to reuse thread local buffers. This is - * very important for performance as without this the tests quickly overwhelm the garbage collector. - * - * @param modelClassName The class name of the model object we are creating a test for - * @param protoCJavaFullQualifiedClass The qualified class name of the protoc generated object class - * @return Code for test method - */ - private static String generateTestMethod(final String modelClassName, final String protoCJavaFullQualifiedClass) { - return """ - @ParameterizedTest - @MethodSource("createModelTestArguments") - public void test$modelClassNameAgainstProtoC(final NoToStringWrapper<$modelClassName> modelObjWrapper) throws Exception { - final $modelClassName modelObj = modelObjWrapper.getValue(); - // get reusable thread buffers - final var dataBuffer = getThreadLocalDataBuffer(); - final var dataBuffer2 = getThreadLocalDataBuffer2(); - final var byteBuffer = getThreadLocalByteBuffer(); - final var charBuffer = getThreadLocalCharBuffer(); - final var charBuffer2 = getThreadLocalCharBuffer2(); - - // model to bytes with PBJ - $modelClassName.PROTOBUF.write(modelObj, dataBuffer); - // clamp limit to bytes written - dataBuffer.limit(dataBuffer.position()); - - // copy bytes to ByteBuffer - dataBuffer.resetPosition(); - final int protoBufByteCount = (int)dataBuffer.remaining(); - dataBuffer.readBytes(byteBuffer); - byteBuffer.flip(); - - // read proto bytes with ProtoC to make sure it is readable and no parse exceptions are thrown - final $protocModelClass protoCModelObj = $protocModelClass.parseFrom(byteBuffer); - - // read proto bytes with PBJ parser - dataBuffer.resetPosition(); - final $modelClassName modelObj2 = $modelClassName.PROTOBUF.parse(dataBuffer); - - // check the read back object is equal to written original one - //assertEquals(modelObj.toString(), modelObj2.toString()); - assertEquals(modelObj, modelObj2); - - // model to bytes with ProtoC writer - byteBuffer.clear(); - final CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteBuffer); - protoCModelObj.writeTo(codedOutput); - codedOutput.flush(); - byteBuffer.flip(); - // copy to a data buffer - dataBuffer2.writeBytes(byteBuffer); - dataBuffer2.flip(); - - // compare written bytes - assertEquals(dataBuffer, dataBuffer2); + // check fast equals + dataBuffer2.resetPosition(); + assertTrue($modelClassName.PROTOBUF.fastEquals(modelObj, dataBuffer2)); - // parse those bytes again with PBJ - dataBuffer2.resetPosition(); - final $modelClassName modelObj3 = $modelClassName.PROTOBUF.parse(dataBuffer2); - assertEquals(modelObj, modelObj3); + // Test toBytes() + Bytes bytes = $modelClassName.PROTOBUF.toBytes(modelObj); + final var dataBuffer3 = getThreadLocalDataBuffer(); + bytes.toReadableSequentialData().readBytes(dataBuffer3); + byte[] readBytes = new byte[(int)dataBuffer3.length()]; + dataBuffer3.getBytes(0, readBytes); + assertArrayEquals(bytes.toByteArray(), readBytes); - // check measure methods - dataBuffer2.resetPosition(); - assertEquals(protoBufByteCount, $modelClassName.PROTOBUF.measure(dataBuffer2)); - assertEquals(protoBufByteCount, $modelClassName.PROTOBUF.measureRecord(modelObj)); - - // check fast equals - dataBuffer2.resetPosition(); - assertTrue($modelClassName.PROTOBUF.fastEquals(modelObj, dataBuffer2)); + // Test JSON Writing + final CharBufferToWritableSequentialData charBufferToWritableSequentialData = new CharBufferToWritableSequentialData(charBuffer); + $modelClassName.JSON.write(modelObj,charBufferToWritableSequentialData); + charBuffer.flip(); + JsonFormat.printer().appendTo(protoCModelObj, charBuffer2); + charBuffer2.flip(); + assertEquals(charBuffer2, charBuffer); - // Test toBytes() - Bytes bytes = $modelClassName.PROTOBUF.toBytes(modelObj); - final var dataBuffer3 = getThreadLocalDataBuffer(); - bytes.toReadableSequentialData().readBytes(dataBuffer3); - byte[] readBytes = new byte[(int)dataBuffer3.length()]; - dataBuffer3.getBytes(0, readBytes); - assertArrayEquals(bytes.toByteArray(), readBytes); + // Test JSON Reading + final $modelClassName jsonReadPbj = $modelClassName.JSON.parse(JsonTools.parseJson(charBuffer), false); + assertEquals(modelObj, jsonReadPbj); - // Test JSON Writing - final CharBufferToWritableSequentialData charBufferToWritableSequentialData = new CharBufferToWritableSequentialData(charBuffer); - $modelClassName.JSON.write(modelObj,charBufferToWritableSequentialData); - charBuffer.flip(); - JsonFormat.printer().appendTo(protoCModelObj, charBuffer2); - charBuffer2.flip(); - assertEquals(charBuffer2, charBuffer); - - // Test JSON Reading - final $modelClassName jsonReadPbj = $modelClassName.JSON.parse(JsonTools.parseJson(charBuffer), false); - assertEquals(modelObj, jsonReadPbj); - } - """ - .replace("$modelClassName",modelClassName) - .replace("$protocModelClass",protoCJavaFullQualifiedClass) - .replace("$modelClassName",modelClassName) - ; - } + // Very slow for now. Too much garbage generated to enable in general case. + // // Test hashCode and equals() + // Stream> objects = createModelTestArguments(); + // Object[] objArray = objects.toArray(); + // for (int i = 0; i < objArray.length; i++) { + // for (int j = i; j < objArray.length; j++) { + // if (objArray[i].hashCode() != objArray[i].hashCode()) { + // fail("Same object, different hash."); + // } + // if (objArray[j].hashCode() != objArray[j].hashCode()) { + // fail("Same object, different hash 1."); + // } + // if (objArray[i].hashCode() == objArray[j].hashCode()) { + // if (!objArray[i].equals(objArray[j])) { + // fail("equalsHash, different objects."); + // } + // } + // } + // } + // + // Map map = new HashMap<>(); + // map.put(objArray[0], objArray[0]); + // for (int i = 1; i < objArray.length; i++) { + // Object o = map.put(objArray[i], objArray[i]); + // if (o != null) { + // Object existing = map.get(objArray[i]); + // assertEquals(existing.hashCode(), objArray[i].hashCode()); + // assertEquals(existing, objArray[i]); + // } + // } + } + """ + .replace("$modelClassName", modelClassName) + .replace("$protocModelClass", protoCJavaFullQualifiedClass) + .replace("$modelClassName", modelClassName); + } } diff --git a/pbj-integration-tests/build.gradle.kts b/pbj-integration-tests/build.gradle.kts index e65cb4e9..0e0ab24e 100644 --- a/pbj-integration-tests/build.gradle.kts +++ b/pbj-integration-tests/build.gradle.kts @@ -81,6 +81,9 @@ jmh { // includes.add("AccountDetailsBench") // includes.add("JsonBench") // includes.add("VarIntBench") +// includes.add("HashBench") +// includes.add("EqualsHashCodeBench"); + jmhVersion.set("1.35") includeTests.set(true) // jvmArgsAppend.add("-XX:MaxInlineSize=100 -XX:MaxInlineLevel=20") diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/EqualsHashCodeBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/EqualsHashCodeBench.java new file mode 100644 index 00000000..66584db7 --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/EqualsHashCodeBench.java @@ -0,0 +1,62 @@ +package com.hedera.pbj.intergration.jmh; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; +import com.hedera.pbj.intergration.test.TestHashFunctions; +import com.hedera.pbj.runtime.MalformedProtobufException; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.runtime.io.stream.ReadableStreamingData; +import com.hedera.pbj.test.proto.pbj.Hasheval; +import com.hedera.pbj.test.proto.pbj.Suit; +import com.hedera.pbj.test.proto.pbj.TimestampTest; +import com.hedera.pbj.test.proto.pbj.tests.HashevalTest; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.*; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@SuppressWarnings("unused") +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 4, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +public class EqualsHashCodeBench { + private TimestampTest testStamp; + private TimestampTest testStamp1; + + public EqualsHashCodeBench() { + testStamp = new TimestampTest(987L, 123); + testStamp1 = new TimestampTest(987L, 122); + } + + @Benchmark + @OperationsPerInvocation(1050) + public void benchHashCode(Blackhole blackhole) throws IOException { + for (int i = 0; i < 1050; i++) { + testStamp.hashCode(); + } + } + + @Benchmark + @OperationsPerInvocation(1050) + public void benchEquals(Blackhole blackhole) throws IOException { + for (int i = 0; i < 1050; i++) { + testStamp.equals(testStamp); + } + } + + @Benchmark + @OperationsPerInvocation(1050) + public void benchNotEquals(Blackhole blackhole) throws IOException { + for (int i = 0; i < 1050; i++) { + testStamp.equals(testStamp1); + } + } +} \ No newline at end of file diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/HashBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/HashBench.java new file mode 100644 index 00000000..beb3970a --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/HashBench.java @@ -0,0 +1,56 @@ +package com.hedera.pbj.intergration.jmh; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; +import com.hedera.pbj.intergration.test.TestHashFunctions; +import com.hedera.pbj.runtime.MalformedProtobufException; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.runtime.io.stream.ReadableStreamingData; +import com.hedera.pbj.test.proto.pbj.Hasheval; +import com.hedera.pbj.test.proto.pbj.Suit; +import com.hedera.pbj.test.proto.pbj.TimestampTest; +import com.hedera.pbj.test.proto.pbj.tests.HashevalTest; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.*; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@SuppressWarnings("unused") +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 4, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +public class HashBench { + private Hasheval hasheval; + + public HashBench() { + TimestampTest tst = new TimestampTest(987L, 123); + hasheval = new Hasheval(1, -1, 2, 3, -2, + 123f, 7L, -7L, 123L, 234L, + -345L, 456.789D, true, Suit.ACES, tst, "FooBarKKKKHHHHOIOIOI", + Bytes.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, (byte)255})); + } + + @Benchmark + @OperationsPerInvocation(1050) + public void hashBenchSHA256(Blackhole blackhole) throws IOException { + for (int i = 0; i < 1050; i++) { + TestHashFunctions.hash1(hasheval); + } + } + + @Benchmark + @OperationsPerInvocation(1050) + public void hashBenchFieldWise(Blackhole blackhole) throws IOException { + for (int i = 0; i < 1050; i++) { + TestHashFunctions.hash2(hasheval); + } + } +} \ No newline at end of file diff --git a/pbj-integration-tests/src/main/proto/hasheval.proto b/pbj-integration-tests/src/main/proto/hasheval.proto new file mode 100644 index 00000000..0d9b01f7 --- /dev/null +++ b/pbj-integration-tests/src/main/proto/hasheval.proto @@ -0,0 +1,34 @@ +syntax = "proto3"; + +package proto; + +option java_package = "com.hedera.pbj.test.proto.java"; +option java_multiple_files = true; +// <<>> This comment is special code for setting PBJ Compiler java package + +import "timestampTest.proto"; +import "google/protobuf/wrappers.proto"; +import "everything.proto"; + +/** + * Example protobuf containing examples of all types + */ +message Hasheval { + int32 int32Number = 1; + sint32 sint32Number = 2; + uint32 uint32Number = 3; + fixed32 fixed32Number = 4; + sfixed32 sfixed32Number = 5; + float floatNumber = 6; + int64 int64Number = 7; + sint64 sint64Number = 8; + uint64 uint64Number = 9; + fixed64 fixed64Number = 10; + sfixed64 sfixed64Number = 11; + double doubleNumber = 12; + bool booleanField = 13; + Suit enumSuit = 14; + TimestampTest subObject = 15; + string text = 16; + bytes bytesField = 17; +} diff --git a/pbj-integration-tests/src/main/proto/timestampTest2.proto b/pbj-integration-tests/src/main/proto/timestampTest2.proto new file mode 100644 index 00000000..1eb2c3da --- /dev/null +++ b/pbj-integration-tests/src/main/proto/timestampTest2.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package proto; + +/** Test issue 87 */ +/** Test issue 87 */ + +/*- + * ‌ + * Hedera Network Services Protobuf + * ​ + * Copyright (C) 2018 - 2021 Hedera Hashgraph, LLC + * ​ + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ‍ + */ + +option java_package = "com.hedera.pbj.test.proto.java"; +option java_multiple_files = true; +// <<>> This comment is special code for setting PBJ Compiler java package + +/** + * An exact date and time. This is the same data structure as the protobuf Timestamp.proto (see the + * comments in https://github.com/google/protobuf/blob/master/src/google/protobuf/timestamp.proto) + */ +message TimestampTest2 { + /** + * Number of complete seconds since the start of the epoch + */ + int64 seconds = 1; + + /** + * Number of nanoseconds since the start of the last second + */ + int32 nanos = 2; + + /** + * Number of picoseconds since the start of the last nanosecond + */ + int32 pico = 3; +} + +/** + * An exact date and time, with a resolution of one second (no nanoseconds). + */ +message TimestampTestSeconds2 { + /** + * Number of complete seconds since the start of the epoch + */ + int64 seconds = 1; +} + diff --git a/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/HashEqualsTest.java b/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/HashEqualsTest.java new file mode 100644 index 00000000..71bf02db --- /dev/null +++ b/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/HashEqualsTest.java @@ -0,0 +1,90 @@ +package com.hedera.pbj.intergration.test; + +import org.junit.jupiter.api.Test; + +import com.hedera.pbj.test.proto.pbj.TimestampTest; +import com.hedera.pbj.test.proto.pbj.TimestampTest2; + +import static org.junit.jupiter.api.Assertions.*; + +class HasjhEqualsTest { + @Test + void differentObjectsWithDefaulEquals() { + TimestampTest tst = new TimestampTest(1, 2); + TimestampTest2 tst2 = new TimestampTest2(1, 2, 0); + + assertFalse(tst.equals(tst2)); + } + + @Test + void sameObjectsWithNoDefaulEquals() { + TimestampTest tst = new TimestampTest(3, 4); + TimestampTest tst1 = new TimestampTest(3, 4); + + assertEquals(tst, tst1); + } + + @Test + void sameObjectsWithDefaulNoEquals() { + TimestampTest tst = new TimestampTest(3, 4); + TimestampTest tst1 = new TimestampTest(3, 5); + + assertNotEquals(tst, tst1); + } + + @Test + void sameObjectsWithDefaulEquals() { + TimestampTest tst = new TimestampTest(0, 0); + TimestampTest tst1 = new TimestampTest(0, 0); + + assertEquals(tst, tst1); + } + + @Test + void differentObjectsWithDefaulHashCode() { + TimestampTest tst = new TimestampTest(0, 0); + TimestampTest2 tst2 = new TimestampTest2(0, 0, 0); + + assertEquals(tst.hashCode(), tst2.hashCode()); + } + + @Test + void differentObjectsWithNoDefaulHashCode() { + TimestampTest tst = new TimestampTest(1, 0); + TimestampTest2 tst2 = new TimestampTest2(1, 0, 0); + + assertEquals(tst.hashCode(), tst2.hashCode()); + } + + @Test + void differentObjectsWithNoDefaulHashCode1() { + TimestampTest tst = new TimestampTest(0, 0); + TimestampTest2 tst2 = new TimestampTest2(0, 0, 3); + + assertNotEquals(tst.hashCode(), tst2.hashCode()); + } + + @Test + void differentObjectsWithNoDefaulHashCode2() { + TimestampTest2 tst = new TimestampTest2(0, 0, 0); + TimestampTest2 tst2 = new TimestampTest2(0, 0, 0); + + assertEquals(tst.hashCode(), tst2.hashCode()); + } + + @Test + void differentObjectsWithNoDefaulHashCode3() { + TimestampTest2 tst = new TimestampTest2(1, 2, 3); + TimestampTest2 tst2 = new TimestampTest2(1, 2, 3); + + assertEquals(tst.hashCode(), tst2.hashCode()); + } + + @Test + void differentObjectsWithNoDefaulHashCode4() { + TimestampTest2 tst = new TimestampTest2(1, 4, 3); + TimestampTest2 tst2 = new TimestampTest2(1, 2, 3); + + assertNotEquals(tst.hashCode(), tst2.hashCode()); + } +} \ No newline at end of file diff --git a/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/TestHashFunctions.java b/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/TestHashFunctions.java new file mode 100644 index 00000000..3d81b383 --- /dev/null +++ b/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/TestHashFunctions.java @@ -0,0 +1,116 @@ +package com.hedera.pbj.intergration.test; + +import com.google.protobuf.CodedOutputStream; +import com.hedera.hapi.node.base.Timestamp; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.test.NoToStringWrapper; +import com.hedera.pbj.test.proto.pbj.Hasheval; +import com.hedera.pbj.test.proto.pbj.TimestampTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.List; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import static com.hedera.pbj.runtime.ProtoTestTools.INTEGER_TESTS_LIST; +import static com.hedera.pbj.runtime.ProtoTestTools.LONG_TESTS_LIST; +import static com.hedera.pbj.runtime.ProtoTestTools.getThreadLocalByteBuffer; +import static com.hedera.pbj.runtime.ProtoTestTools.getThreadLocalDataBuffer; +import static com.hedera.pbj.runtime.ProtoTestTools.getThreadLocalDataBuffer2; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Unit Test for TimestampTest model object. Generate based on protobuf schema. + */ +public final class TestHashFunctions { + public static int hash1(Hasheval hashEval) { + try { + byte[] hash = MessageDigest.getInstance("SHA-256").digest( + Hasheval.PROTOBUF.toBytes(hashEval).toByteArray()); + int res = hash[0] << 24 | hash[1] << 16 | hash[2] << 8 | hash[3]; + return processForBetterDistribution(res); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + public static int hash2(Hasheval hashEval) { + if (hashEval == null) return 0; + + int result = 1; + if (hashEval.int32Number() != Hasheval.DEFAULT.int32Number()) { + result = 31 * result + Integer.hashCode(hashEval.int32Number()); + } + if (hashEval.sint32Number() != Hasheval.DEFAULT.sint32Number()) { + result = 31 * result + Integer.hashCode(hashEval.sint32Number()); + } + if (hashEval.uint32Number() != Hasheval.DEFAULT.uint32Number()) { + result = 31 * result + Integer.hashCode(hashEval.uint32Number()); + } + if (hashEval.fixed32Number() != Hasheval.DEFAULT.fixed32Number()) { + result = 31 * result + Integer.hashCode(hashEval.fixed32Number()); + } + if (hashEval.sfixed32Number() != Hasheval.DEFAULT.sfixed32Number()) { + result = 31 * result + Integer.hashCode(hashEval.sfixed32Number()); + } + if (hashEval.floatNumber() != Hasheval.DEFAULT.floatNumber()) { + result = 31 * result + Float.hashCode(hashEval.floatNumber()); + } + if (hashEval.int64Number() != Hasheval.DEFAULT.int64Number()) { + result = 31 * result + Long.hashCode(hashEval.int64Number()); + } + if (hashEval.sint64Number() != Hasheval.DEFAULT.sint64Number()) { + result = 31 * result + Long.hashCode(hashEval.sint64Number()); + } + if (hashEval.uint64Number() != Hasheval.DEFAULT.uint64Number()) { + result = 31 * result + Long.hashCode(hashEval.uint64Number()); + } + if (hashEval.fixed64Number() != Hasheval.DEFAULT.fixed64Number()) { + result = 31 * result + Long.hashCode(hashEval.fixed64Number()); + } + if (hashEval.sfixed64Number() != Hasheval.DEFAULT.sfixed64Number()) { + result = 31 * result + Long.hashCode(hashEval.sfixed64Number()); + } + if (hashEval.doubleNumber() != Hasheval.DEFAULT.doubleNumber()) { + result = 31 * result + Double.hashCode(hashEval.doubleNumber()); + } + if (hashEval.booleanField() != Hasheval.DEFAULT.booleanField()) { + result = 31 * result + Boolean.hashCode(hashEval.booleanField()); + } + if (hashEval.enumSuit() != Hasheval.DEFAULT.enumSuit()) { + result = 31 * result + hashEval.enumSuit().hashCode(); + } + if (hashEval.subObject() != Hasheval.DEFAULT.subObject()) { + TimestampTest sub = hashEval.subObject(); + if (sub.nanos() != sub.DEFAULT.nanos()) { + result = 31 * result + Integer.hashCode(sub.nanos()); + } + if (sub.seconds() != sub.DEFAULT.seconds()) { + result = 31 * result + Long.hashCode(sub.seconds()); + } + } + if (hashEval.text() != Hasheval.DEFAULT.text()) { + result = 31 * result + hashEval.text().hashCode(); + } + if (hashEval.bytesField() != Hasheval.DEFAULT.bytesField()) { + result = 31 * result + (hashEval.bytesField() == null ? 0 : hashEval.bytesField().hashCode()); + } + + return processForBetterDistribution(result); + } + + private static int processForBetterDistribution(int val) { + val += val << 30; + val ^= val >>> 27; + val += val << 16; + val ^= val >>> 20; + val += val << 5; + val ^= val >>> 18; + val += val << 10; + val ^= val >>> 24; + val += val << 30; + return val; + } +} \ No newline at end of file