diff --git a/pom.xml b/pom.xml index 0d669e4..be5cae5 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ com.rtbhouse avro-fastserde - 1.0.4 + 1.0.5-SNAPSHOT jar avro-fastserde diff --git a/src/main/java/com/rtbhouse/utils/avro/FastDeserializerGenerator.java b/src/main/java/com/rtbhouse/utils/avro/FastDeserializerGenerator.java index a16fab5..cda71f8 100644 --- a/src/main/java/com/rtbhouse/utils/avro/FastDeserializerGenerator.java +++ b/src/main/java/com/rtbhouse/utils/avro/FastDeserializerGenerator.java @@ -53,12 +53,15 @@ public class FastDeserializerGenerator extends FastDeserializerGeneratorBase< private Map>> exceptionFromMethodMap = new HashMap<>(); private SchemaAssistant schemaAssistant; + private JClass string; + FastDeserializerGenerator(boolean useGenericTypes, Schema writer, Schema reader, File destination, ClassLoader classLoader, String compileClassPath) { super(writer, reader, destination, classLoader, compileClassPath); this.useGenericTypes = useGenericTypes; - this.schemaAssistant = new SchemaAssistant(codeModel, useGenericTypes); + this.schemaAssistant = new SchemaAssistant(codeModel, useGenericTypes, false); + this.string = codeModel.ref(String.class); } public FastDeserializer generateDeserializer() { @@ -88,9 +91,8 @@ public FastDeserializer generateDeserializer() { } JClass readerSchemaClass = schemaAssistant.classFromSchema(reader); - JClass writerSchemaClass = schemaAssistant.classFromSchema(aliasedWriterSchema); - deserializerClass._implements(codeModel.ref(FastDeserializer.class).narrow(writerSchemaClass)); + deserializerClass._implements(codeModel.ref(FastDeserializer.class).narrow(readerSchemaClass)); JMethod deserializeMethod = deserializerClass.method(JMod.PUBLIC, readerSchemaClass, "deserialize"); JBlock topLevelDeserializeBlock = new JBlock(); @@ -262,8 +264,12 @@ private void processRecord(JVar recordSchemaVar, String recordName, final Schema processComplexType(fieldSchemaVar, field.name(), field.schema(), readerFieldSchema, methodBody, action, putExpressionInRecord); } else { - processSimpleType(field.schema(), methodBody, action, putExpressionInRecord); - + // to preserve reader string specific options use reader field schema + if (action.getShouldRead() && Schema.Type.STRING.equals(field.schema().getType())) { + processSimpleType(readerFieldSchema, methodBody, action, putExpressionInRecord); + } else { + processSimpleType(field.schema(), methodBody, action, putExpressionInRecord); + } } } @@ -489,7 +495,13 @@ private void processUnion(JVar unionSchemaVar, final String name, final Schema u processComplexType(optionSchemaVar, optionName, optionSchema, readerOptionSchema, thenBlock, unionAction, putValueIntoParent); } else { - processSimpleType(optionSchema, thenBlock, unionAction, putValueIntoParent); + // to preserve reader string specific options use reader option schema + if (action.getShouldRead() && Schema.Type.STRING.equals(optionSchema.getType())) { + processSimpleType(readerOptionSchema, thenBlock, unionAction, putValueIntoParent); + + } else { + processSimpleType(optionSchema, thenBlock, unionAction, putValueIntoParent); + } } } } @@ -518,14 +530,14 @@ private void processArray(JVar arraySchemaVar, final String name, final Schema a action = FieldAction.fromValues(arraySchema.getElementType().getType(), false, EMPTY_SYMBOL); } - final JVar arrayVar = action.getShouldRead() ? declareValueVar(name, arraySchema, parentBody) : null; + final JVar arrayVar = action.getShouldRead() ? declareValueVar(name, readerArraySchema, parentBody) : null; JVar chunkLen = parentBody.decl(codeModel.LONG, getVariableName("chunkLen"), JExpr.direct(DECODER + ".readArrayStart()")); JConditional conditional = parentBody._if(chunkLen.gt(JExpr.lit(0))); JBlock ifBlock = conditional._then(); - JClass arrayClass = schemaAssistant.classFromSchema(arraySchema, false); + JClass arrayClass = schemaAssistant.classFromSchema(readerArraySchema, false); if (action.getShouldRead()) { JInvocation newArrayExp = JExpr._new(arrayClass); @@ -560,14 +572,16 @@ private void processArray(JVar arraySchemaVar, final String name, final Schema a if (SchemaAssistant.isComplexType(arraySchema.getElementType())) { String elemName = name + "Elem"; - Schema readerArrayElementSchema = null; - if (action.getShouldRead()) { - readerArrayElementSchema = readerArraySchema.getElementType(); - } + Schema readerArrayElementSchema = action.getShouldRead() ? readerArraySchema.getElementType() : null; processComplexType(elementSchemaVar, elemName, arraySchema.getElementType(), readerArrayElementSchema, forBody, action, putValueInArray); } else { - processSimpleType(arraySchema.getElementType(), forBody, action, putValueInArray); + // to preserve reader string specific options use reader array schema + if (action.getShouldRead() && Schema.Type.STRING.equals(arraySchema.getElementType().getType())) { + processSimpleType(readerArraySchema.getElementType(), forBody, action, putValueInArray); + } else { + processSimpleType(arraySchema.getElementType(), forBody, action, putValueInArray); + } } doLoop.body().assign(chunkLen, JExpr.direct(DECODER + ".arrayNext()")); @@ -599,7 +613,7 @@ private void processMap(JVar mapSchemaVar, final String name, final Schema mapSc action = FieldAction.fromValues(mapSchema.getValueType().getType(), false, EMPTY_SYMBOL); } - final JVar mapVar = action.getShouldRead() ? declareValueVar(name, mapSchema, parentBody) : null; + final JVar mapVar = action.getShouldRead() ? declareValueVar(name, readerMapSchema, parentBody) : null; JVar chunkLen = parentBody.decl(codeModel.LONG, getVariableName("chunkLen"), JExpr.direct(DECODER + ".readMapStart()")); @@ -607,7 +621,7 @@ private void processMap(JVar mapSchemaVar, final String name, final Schema mapSc JBlock ifBlock = conditional._then(); if (action.getShouldRead()) { - ifBlock.assign(mapVar, JExpr._new(schemaAssistant.classFromSchema(mapSchema, false))); + ifBlock.assign(mapVar, JExpr._new(schemaAssistant.classFromSchema(readerMapSchema, false))); JBlock elseBlock = conditional._else(); elseBlock.assign(mapVar, codeModel.ref(Collections.class).staticInvoke("emptyMap")); } @@ -619,10 +633,13 @@ private void processMap(JVar mapSchemaVar, final String name, final Schema mapSc forLoop.update(counter.incr()); JBlock forBody = forLoop.body(); - JClass keyClass = schemaAssistant.keyClassFromMapSchema(mapSchema); - JExpression keyValueExpression = JExpr.direct(DECODER + ".readString()"); + JClass keyClass = schemaAssistant.keyClassFromMapSchema(readerMapSchema); + JExpression keyValueExpression = (string.equals(keyClass)) ? + JExpr.direct(DECODER + ".readString()") + : JExpr.direct(DECODER + ".readString(null)"); + if (SchemaAssistant.hasStringableKey(mapSchema)) { - keyValueExpression = JExpr._new(keyClass).arg(keyValueExpression); + keyValueExpression = JExpr._new(keyClass).arg(keyValueExpression.invoke("toString")); } JVar key = forBody.decl(keyClass, getVariableName("key"), keyValueExpression); @@ -631,10 +648,12 @@ private void processMap(JVar mapSchemaVar, final String name, final Schema mapSc mapValueSchemaVar = declareSchemaVar(mapSchema.getValueType(), name + "MapValueSchema", mapSchemaVar.invoke("getValueType")); } + BiConsumer putValueInMap = null; if (action.getShouldRead()) { putValueInMap = (block, expression) -> block.invoke(mapVar, "put").arg(key).arg(expression); } + if (SchemaAssistant.isComplexType(mapSchema.getValueType())) { String valueName = name + "Value"; Schema readerMapValueSchema = null; @@ -644,9 +663,15 @@ private void processMap(JVar mapSchemaVar, final String name, final Schema mapSc processComplexType(mapValueSchemaVar, valueName, mapSchema.getValueType(), readerMapValueSchema, forBody, action, putValueInMap); } else { - processSimpleType(mapSchema.getValueType(), forBody, action, putValueInMap); + // to preserve reader string specific options use reader map schema + if (action.getShouldRead() && Schema.Type.STRING.equals(mapSchema.getValueType().getType())) { + processSimpleType(readerMapSchema.getValueType(), forBody, action, putValueInMap); + } else { + processSimpleType(mapSchema.getValueType(), forBody, action, putValueInMap); + } } doLoop.body().assign(chunkLen, JExpr.direct(DECODER + ".mapNext()")); + if (action.getShouldRead()) { putMapIntoParent.accept(parentBody, mapVar); } @@ -727,7 +752,16 @@ private void processPrimitive(final Schema schema, JBlock body, FieldAction acti String readFunction; switch (schema.getType()) { case STRING: - readFunction = action.getShouldRead() ? "readString()" : "skipString()"; + if (action.getShouldRead()) { + if (string.equals(schemaAssistant.classFromSchema(schema))) { + readFunction = "readString()"; + } else { + // reads as Utf8 + readFunction = "readString(null)"; + } + } else { + readFunction = "skipString()"; + } break; case BYTES: readFunction = "readBytes(null)"; @@ -756,7 +790,7 @@ private void processPrimitive(final Schema schema, JBlock body, FieldAction acti if (action.getShouldRead()) { if (schema.getType().equals(Schema.Type.STRING) && SchemaAssistant.isStringable(schema)) { primitiveValueExpression = JExpr._new(schemaAssistant.classFromSchema(schema)) - .arg(primitiveValueExpression); + .arg(primitiveValueExpression.invoke("toString")); } putValueIntoParent.accept(body, primitiveValueExpression); } else { diff --git a/src/main/java/com/rtbhouse/utils/avro/FastSerializerGenerator.java b/src/main/java/com/rtbhouse/utils/avro/FastSerializerGenerator.java index ac5dee3..c9bb405 100644 --- a/src/main/java/com/rtbhouse/utils/avro/FastSerializerGenerator.java +++ b/src/main/java/com/rtbhouse/utils/avro/FastSerializerGenerator.java @@ -30,11 +30,14 @@ public class FastSerializerGenerator extends FastSerializerGeneratorBase { private final Map serializeMethodMap = new HashMap<>(); private final SchemaAssistant schemaAssistant; + private final JClass string; + public FastSerializerGenerator(boolean useGenericTypes, Schema schema, File destination, ClassLoader classLoader, - String compileClassPath) { + String compileClassPath) { super(schema, destination, classLoader, compileClassPath); this.useGenericTypes = useGenericTypes; - this.schemaAssistant = new SchemaAssistant(codeModel, useGenericTypes); + this.schemaAssistant = new SchemaAssistant(codeModel, useGenericTypes, true); + this.string = codeModel.ref(String.class); } @Override @@ -48,29 +51,29 @@ public FastSerializer generateSerializer() { final JMethod serializeMethod = serializerClass.method(JMod.PUBLIC, void.class, "serialize"); final JVar serializeMethodParam; - JClass outputClass = schemaAssistant.classFromSchema(schema); - serializerClass._implements(codeModel.ref(FastSerializer.class).narrow(outputClass)); - serializeMethodParam = serializeMethod.param(outputClass, "data"); + JClass inputClass = schemaAssistant.classFromSchema(schema); + serializerClass._implements(codeModel.ref(FastSerializer.class).narrow(inputClass)); + serializeMethodParam = serializeMethod.param(inputClass, "data"); switch (schema.getType()) { - case RECORD: - processRecord(schema, serializeMethodParam, serializeMethod.body()); - break; - case ARRAY: - processArray(schema, serializeMethodParam, serializeMethod.body()); - break; - case MAP: - processMap(schema, serializeMethodParam, serializeMethod.body()); - break; - default: - throw new FastSerializerGeneratorException("Unsupported input schema type: " + schema.getType()); + case RECORD: + processRecord(schema, serializeMethodParam, serializeMethod.body()); + break; + case ARRAY: + processArray(schema, serializeMethodParam, serializeMethod.body()); + break; + case MAP: + processMap(schema, serializeMethodParam, serializeMethod.body()); + break; + default: + throw new FastSerializerGeneratorException("Unsupported input schema type: " + schema.getType()); } serializeMethod.param(codeModel.ref(Encoder.class), ENCODER); serializeMethod._throws(codeModel.ref(IOException.class)); final Class> clazz = compileClass(className); - return clazz.newInstance(); + return clazz.getConstructor().newInstance(); } catch (JClassAlreadyExistsException e) { throw new FastSerializerGeneratorException("Class: " + className + " already exists"); } catch (Exception e) { @@ -80,34 +83,34 @@ public FastSerializer generateSerializer() { private void processComplexType(Schema schema, JExpression valueExpr, JBlock body) { switch (schema.getType()) { - case RECORD: - processRecord(schema, valueExpr, body); - break; - case ARRAY: - processArray(schema, valueExpr, body); - break; - case UNION: - processUnion(schema, valueExpr, body); - break; - case MAP: - processMap(schema, valueExpr, body); - break; - default: - throw new FastSerializerGeneratorException("Not a complex schema type: " + schema.getType()); + case RECORD: + processRecord(schema, valueExpr, body); + break; + case ARRAY: + processArray(schema, valueExpr, body); + break; + case UNION: + processUnion(schema, valueExpr, body); + break; + case MAP: + processMap(schema, valueExpr, body); + break; + default: + throw new FastSerializerGeneratorException("Not a complex schema type: " + schema.getType()); } } private void processSimpleType(Schema schema, JExpression valueExpression, JBlock body) { switch (schema.getType()) { - case ENUM: - processEnum(schema, valueExpression, body); - break; - case FIXED: - processFixed(schema, valueExpression, body); - break; - default: - processPrimitive(schema, valueExpression, body); + case ENUM: + processEnum(schema, valueExpression, body); + break; + case FIXED: + processFixed(schema, valueExpression, body); + break; + default: + processPrimitive(schema, valueExpression, body); } } @@ -167,13 +170,13 @@ private void processArray(final Schema arraySchema, JExpression arrayExpr, JBloc } else { processSimpleType(elementSchema, arrayExpr.invoke("get").arg(counter), forBody); } + body.invoke(JExpr.direct(ENCODER), "writeArrayEnd"); } private void processMap(final Schema mapSchema, JExpression mapExpr, JBlock body) { final JClass mapClass = schemaAssistant.classFromSchema(mapSchema); - JClass keyClass = schemaAssistant.keyClassFromMapSchema(mapSchema); body.invoke(JExpr.direct(ENCODER), "writeMapStart"); @@ -196,7 +199,7 @@ private void processMap(final Schema mapSchema, JExpression mapExpr, JBlock body JVar keyStringVar; if (SchemaAssistant.hasStringableKey(mapSchema)) { - keyStringVar = forBody.decl(codeModel.ref(String.class), getVariableName("keyString"), + keyStringVar = forBody.decl(string, getVariableName("keyString"), mapKeysLoop.var().invoke("toString")); } else { keyStringVar = mapKeysLoop.var(); @@ -236,24 +239,27 @@ private void processUnion(final Schema unionSchema, JExpression unionExpr, JBloc JClass optionClass = schemaAssistant.classFromSchema(schemaOption); JClass rawOptionClass = schemaAssistant.classFromSchema(schemaOption, true, true); JExpression condition = unionExpr._instanceof(rawOptionClass); + if (useGenericTypes && SchemaAssistant.isNamedType(schemaOption)) { condition = condition.cand(JExpr.invoke(JExpr.lit(schemaOption.getFullName()), "equals") .arg(JExpr.invoke(JExpr.cast(optionClass, unionExpr), "getSchema").invoke("getFullName"))); } + ifBlock = ifBlock != null ? ifBlock._elseif(condition) : body._if(condition); JBlock thenBlock = ifBlock._then(); thenBlock.invoke(JExpr.direct(ENCODER), "writeIndex") .arg(JExpr.lit(unionSchema.getIndexNamed(schemaOption.getFullName()))); + switch (schemaOption.getType()) { - case UNION: - case NULL: - throw new FastSerializerGeneratorException("Incorrect union subschema processing: " + schemaOption); - default: - if (SchemaAssistant.isComplexType(schemaOption)) { - processComplexType(schemaOption, JExpr.cast(optionClass, unionExpr), thenBlock); - } else { - processSimpleType(schemaOption, unionExpr, thenBlock); - } + case UNION: + case NULL: + throw new FastSerializerGeneratorException("Incorrect union subschema processing: " + schemaOption); + default: + if (SchemaAssistant.isComplexType(schemaOption)) { + processComplexType(schemaOption, JExpr.cast(optionClass, unionExpr), thenBlock); + } else { + processSimpleType(schemaOption, unionExpr, thenBlock); + } } } } @@ -282,34 +288,35 @@ private void processPrimitive(final Schema primitiveSchema, JExpression primitiv String writeFunction; JClass primitiveClass = schemaAssistant.classFromSchema(primitiveSchema); JExpression castedValue = JExpr.cast(primitiveClass, primitiveValueExpression); + switch (primitiveSchema.getType()) { - case STRING: - writeFunction = "writeString"; - if (SchemaAssistant.isStringable(primitiveSchema)) { - castedValue = JExpr.cast(codeModel.ref(String.class), castedValue.invoke("toString")); - } - break; - case BYTES: - writeFunction = "writeBytes"; - break; - case INT: - writeFunction = "writeInt"; - break; - case LONG: - writeFunction = "writeLong"; - break; - case FLOAT: - writeFunction = "writeFloat"; - break; - case DOUBLE: - writeFunction = "writeDouble"; - break; - case BOOLEAN: - writeFunction = "writeBoolean"; - break; - default: - throw new FastSerializerGeneratorException( - "Unsupported primitive schema of type: " + primitiveSchema.getType()); + case STRING: + writeFunction = "writeString"; + if (SchemaAssistant.isStringable(primitiveSchema)) { + castedValue = JExpr.cast(string, castedValue.invoke("toString")); + } + break; + case BYTES: + writeFunction = "writeBytes"; + break; + case INT: + writeFunction = "writeInt"; + break; + case LONG: + writeFunction = "writeLong"; + break; + case FLOAT: + writeFunction = "writeFloat"; + break; + case DOUBLE: + writeFunction = "writeDouble"; + break; + case BOOLEAN: + writeFunction = "writeBoolean"; + break; + default: + throw new FastSerializerGeneratorException( + "Unsupported primitive schema of type: " + primitiveSchema.getType()); } body.invoke(JExpr.direct(ENCODER), writeFunction).arg(castedValue); diff --git a/src/main/java/com/rtbhouse/utils/avro/SchemaAssistant.java b/src/main/java/com/rtbhouse/utils/avro/SchemaAssistant.java index af1c105..5a4021f 100644 --- a/src/main/java/com/rtbhouse/utils/avro/SchemaAssistant.java +++ b/src/main/java/com/rtbhouse/utils/avro/SchemaAssistant.java @@ -23,15 +23,18 @@ import com.sun.codemodel.JExpr; import com.sun.codemodel.JExpression; import com.sun.codemodel.JInvocation; +import org.apache.avro.util.Utf8; public class SchemaAssistant { private final JCodeModel codeModel; private final boolean useGenericTypes; + private final boolean useCharSequence; private Set> exceptionsFromStringable; - public SchemaAssistant(JCodeModel codeModel, boolean useGenericTypes) { + public SchemaAssistant(JCodeModel codeModel, boolean useGenericTypes, boolean useCharSequence) { this.codeModel = codeModel; this.useGenericTypes = useGenericTypes; + this.useCharSequence = useCharSequence; this.exceptionsFromStringable = new HashSet<>(); } @@ -59,14 +62,23 @@ private void extendExceptionsFromStringable(String className) { } public JClass keyClassFromMapSchema(Schema schema) { + if (!Schema.Type.MAP.equals(schema.getType())) { throw new SchemaAssistantException("Map schema was expected, instead got:" + schema.getType().getName()); } + + if (hasStringableKey(schema) && !useGenericTypes) { extendExceptionsFromStringable(schema.getProp(SpecificData.KEY_CLASS_PROP)); return codeModel.ref(schema.getProp(SpecificData.KEY_CLASS_PROP)); } else { - return codeModel.ref(String.class); + if (useCharSequence) { + return codeModel.ref(CharSequence.class); + } else if ("String".equals(schema.getProp(GenericData.STRING_PROP))) { + return codeModel.ref(String.class); + } else { + return codeModel.ref(Utf8.class); + } } } @@ -120,74 +132,81 @@ public JClass classFromSchema(Schema schema, boolean abstractType, boolean rawTy switch (schema.getType()) { - case RECORD: - outputClass = useGenericTypes ? codeModel.ref(GenericData.Record.class) - : codeModel.ref(schema.getFullName()); - break; + case RECORD: + outputClass = useGenericTypes ? codeModel.ref(GenericData.Record.class) + : codeModel.ref(schema.getFullName()); + break; - case ARRAY: - if (abstractType) { - outputClass = codeModel.ref(List.class); - } else { - if (useGenericTypes) { - outputClass = codeModel.ref(GenericData.Array.class); + case ARRAY: + if (abstractType) { + outputClass = codeModel.ref(List.class); } else { - outputClass = codeModel.ref(ArrayList.class); + if (useGenericTypes) { + outputClass = codeModel.ref(GenericData.Array.class); + } else { + outputClass = codeModel.ref(ArrayList.class); + } } - } - if (!rawType) { - outputClass = outputClass.narrow(elementClassFromArraySchema(schema)); - } - break; - case MAP: - if (!abstractType) { - outputClass = codeModel.ref(HashMap.class); - } else { - outputClass = codeModel.ref(Map.class); - } - if (!rawType) { - outputClass = outputClass.narrow(keyClassFromMapSchema(schema), valueClassFromMapSchema(schema)); - } - break; - case UNION: - outputClass = classFromUnionSchema(schema); - break; - case ENUM: - outputClass = useGenericTypes ? codeModel.ref(GenericData.EnumSymbol.class) - : codeModel.ref(schema.getFullName()); - break; - case FIXED: - outputClass = useGenericTypes ? codeModel.ref(GenericData.Fixed.class) - : codeModel.ref(schema.getFullName()); - break; - case BOOLEAN: - outputClass = codeModel.ref(Boolean.class); - break; - case DOUBLE: - outputClass = codeModel.ref(Double.class); - break; - case FLOAT: - outputClass = codeModel.ref(Float.class); - break; - case INT: - outputClass = codeModel.ref(Integer.class); - break; - case LONG: - outputClass = codeModel.ref(Long.class); - break; - case STRING: - if (isStringable(schema) && !useGenericTypes) { - outputClass = codeModel.ref(schema.getProp(SpecificData.CLASS_PROP)); - extendExceptionsFromStringable(schema.getProp(SpecificData.CLASS_PROP)); - } else { - outputClass = codeModel.ref(String.class); - } - break; - case BYTES: - outputClass = codeModel.ref(ByteBuffer.class); - break; - default: - throw new SchemaAssistantException("Incorrect request for " + schema.getType().getName() + " class!"); + if (!rawType) { + outputClass = outputClass.narrow(elementClassFromArraySchema(schema)); + } + break; + case MAP: + if (!abstractType) { + outputClass = codeModel.ref(HashMap.class); + } else { + outputClass = codeModel.ref(Map.class); + } + if (!rawType) { + outputClass = outputClass.narrow(keyClassFromMapSchema(schema), valueClassFromMapSchema(schema)); + } + break; + case UNION: + outputClass = classFromUnionSchema(schema); + break; + case ENUM: + outputClass = useGenericTypes ? codeModel.ref(GenericData.EnumSymbol.class) + : codeModel.ref(schema.getFullName()); + break; + case FIXED: + outputClass = useGenericTypes ? codeModel.ref(GenericData.Fixed.class) + : codeModel.ref(schema.getFullName()); + break; + case BOOLEAN: + outputClass = codeModel.ref(Boolean.class); + break; + case DOUBLE: + outputClass = codeModel.ref(Double.class); + break; + case FLOAT: + outputClass = codeModel.ref(Float.class); + break; + case INT: + outputClass = codeModel.ref(Integer.class); + break; + case LONG: + outputClass = codeModel.ref(Long.class); + break; + case STRING: + + if (isStringable(schema) && !useGenericTypes) { + outputClass = codeModel.ref(schema.getProp(SpecificData.CLASS_PROP)); + extendExceptionsFromStringable(schema.getProp(SpecificData.CLASS_PROP)); + } else { + if (useCharSequence) { + outputClass = codeModel.ref(CharSequence.class); + } else if ("String".equals(schema.getProp(GenericData.STRING_PROP))) { + outputClass = codeModel.ref(String.class); + } else { + outputClass = codeModel.ref(Utf8.class); + } + } + break; + case BYTES: + outputClass = codeModel.ref(ByteBuffer.class); + break; + default: + throw new SchemaAssistantException("Incorrect request for " + schema.getType().getName() + " class!"); } return outputClass; @@ -229,24 +248,24 @@ public JExpression getStringableValue(Schema schema, JExpression stringExpr) { /* Complex type here means type that it have to handle other types inside itself. */ public static boolean isComplexType(Schema schema) { switch (schema.getType()) { - case MAP: - case RECORD: - case ARRAY: - case UNION: - return true; - default: - return false; + case MAP: + case RECORD: + case ARRAY: + case UNION: + return true; + default: + return false; } } public static boolean isNamedType(Schema schema) { switch (schema.getType()) { - case RECORD: - case ENUM: - case FIXED: - return true; - default: - return false; + case RECORD: + case ENUM: + case FIXED: + return true; + default: + return false; } } diff --git a/src/test/java/com/rtbhouse/utils/avro/FastDatumReaderTest.java b/src/test/java/com/rtbhouse/utils/avro/FastDatumReaderTest.java index 54c78fa..4d599f1 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastDatumReaderTest.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastDatumReaderTest.java @@ -112,6 +112,6 @@ public void shouldCreateGenericDatumReader() throws IOException, InterruptedExce Assert.assertNotEquals(2, fastGenericDeserializer.getClass().getDeclaredMethods().length); Assert.assertEquals( "test", - fastGenericDatumReader.read(null, serializeGeneric(recordBuilder.build())).get("test")); + fastGenericDatumReader.read(null, serializeGeneric(recordBuilder.build())).get("test").toString()); } } diff --git a/src/test/java/com/rtbhouse/utils/avro/FastDeserializerDefaultsTest.java b/src/test/java/com/rtbhouse/utils/avro/FastDeserializerDefaultsTest.java index 74074ce..df197ad 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastDeserializerDefaultsTest.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastDeserializerDefaultsTest.java @@ -24,6 +24,7 @@ import org.apache.avro.generic.GenericRecordBuilder; import org.apache.avro.io.Decoder; import org.apache.avro.specific.SpecificDatumReader; +import org.apache.avro.util.Utf8; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -46,7 +47,7 @@ public void prepare() throws Exception { Path tempPath = Files.createTempDirectory("generated"); tempDir = tempPath.toFile(); - classLoader = URLClassLoader.newInstance(new URL[] { tempDir.toURI().toURL() }, + classLoader = URLClassLoader.newInstance(new URL[]{tempDir.toURI().toURL()}, FastDeserializerDefaultsTest.class.getClassLoader()); } @@ -81,14 +82,14 @@ public void shouldReadSpecificDefaults() throws IOException { Assert.assertNull(testRecord.getTestFloatUnion()); Assert.assertEquals(true, testRecord.getTestBoolean()); Assert.assertNull(testRecord.getTestBooleanUnion()); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0, 1, 2, 3 }), testRecord.getTestBytes()); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0, 1, 2, 3}), testRecord.getTestBytes()); Assert.assertNull(testRecord.getTestBytesUnion()); Assert.assertEquals("testStringValue", testRecord.getTestString()); Assert.assertEquals(new URL("http://www.example.com"), testRecord.getTestStringable()); Assert.assertNull(testRecord.getTestStringUnion()); - Assert.assertEquals(new DefaultsFixed(new byte[] { (byte) 0xFF }), testRecord.getTestFixed()); + Assert.assertEquals(new DefaultsFixed(new byte[]{(byte) 0xFF}), testRecord.getTestFixed()); Assert.assertNull(testRecord.getTestFixedUnion()); - Assert.assertEquals(Collections.singletonList(new DefaultsFixed(new byte[] { (byte) 0xFA })), + Assert.assertEquals(Collections.singletonList(new DefaultsFixed(new byte[]{(byte) 0xFA})), testRecord.getTestFixedArray()); List listWithNull = new LinkedList(); @@ -157,17 +158,17 @@ public void shouldReadGenericDefaults() throws IOException { Assert.assertNull(testRecord.get("testFloatUnion")); Assert.assertEquals(true, testRecord.get("testBoolean")); Assert.assertNull(testRecord.get("testBooleanUnion")); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0, 1, 2, 3 }), testRecord.get("testBytes")); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0, 1, 2, 3}), testRecord.get("testBytes")); Assert.assertNull(testRecord.get("testBytesUnion")); Assert.assertEquals("testStringValue", testRecord.get("testString")); Assert.assertEquals("http://www.example.com", testRecord.get("testStringable")); Assert.assertNull(testRecord.get("testStringUnion")); - Assert.assertEquals(new GenericData.Fixed(DefaultsFixed.getClassSchema(), new byte[] { (byte) 0xFF }), + Assert.assertEquals(new GenericData.Fixed(DefaultsFixed.getClassSchema(), new byte[]{(byte) 0xFF}), testRecord.get("testFixed")); Assert.assertNull(testRecord.get("testFixedUnion")); Assert.assertEquals( Collections.singletonList( - new GenericData.Fixed(DefaultsFixed.getClassSchema(), new byte[] { (byte) 0xFA })), + new GenericData.Fixed(DefaultsFixed.getClassSchema(), new byte[]{(byte) 0xFA})), testRecord.get("testFixedArray")); List listWithNull = new LinkedList(); @@ -259,7 +260,7 @@ public void shouldAddFieldsInMiddleOfSchema() throws IOException { GenericData.EnumSymbol testEnum = new GenericData.EnumSymbol( oldRecordSchema.getField("testEnum").schema(), "A"); GenericData.Fixed testFixed = new GenericData.Fixed(oldRecordSchema.getField("testFixed").schema(), - new byte[] { 0x01 }); + new byte[]{0x01}); GenericData.Record oldRecord = new GenericData.Record(oldRecordSchema); oldRecord.put("testInt", 1); @@ -267,7 +268,7 @@ public void shouldAddFieldsInMiddleOfSchema() throws IOException { oldRecord.put("testDouble", 1.0); oldRecord.put("testFloat", 1.0f); oldRecord.put("testBoolean", true); - oldRecord.put("testBytes", ByteBuffer.wrap(new byte[] { 0x01, 0x02 })); + oldRecord.put("testBytes", ByteBuffer.wrap(new byte[]{0x01, 0x02})); oldRecord.put("testString", "aaa"); oldRecord.put("testFixed", testFixed); oldRecord.put("testEnum", testEnum); @@ -278,8 +279,8 @@ public void shouldAddFieldsInMiddleOfSchema() throws IOException { oldRecord.put("subRecordUnion", subRecord); oldRecord.put("subRecord", subRecord); oldRecord.put("recordsArray", Collections.singletonList(subRecord)); - Map recordsMap = new HashMap<>(); - recordsMap.put("1", subRecord); + Map recordsMap = new HashMap<>(); + recordsMap.put(new Utf8("1"), subRecord); oldRecord.put("recordsMap", recordsMap); oldRecord.put("testFixedArray", Collections.emptyList()); @@ -300,7 +301,7 @@ public void shouldAddFieldsInMiddleOfSchema() throws IOException { newSubRecord.put("subField", "abc"); newSubRecord.put("anotherField", "ghi"); newSubRecord.put("newSubField", "newSubFieldValue"); - recordsMap.put("1", newSubRecord); + recordsMap.put(new Utf8("1"), newSubRecord); Assert.assertEquals("newSubFieldValue", ((GenericRecord) record.get("subRecordUnion")).get("newSubField").toString()); @@ -310,8 +311,8 @@ public void shouldAddFieldsInMiddleOfSchema() throws IOException { Assert.assertEquals(1.0, record.get("testDouble")); Assert.assertEquals(1.0f, record.get("testFloat")); Assert.assertEquals(true, record.get("testBoolean")); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0x01, 0x02 }), record.get("testBytes")); - Assert.assertEquals("aaa", record.get("testString")); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), record.get("testBytes")); + Assert.assertEquals("aaa", record.get("testString").toString()); Assert.assertEquals(testFixed, record.get("testFixed")); Assert.assertEquals(testEnum, record.get("testEnum")); Assert.assertEquals(newSubRecord, record.get("subRecordUnion")); diff --git a/src/test/java/com/rtbhouse/utils/avro/FastGenericDeserializerGeneratorTest.java b/src/test/java/com/rtbhouse/utils/avro/FastGenericDeserializerGeneratorTest.java index 7c31a43..bd2164d 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastGenericDeserializerGeneratorTest.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastGenericDeserializerGeneratorTest.java @@ -1,17 +1,14 @@ package com.rtbhouse.utils.avro; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.addAliases; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createArrayFieldSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createEnumSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createField; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createFixedSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createMapFieldSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createPrimitiveFieldSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createPrimitiveUnionFieldSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createRecord; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createUnionField; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createUnionSchema; -import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.serializeGeneric; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.io.Decoder; +import org.apache.avro.util.Utf8; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; import java.io.File; import java.net.URL; @@ -25,14 +22,18 @@ import java.util.List; import java.util.Map; -import org.apache.avro.Schema; -import org.apache.avro.generic.GenericData; -import org.apache.avro.generic.GenericRecord; -import org.apache.avro.generic.GenericRecordBuilder; -import org.apache.avro.io.Decoder; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.addAliases; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createArrayFieldSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createEnumSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createField; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createFixedSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createMapFieldSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createPrimitiveFieldSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createPrimitiveUnionFieldSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createRecord; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createUnionField; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.createUnionSchema; +import static com.rtbhouse.utils.avro.FastSerdeTestsSupport.serializeGeneric; public class FastGenericDeserializerGeneratorTest { @@ -44,18 +45,22 @@ public void prepare() throws Exception { Path tempPath = Files.createTempDirectory("generated"); tempDir = tempPath.toFile(); - classLoader = URLClassLoader.newInstance(new URL[] { tempDir.toURI().toURL() }, + classLoader = URLClassLoader.newInstance(new URL[]{tempDir.toURI().toURL()}, FastGenericDeserializerGeneratorTest.class.getClassLoader()); } @Test public void shouldReadPrimitives() { // given + Schema javaLangStringSchema = Schema.create(Schema.Type.STRING); + GenericData.setStringType(javaLangStringSchema, GenericData.StringType.String); Schema recordSchema = createRecord("testRecord", createField("testInt", Schema.create(Schema.Type.INT)), createPrimitiveUnionFieldSchema("testIntUnion", Schema.Type.INT), createField("testString", Schema.create(Schema.Type.STRING)), createPrimitiveUnionFieldSchema("testStringUnion", Schema.Type.STRING), + createField("testJavaString", javaLangStringSchema), + createUnionField("testJavaStringUnion", javaLangStringSchema), createField("testLong", Schema.create(Schema.Type.LONG)), createPrimitiveUnionFieldSchema("testLongUnion", Schema.Type.LONG), createField("testDouble", Schema.create(Schema.Type.DOUBLE)), @@ -72,16 +77,18 @@ public void shouldReadPrimitives() { builder.set("testIntUnion", 1); builder.set("testString", "aaa"); builder.set("testStringUnion", "aaa"); - builder.set("testLong", 1l); - builder.set("testLongUnion", 1l); + builder.set("testJavaString", "aaa"); + builder.set("testJavaStringUnion", "aaa"); + builder.set("testLong", 1L); + builder.set("testLongUnion", 1L); builder.set("testDouble", 1.0); builder.set("testDoubleUnion", 1.0); builder.set("testFloat", 1.0f); builder.set("testFloatUnion", 1.0f); builder.set("testBoolean", true); builder.set("testBooleanUnion", true); - builder.set("testBytes", ByteBuffer.wrap(new byte[] { 0x01, 0x02 })); - builder.set("testBytesUnion", ByteBuffer.wrap(new byte[] { 0x01, 0x02 })); + builder.set("testBytes", ByteBuffer.wrap(new byte[]{0x01, 0x02})); + builder.set("testBytesUnion", ByteBuffer.wrap(new byte[]{0x01, 0x02})); // when GenericRecord record = deserializeGenericFast(recordSchema, recordSchema, serializeGeneric(builder.build())); @@ -89,22 +96,25 @@ public void shouldReadPrimitives() { // then Assert.assertEquals(1, record.get("testInt")); Assert.assertEquals(1, record.get("testIntUnion")); - Assert.assertEquals("aaa", record.get("testString")); - Assert.assertEquals("aaa", record.get("testStringUnion")); - Assert.assertEquals(1l, record.get("testLong")); - Assert.assertEquals(1l, record.get("testLongUnion")); + Assert.assertEquals("aaa", record.get("testString").toString()); + Assert.assertEquals("aaa", record.get("testStringUnion").toString()); + Assert.assertEquals("aaa", record.get("testJavaString")); + Assert.assertEquals("aaa", record.get("testJavaStringUnion")); + Assert.assertEquals(1L, record.get("testLong")); + Assert.assertEquals(1L, record.get("testLongUnion")); Assert.assertEquals(1.0, record.get("testDouble")); Assert.assertEquals(1.0, record.get("testDoubleUnion")); Assert.assertEquals(1.0f, record.get("testFloat")); Assert.assertEquals(1.0f, record.get("testFloatUnion")); Assert.assertEquals(true, record.get("testBoolean")); Assert.assertEquals(true, record.get("testBooleanUnion")); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0x01, 0x02 }), record.get("testBytes")); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0x01, 0x02 }), record.get("testBytesUnion")); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), record.get("testBytes")); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), record.get("testBytesUnion")); } @Test + @SuppressWarnings("unchecked") public void shouldReadFixed() { // given Schema fixedSchema = createFixedSchema("testFixed", 2); @@ -113,28 +123,29 @@ public void shouldReadFixed() { createArrayFieldSchema("testFixedUnionArray", createUnionSchema(fixedSchema))); GenericRecordBuilder builder = new GenericRecordBuilder(recordSchema); - builder.set("testFixed", new GenericData.Fixed(fixedSchema, new byte[] { 0x01, 0x02 })); - builder.set("testFixedUnion", new GenericData.Fixed(fixedSchema, new byte[] { 0x03, 0x04 })); - builder.set("testFixedArray", Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[] { 0x05, 0x06 }))); + builder.set("testFixed", new GenericData.Fixed(fixedSchema, new byte[]{0x01, 0x02})); + builder.set("testFixedUnion", new GenericData.Fixed(fixedSchema, new byte[]{0x03, 0x04})); + builder.set("testFixedArray", Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[]{0x05, 0x06}))); builder.set("testFixedUnionArray", - Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[] { 0x07, 0x08 }))); + Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[]{0x07, 0x08}))); // when GenericRecord record = deserializeGenericFast(recordSchema, recordSchema, serializeGeneric(builder.build())); // then - Assert.assertArrayEquals(new byte[] { 0x01, 0x02 }, ((GenericData.Fixed) record.get("testFixed")).bytes()); - Assert.assertArrayEquals(new byte[] { 0x03, 0x04 }, ((GenericData.Fixed) record.get("testFixedUnion")).bytes()); - Assert.assertArrayEquals(new byte[] { 0x05, 0x06 }, + Assert.assertArrayEquals(new byte[]{0x01, 0x02}, ((GenericData.Fixed) record.get("testFixed")).bytes()); + Assert.assertArrayEquals(new byte[]{0x03, 0x04}, ((GenericData.Fixed) record.get("testFixedUnion")).bytes()); + Assert.assertArrayEquals(new byte[]{0x05, 0x06}, ((List) record.get("testFixedArray")).get(0).bytes()); - Assert.assertArrayEquals(new byte[] { 0x07, 0x08 }, + Assert.assertArrayEquals(new byte[]{0x07, 0x08}, ((List) record.get("testFixedUnionArray")).get(0).bytes()); } @Test + @SuppressWarnings("unchecked") public void shouldReadEnum() { // given - Schema enumSchema = createEnumSchema("testEnum", new String[] { "A", "B" }); + Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B"}); Schema recordSchema = createRecord("testRecord", createField("testEnum", enumSchema), createUnionField("testEnumUnion", enumSchema), createArrayFieldSchema("testEnumArray", enumSchema), createArrayFieldSchema("testEnumUnionArray", createUnionSchema(enumSchema))); @@ -156,9 +167,10 @@ public void shouldReadEnum() { } @Test + @SuppressWarnings("unchecked") public void shouldReadPermutatedEnum() { // given - Schema enumSchema = createEnumSchema("testEnum", new String[] { "A", "B", "C", "D", "E" }); + Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B", "C", "D", "E"}); Schema recordSchema = createRecord("testRecord", createField("testEnum", enumSchema), createUnionField("testEnumUnion", enumSchema), createArrayFieldSchema("testEnumArray", enumSchema), createArrayFieldSchema("testEnumUnionArray", createUnionSchema(enumSchema))); @@ -169,7 +181,7 @@ public void shouldReadPermutatedEnum() { builder.set("testEnumArray", Arrays.asList(new GenericData.EnumSymbol(enumSchema, "C"))); builder.set("testEnumUnionArray", Arrays.asList(new GenericData.EnumSymbol(enumSchema, "D"))); - Schema enumSchema1 = createEnumSchema("testEnum", new String[] { "B", "A", "D", "E", "C" }); + Schema enumSchema1 = createEnumSchema("testEnum", new String[]{"B", "A", "D", "E", "C"}); Schema recordSchema1 = createRecord("testRecord", createField("testEnum", enumSchema1), createUnionField("testEnumUnion", enumSchema1), createArrayFieldSchema("testEnumArray", enumSchema1), createArrayFieldSchema("testEnumUnionArray", createUnionSchema(enumSchema1))); @@ -187,13 +199,13 @@ public void shouldReadPermutatedEnum() { @Test(expected = FastDeserializerGeneratorException.class) public void shouldNotReadStrippedEnum() { // given - Schema enumSchema = createEnumSchema("testEnum", new String[] { "A", "B", "C" }); + Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B", "C"}); Schema recordSchema = createRecord("testRecord", createField("testEnum", enumSchema)); GenericRecordBuilder builder = new GenericRecordBuilder(recordSchema); builder.set("testEnum", new GenericData.EnumSymbol(enumSchema, "C")); - Schema enumSchema1 = createEnumSchema("testEnum", new String[] { "A", "B" }); + Schema enumSchema1 = createEnumSchema("testEnum", new String[]{"A", "B"}); Schema recordSchema1 = createRecord("testRecord", createField("testEnum", enumSchema1)); // when @@ -221,14 +233,15 @@ public void shouldReadSubRecordField() { GenericRecord record = deserializeGenericFast(recordSchema, recordSchema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", ((GenericRecord) record.get("record")).get("subField")); + Assert.assertEquals("abc", ((GenericRecord) record.get("record")).get("subField").toString()); Assert.assertEquals(subRecordSchema.hashCode(), ((GenericRecord) record.get("record")).getSchema().hashCode()); - Assert.assertEquals("abc", ((GenericRecord) record.get("record1")).get("subField")); + Assert.assertEquals("abc", ((GenericRecord) record.get("record1")).get("subField").toString()); Assert.assertEquals(subRecordSchema.hashCode(), ((GenericRecord) record.get("record1")).getSchema().hashCode()); - Assert.assertEquals("abc", record.get("field")); + Assert.assertEquals("abc", record.get("field").toString()); } @Test + @SuppressWarnings("unchecked") public void shouldReadSubRecordCollectionsField() { // given Schema subRecordSchema = createRecord("subRecord", @@ -256,16 +269,17 @@ public void shouldReadSubRecordCollectionsField() { // then Assert.assertEquals("abc", - ((List) record.get("recordsArray")).get(0).get("subField")); + ((List) record.get("recordsArray")).get(0).get("subField").toString()); Assert.assertEquals("abc", - ((List) record.get("recordsArrayUnion")).get(0).get("subField")); + ((List) record.get("recordsArrayUnion")).get(0).get("subField").toString()); Assert.assertEquals("abc", - ((Map) record.get("recordsMap")).get("1").get("subField")); + ((Map) record.get("recordsMap")).get(new Utf8("1")).get("subField").toString()); Assert.assertEquals("abc", - ((Map) record.get("recordsMapUnion")).get("1").get("subField")); + ((Map) record.get("recordsMapUnion")).get(new Utf8("1")).get("subField").toString()); } @Test + @SuppressWarnings("unchecked") public void shouldReadSubRecordComplexCollectionsField() { // given Schema subRecordSchema = createRecord("subRecord", @@ -304,15 +318,15 @@ public void shouldReadSubRecordComplexCollectionsField() { // then Assert.assertEquals("abc", - ((List>) record.get("recordsArrayMap")).get(0).get("1").get("subField")); + ((List>) record.get("recordsArrayMap")).get(0).get(new Utf8("1")).get("subField").toString()); Assert.assertEquals("abc", - ((Map>) record.get("recordsMapArray")).get("1").get(0).get("subField")); + ((Map>) record.get("recordsMapArray")).get(new Utf8("1")).get(0).get("subField").toString()); Assert.assertEquals("abc", - ((List>) record.get("recordsArrayMapUnion")).get(0).get("1") - .get("subField")); + ((List>) record.get("recordsArrayMapUnion")).get(0).get(new Utf8("1")) + .get("subField").toString()); Assert.assertEquals("abc", - ((Map>) record.get("recordsMapArrayUnion")).get("1").get(0) - .get("subField")); + ((Map>) record.get("recordsMapArrayUnion")).get(new Utf8("1")).get(0) + .get("subField").toString()); } @Test @@ -334,11 +348,12 @@ public void shouldReadAliasedField() { GenericRecord record = deserializeGenericFast(record1Schema, record2Schema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", record.get("testString")); - Assert.assertEquals("def", record.get("testStringUnionAlias")); + Assert.assertEquals("abc", record.get("testString").toString()); + Assert.assertEquals("def", record.get("testStringUnionAlias").toString()); } @Test + @SuppressWarnings("unchecked") public void shouldSkipRemovedField() { // given Schema subRecord1Schema = createRecord("subRecord", @@ -382,13 +397,13 @@ public void shouldSkipRemovedField() { GenericRecord record = deserializeGenericFast(record1Schema, record2Schema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", record.get("testNotRemoved")); + Assert.assertEquals("abc", record.get("testNotRemoved").toString()); Assert.assertNull(record.get("testRemoved")); - Assert.assertEquals("ghi", record.get("testNotRemoved2")); - Assert.assertEquals("ghi", ((GenericRecord) record.get("subRecord")).get("testNotRemoved2")); - Assert.assertEquals("ghi", ((List) record.get("subRecordArray")).get(0).get("testNotRemoved2")); + Assert.assertEquals("ghi", record.get("testNotRemoved2").toString()); + Assert.assertEquals("ghi", ((GenericRecord) record.get("subRecord")).get("testNotRemoved2").toString()); + Assert.assertEquals("ghi", ((List) record.get("subRecordArray")).get(0).get("testNotRemoved2").toString()); Assert.assertEquals("ghi", - ((Map) record.get("subRecordMap")).get("1").get("testNotRemoved2")); + ((Map) record.get("subRecordMap")).get(new Utf8("1")).get("testNotRemoved2").toString()); } @Test @@ -429,10 +444,10 @@ public void shouldSkipRemovedRecord() { GenericRecord record = deserializeGenericFast(record1Schema, record2Schema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", ((GenericRecord) record.get("subRecord1")).get("test1")); - Assert.assertEquals("def", ((GenericRecord) record.get("subRecord1")).get("test2")); - Assert.assertEquals("abc", ((GenericRecord) record.get("subRecord4")).get("test1")); - Assert.assertEquals("def", ((GenericRecord) record.get("subRecord4")).get("test2")); + Assert.assertEquals("abc", ((GenericRecord) record.get("subRecord1")).get("test1").toString()); + Assert.assertEquals("def", ((GenericRecord) record.get("subRecord1")).get("test2").toString()); + Assert.assertEquals("abc", ((GenericRecord) record.get("subRecord4")).get("test1").toString()); + Assert.assertEquals("def", ((GenericRecord) record.get("subRecord4")).get("test2").toString()); } @Test @@ -473,8 +488,8 @@ public void shouldSkipRemovedNestedRecord() { GenericRecord record = deserializeGenericFast(record1Schema, record2Schema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", ((GenericRecord) record.get("subRecord")).get("test1")); - Assert.assertEquals("def", ((GenericRecord) record.get("subRecord")).get("test4")); + Assert.assertEquals("abc", ((GenericRecord) record.get("subRecord")).get("test1").toString()); + Assert.assertEquals("def", ((GenericRecord) record.get("subRecord")).get("test4").toString()); } @Test @@ -498,7 +513,7 @@ public void shouldReadMultipleChoiceUnion() { GenericRecord record = deserializeGenericFast(recordSchema, recordSchema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", ((GenericData.Record) record.get("union")).get("subField")); + Assert.assertEquals("abc", ((GenericData.Record) record.get("union")).get("subField").toString()); // given builder = new GenericRecordBuilder(recordSchema); @@ -508,7 +523,7 @@ public void shouldReadMultipleChoiceUnion() { record = deserializeGenericFast(recordSchema, recordSchema, serializeGeneric(builder.build())); // then - Assert.assertEquals("abc", record.get("union")); + Assert.assertEquals("abc", record.get("union").toString()); // given builder = new GenericRecordBuilder(recordSchema); @@ -543,8 +558,8 @@ public void shouldReadArrayOfRecords() { // then Assert.assertEquals(2, array.size()); - Assert.assertEquals("abc", array.get(0).get("field")); - Assert.assertEquals("abc", array.get(1).get("field")); + Assert.assertEquals("abc", array.get(0).get("field").toString()); + Assert.assertEquals("abc", array.get(1).get("field").toString()); // given @@ -562,10 +577,115 @@ public void shouldReadArrayOfRecords() { // then Assert.assertEquals(2, array.size()); - Assert.assertEquals("abc", array.get(0).get("field")); - Assert.assertEquals("abc", array.get(1).get("field")); + Assert.assertEquals("abc", array.get(0).get("field").toString()); + Assert.assertEquals("abc", array.get(1).get("field").toString()); } + @Test + public void shouldReadArrayOfPrimitives() { + // given + Schema stringArraySchema = Schema.createArray(Schema.create(Schema.Type.STRING)); + + GenericData.Array stringArray = new GenericData.Array<>(0, stringArraySchema); + stringArray.add("aaa"); + stringArray.add("abc"); + + Schema intArraySchema = Schema.createArray(Schema.create(Schema.Type.INT)); + + GenericData.Array intArray = new GenericData.Array<>(0, intArraySchema); + intArray.add(1); + intArray.add(2); + + Schema longArraySchema = Schema.createArray(Schema.create(Schema.Type.LONG)); + + GenericData.Array longArray = new GenericData.Array<>(0, longArraySchema); + longArray.add(1L); + longArray.add(2L); + + Schema doubleArraySchema = Schema.createArray(Schema.create(Schema.Type.DOUBLE)); + + GenericData.Array doubleArray = new GenericData.Array<>(0, doubleArraySchema); + doubleArray.add(1.0); + doubleArray.add(2.0); + + Schema floatArraySchema = Schema.createArray(Schema.create(Schema.Type.FLOAT)); + + GenericData.Array floatArray = new GenericData.Array<>(0, floatArraySchema); + floatArray.add(1.0f); + floatArray.add(2.0f); + + Schema bytesArraySchema = Schema.createArray(Schema.create(Schema.Type.BYTES)); + + GenericData.Array bytesArray = new GenericData.Array<>(0, bytesArraySchema); + bytesArray.add(ByteBuffer.wrap(new byte[]{0x01})); + bytesArray.add(ByteBuffer.wrap(new byte[]{0x02})); + + // when + GenericData.Array resultStringArray = deserializeGenericFast(stringArraySchema, stringArraySchema, + serializeGeneric(stringArray)); + + GenericData.Array resultIntegerArray = deserializeGenericFast(intArraySchema, intArraySchema, + serializeGeneric(intArray)); + + GenericData.Array resultLongArray = deserializeGenericFast(longArraySchema, longArraySchema, + serializeGeneric(longArray)); + + GenericData.Array resultDoubleArray = deserializeGenericFast(doubleArraySchema, doubleArraySchema, + serializeGeneric(doubleArray)); + + GenericData.Array resultFloatArray = deserializeGenericFast(floatArraySchema, floatArraySchema, + serializeGeneric(floatArray)); + + GenericData.Array resultBytesArray = deserializeGenericFast(bytesArraySchema, bytesArraySchema, + serializeGeneric(bytesArray)); + + // then + Assert.assertEquals(2, resultStringArray.size()); + Assert.assertEquals("aaa", resultStringArray.get(0).toString()); + Assert.assertEquals("abc", resultStringArray.get(1).toString()); + + Assert.assertEquals(2, resultIntegerArray.size()); + Assert.assertEquals(Integer.valueOf(1), resultIntegerArray.get(0)); + Assert.assertEquals(Integer.valueOf(2), resultIntegerArray.get(1)); + + Assert.assertEquals(2, resultLongArray.size()); + Assert.assertEquals(Long.valueOf(1L), resultLongArray.get(0)); + Assert.assertEquals(Long.valueOf(2L), resultLongArray.get(1)); + + Assert.assertEquals(2, resultDoubleArray.size()); + Assert.assertEquals(Double.valueOf(1.0), resultDoubleArray.get(0)); + Assert.assertEquals(Double.valueOf(2.0), resultDoubleArray.get(1)); + + Assert.assertEquals(2, resultFloatArray.size()); + Assert.assertEquals(Float.valueOf(1f), resultFloatArray.get(0)); + Assert.assertEquals(Float.valueOf(2f), resultFloatArray.get(1)); + + Assert.assertEquals(2, resultBytesArray.size()); + Assert.assertEquals(0x01, resultBytesArray.get(0).get()); + Assert.assertEquals(0x02, resultBytesArray.get(1).get()); + } + + @Test + public void shouldReadArrayOfJavaStrings() { + // given + Schema javaStringSchema = Schema.create(Schema.Type.STRING); + GenericData.setStringType(javaStringSchema, GenericData.StringType.String); + Schema javaStringArraySchema = Schema.createArray(javaStringSchema); + + GenericData.Array javaStringArray = new GenericData.Array<>(0, javaStringArraySchema); + javaStringArray.add("aaa"); + javaStringArray.add("abc"); + + GenericData.Array resultJavaStringArray = deserializeGenericFast(javaStringArraySchema, javaStringArraySchema, + serializeGeneric(javaStringArray)); + + // then + Assert.assertEquals(2, resultJavaStringArray.size()); + Assert.assertEquals("aaa", resultJavaStringArray.get(0)); + Assert.assertEquals("abc", resultJavaStringArray.get(1)); + } + + @Test public void shouldReadMapOfRecords() { // given @@ -582,31 +702,155 @@ public void shouldReadMapOfRecords() { recordsMap.put("2", subRecordBuilder.build()); // when - Map map = deserializeGenericFast(mapRecordSchema, mapRecordSchema, + Map map = deserializeGenericFast(mapRecordSchema, mapRecordSchema, serializeGeneric(recordsMap, mapRecordSchema)); // then Assert.assertEquals(2, map.size()); - Assert.assertEquals("abc", map.get("1").get("field")); - Assert.assertEquals("abc", map.get("2").get("field")); + Assert.assertEquals("abc", map.get(new Utf8("1")).get("field").toString()); + Assert.assertEquals("abc", map.get(new Utf8("2")).get("field").toString()); // given mapRecordSchema = Schema.createMap(createUnionSchema(recordSchema)); - subRecordBuilder = new GenericRecordBuilder(recordSchema); + // when + map = deserializeGenericFast(mapRecordSchema, mapRecordSchema, serializeGeneric(recordsMap, mapRecordSchema)); + + // then + Assert.assertEquals(2, map.size()); + Assert.assertEquals("abc", map.get(new Utf8("1")).get("field").toString()); + Assert.assertEquals("abc", map.get(new Utf8("2")).get("field").toString()); + } + + @Test + public void shouldReadMapOfPrimitives() { + // given + Schema stringMapSchema = Schema.createMap(Schema.create(Schema.Type.STRING)); + + Map stringMap = new HashMap<>(0); + stringMap.put("1", "abc"); + stringMap.put("2", "aaa"); + + Schema intMapSchema = Schema.createMap(Schema.create(Schema.Type.INT)); + + Map intMap = new HashMap<>(0); + intMap.put("1", 1); + intMap.put("2", 2); + + Schema longMapSchema = Schema.createMap(Schema.create(Schema.Type.LONG)); + + Map longMap = new HashMap<>(0); + longMap.put("1", 1L); + longMap.put("2", 2L); + + Schema doubleMapSchema = Schema.createMap(Schema.create(Schema.Type.DOUBLE)); + + Map doubleMap = new HashMap<>(0); + doubleMap.put("1", 1.0); + doubleMap.put("2", 2.0); + + Schema floatMapSchema = Schema.createMap(Schema.create(Schema.Type.FLOAT)); + + Map floatMap = new HashMap<>(0); + floatMap.put("1", 1.0f); + floatMap.put("2", 2.0f); + + Schema bytesMapSchema = Schema.createMap(Schema.create(Schema.Type.BYTES)); + + Map bytesMap = new HashMap<>(0); + bytesMap.put("1", ByteBuffer.wrap(new byte[]{0x01})); + bytesMap.put("2", ByteBuffer.wrap(new byte[]{0x02})); + + // when + Map resultStringMap = deserializeGenericFast(stringMapSchema, stringMapSchema, + serializeGeneric(stringMap, stringMapSchema)); + + Map resultIntegerMap = deserializeGenericFast(intMapSchema, intMapSchema, + serializeGeneric(intMap, intMapSchema)); + + Map resultLongMap = deserializeGenericFast(longMapSchema, longMapSchema, + serializeGeneric(longMap, longMapSchema)); + + Map resultDoubleMap = deserializeGenericFast(doubleMapSchema, doubleMapSchema, + serializeGeneric(doubleMap, doubleMapSchema)); + + Map resultFloatMap = deserializeGenericFast(floatMapSchema, floatMapSchema, + serializeGeneric(floatMap, floatMapSchema)); + + Map resultBytesMap = deserializeGenericFast(bytesMapSchema, bytesMapSchema, + serializeGeneric(bytesMap, bytesMapSchema)); + + // then + Assert.assertEquals(2, resultStringMap.size()); + Assert.assertEquals("abc", resultStringMap.get(new Utf8("1")).toString()); + Assert.assertEquals("aaa", resultStringMap.get(new Utf8("2")).toString()); + + Assert.assertEquals(2, resultIntegerMap.size()); + Assert.assertEquals(Integer.valueOf(1), resultIntegerMap.get(new Utf8("1"))); + Assert.assertEquals(Integer.valueOf(2), resultIntegerMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultLongMap.size()); + Assert.assertEquals(Long.valueOf(1L), resultLongMap.get(new Utf8("1"))); + Assert.assertEquals(Long.valueOf(2L), resultLongMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultDoubleMap.size()); + Assert.assertEquals(Double.valueOf(1.0), resultDoubleMap.get(new Utf8("1"))); + Assert.assertEquals(Double.valueOf(2.0), resultDoubleMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultFloatMap.size()); + Assert.assertEquals(Float.valueOf(1f), resultFloatMap.get(new Utf8("1"))); + Assert.assertEquals(Float.valueOf(2f), resultFloatMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultBytesMap.size()); + Assert.assertEquals(0x01, resultBytesMap.get(new Utf8("1")).get()); + Assert.assertEquals(0x02, resultBytesMap.get(new Utf8("2")).get()); + } + + @Test + public void shouldReadMapOfJavaStrings() { + // given + Schema stringMapSchema = Schema.createMap(Schema.create(Schema.Type.STRING)); + Schema javaStringSchema = Schema.create(Schema.Type.STRING); + GenericData.setStringType(javaStringSchema, GenericData.StringType.String); + Schema javaStringMapSchema = Schema.createMap(javaStringSchema); + + Map stringMap = new HashMap<>(0); + stringMap.put("1", "abc"); + stringMap.put("2", "aaa"); + + // when + Map resultJavaStringMap = deserializeGenericFast(stringMapSchema, javaStringMapSchema, + serializeGeneric(stringMap, javaStringMapSchema)); + + // then + Assert.assertEquals(2, resultJavaStringMap.size()); + Assert.assertEquals("abc", resultJavaStringMap.get(new Utf8("1"))); + Assert.assertEquals("aaa", resultJavaStringMap.get(new Utf8("2"))); + } + + @Test + public void shouldReadJavaStringKeyedMapOfRecords() { + // given + Schema recordSchema = createRecord("record", + createPrimitiveUnionFieldSchema("field", Schema.Type.STRING)); + + Schema mapRecordSchema = Schema.createMap(recordSchema); + GenericData.setStringType(mapRecordSchema, GenericData.StringType.String); + + GenericRecordBuilder subRecordBuilder = new GenericRecordBuilder(recordSchema); subRecordBuilder.set("field", "abc"); - recordsMap = new HashMap<>(); + Map recordsMap = new HashMap<>(); recordsMap.put("1", subRecordBuilder.build()); recordsMap.put("2", subRecordBuilder.build()); // when - map = deserializeGenericFast(mapRecordSchema, mapRecordSchema, serializeGeneric(recordsMap, mapRecordSchema)); + Map mapWithStringKeys = deserializeGenericFast(mapRecordSchema, mapRecordSchema, serializeGeneric(recordsMap, mapRecordSchema)); // then - Assert.assertEquals(2, map.size()); - Assert.assertEquals("abc", map.get("1").get("field")); - Assert.assertEquals("abc", map.get("2").get("field")); + Assert.assertEquals(2, mapWithStringKeys.size()); + Assert.assertEquals("abc", mapWithStringKeys.get("1").get("field").toString()); + Assert.assertEquals("abc", mapWithStringKeys.get("2").get("field").toString()); } @Test @@ -621,14 +865,14 @@ public void shouldDeserializeNullElementInMap() { records.put("2", 2); // when - Map map = deserializeGenericFast(mapRecordSchema, mapRecordSchema, + Map map = deserializeGenericFast(mapRecordSchema, mapRecordSchema, serializeGeneric(records, mapRecordSchema)); // then Assert.assertEquals(3, map.size()); - Assert.assertEquals("0", map.get("0")); - Assert.assertNull(map.get("1")); - Assert.assertEquals(2, map.get("2")); + Assert.assertEquals("0", map.get(new Utf8("0")).toString()); + Assert.assertNull(map.get(new Utf8("1"))); + Assert.assertEquals(2, map.get(new Utf8("2"))); } @Test @@ -648,7 +892,7 @@ public void shouldDeserializeNullElementInArray() { // then Assert.assertEquals(3, array.size()); - Assert.assertEquals("0", array.get(0)); + Assert.assertEquals("0", array.get(0).toString()); Assert.assertNull(array.get(1)); Assert.assertEquals(2, array.get(2)); } diff --git a/src/test/java/com/rtbhouse/utils/avro/FastGenericSerializerGeneratorTest.java b/src/test/java/com/rtbhouse/utils/avro/FastGenericSerializerGeneratorTest.java index 1876c4c..85ae41e 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastGenericSerializerGeneratorTest.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastGenericSerializerGeneratorTest.java @@ -48,18 +48,22 @@ public void prepare() throws Exception { Path tempPath = Files.createTempDirectory("generated"); tempDir = tempPath.toFile(); - classLoader = URLClassLoader.newInstance(new URL[] { tempDir.toURI().toURL() }, + classLoader = URLClassLoader.newInstance(new URL[]{tempDir.toURI().toURL()}, FastGenericSerializerGeneratorTest.class.getClassLoader()); } @Test public void shouldWritePrimitives() { // given + Schema javaLangStringSchema = Schema.create(Schema.Type.STRING); + GenericData.setStringType(javaLangStringSchema, GenericData.StringType.String); Schema recordSchema = createRecord("testRecord", createField("testInt", Schema.create(Schema.Type.INT)), createPrimitiveUnionFieldSchema("testIntUnion", Schema.Type.INT), createField("testString", Schema.create(Schema.Type.STRING)), createPrimitiveUnionFieldSchema("testStringUnion", Schema.Type.STRING), + createField("testJavaString", javaLangStringSchema), + createUnionField("testJavaStringUnion", javaLangStringSchema), createField("testLong", Schema.create(Schema.Type.LONG)), createPrimitiveUnionFieldSchema("testLongUnion", Schema.Type.LONG), createField("testDouble", Schema.create(Schema.Type.DOUBLE)), @@ -76,16 +80,18 @@ public void shouldWritePrimitives() { builder.set("testIntUnion", 1); builder.set("testString", "aaa"); builder.set("testStringUnion", "aaa"); - builder.set("testLong", 1l); - builder.set("testLongUnion", 1l); + builder.set("testJavaString", "aaa"); + builder.set("testJavaStringUnion", "aaa"); + builder.set("testLong", 1L); + builder.set("testLongUnion", 1L); builder.set("testDouble", 1.0); builder.set("testDoubleUnion", 1.0); builder.set("testFloat", 1.0f); builder.set("testFloatUnion", 1.0f); builder.set("testBoolean", true); builder.set("testBooleanUnion", true); - builder.set("testBytes", ByteBuffer.wrap(new byte[] { 0x01, 0x02 })); - builder.set("testBytesUnion", ByteBuffer.wrap(new byte[] { 0x01, 0x02 })); + builder.set("testBytes", ByteBuffer.wrap(new byte[]{0x01, 0x02})); + builder.set("testBytesUnion", ByteBuffer.wrap(new byte[]{0x01, 0x02})); // when GenericRecord record = deserializeGeneric(recordSchema, serializeGenericFast(builder.build())); @@ -95,20 +101,23 @@ public void shouldWritePrimitives() { Assert.assertEquals(1, record.get("testIntUnion")); Assert.assertEquals("aaa", record.get("testString").toString()); Assert.assertEquals("aaa", record.get("testStringUnion").toString()); - Assert.assertEquals(1l, record.get("testLong")); - Assert.assertEquals(1l, record.get("testLongUnion")); + Assert.assertEquals("aaa", record.get("testJavaString")); + Assert.assertEquals("aaa", record.get("testJavaStringUnion")); + Assert.assertEquals(1L, record.get("testLong")); + Assert.assertEquals(1L, record.get("testLongUnion")); Assert.assertEquals(1.0, record.get("testDouble")); Assert.assertEquals(1.0, record.get("testDoubleUnion")); Assert.assertEquals(1.0f, record.get("testFloat")); Assert.assertEquals(1.0f, record.get("testFloatUnion")); Assert.assertEquals(true, record.get("testBoolean")); Assert.assertEquals(true, record.get("testBooleanUnion")); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0x01, 0x02 }), record.get("testBytes")); - Assert.assertEquals(ByteBuffer.wrap(new byte[] { 0x01, 0x02 }), record.get("testBytesUnion")); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), record.get("testBytes")); + Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), record.get("testBytesUnion")); } @Test + @SuppressWarnings("unchecked") public void shouldWriteFixed() { // given Schema fixedSchema = createFixedSchema("testFixed", 2); @@ -117,28 +126,29 @@ public void shouldWriteFixed() { createArrayFieldSchema("testFixedUnionArray", createUnionSchema(fixedSchema))); GenericRecordBuilder builder = new GenericRecordBuilder(recordSchema); - builder.set("testFixed", new GenericData.Fixed(fixedSchema, new byte[] { 0x01, 0x02 })); - builder.set("testFixedUnion", new GenericData.Fixed(fixedSchema, new byte[] { 0x03, 0x04 })); - builder.set("testFixedArray", Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[] { 0x05, 0x06 }))); + builder.set("testFixed", new GenericData.Fixed(fixedSchema, new byte[]{0x01, 0x02})); + builder.set("testFixedUnion", new GenericData.Fixed(fixedSchema, new byte[]{0x03, 0x04})); + builder.set("testFixedArray", Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[]{0x05, 0x06}))); builder.set("testFixedUnionArray", - Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[] { 0x07, 0x08 }))); + Arrays.asList(new GenericData.Fixed(fixedSchema, new byte[]{0x07, 0x08}))); // when GenericRecord record = deserializeGeneric(recordSchema, serializeGenericFast(builder.build())); // then - Assert.assertArrayEquals(new byte[] { 0x01, 0x02 }, ((GenericData.Fixed) record.get("testFixed")).bytes()); - Assert.assertArrayEquals(new byte[] { 0x03, 0x04 }, ((GenericData.Fixed) record.get("testFixedUnion")).bytes()); - Assert.assertArrayEquals(new byte[] { 0x05, 0x06 }, + Assert.assertArrayEquals(new byte[]{0x01, 0x02}, ((GenericData.Fixed) record.get("testFixed")).bytes()); + Assert.assertArrayEquals(new byte[]{0x03, 0x04}, ((GenericData.Fixed) record.get("testFixedUnion")).bytes()); + Assert.assertArrayEquals(new byte[]{0x05, 0x06}, ((List) record.get("testFixedArray")).get(0).bytes()); - Assert.assertArrayEquals(new byte[] { 0x07, 0x08 }, + Assert.assertArrayEquals(new byte[]{0x07, 0x08}, ((List) record.get("testFixedUnionArray")).get(0).bytes()); } @Test + @SuppressWarnings("unchecked") public void shouldWriteEnum() { // given - Schema enumSchema = createEnumSchema("testEnum", new String[] { "A", "B" }); + Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B"}); Schema recordSchema = createRecord("testRecord", createField("testEnum", enumSchema), createUnionField("testEnumUnion", enumSchema), createArrayFieldSchema("testEnumArray", enumSchema), createArrayFieldSchema("testEnumUnionArray", createUnionSchema(enumSchema))); @@ -188,6 +198,7 @@ public void shouldWriteSubRecordField() { } @Test + @SuppressWarnings("unchecked") public void shouldWriteSubRecordCollectionsField() { // given Schema subRecordSchema = createRecord("subRecord", @@ -219,14 +230,15 @@ public void shouldWriteSubRecordCollectionsField() { Assert.assertEquals("abc", ((List) record.get("recordsArrayUnion")).get(0).get("subField").toString()); Assert.assertEquals("abc", - ((Map) record.get("recordsMap")).get(new Utf8("1")).get("subField") + ((Map) record.get("recordsMap")).get(new Utf8("1")).get("subField") .toString()); Assert.assertEquals("abc", - ((Map) record.get("recordsMapUnion")).get(new Utf8("1")).get("subField") + ((Map) record.get("recordsMapUnion")).get(new Utf8("1")).get("subField") .toString()); } @Test + @SuppressWarnings("unchecked") public void shouldWriteSubRecordComplexCollectionsField() { // given Schema subRecordSchema = createRecord("subRecord", @@ -303,7 +315,7 @@ public void shouldWriteMultipleChoiceUnion() { // given builder = new GenericRecordBuilder(recordSchema); - builder.set("union", "abc"); + builder.set("union", new Utf8("abc")); // when record = deserializeGeneric(recordSchema, serializeGenericFast(builder.build())); @@ -323,6 +335,90 @@ record = deserializeGeneric(recordSchema, serializeGenericFast(builder.build())) } + @Test + public void shouldWriteArrayOfPrimitives() { + // given + Schema stringArraySchema = Schema.createArray(Schema.create(Schema.Type.STRING)); + + GenericData.Array stringArray = new GenericData.Array<>(0, stringArraySchema); + stringArray.add("aaa"); + stringArray.add("abc"); + + Schema intArraySchema = Schema.createArray(Schema.create(Schema.Type.INT)); + + GenericData.Array intArray = new GenericData.Array<>(0, intArraySchema); + intArray.add(1); + intArray.add(2); + + Schema longArraySchema = Schema.createArray(Schema.create(Schema.Type.LONG)); + + GenericData.Array longArray = new GenericData.Array<>(0, longArraySchema); + longArray.add(1L); + longArray.add(2L); + + Schema doubleArraySchema = Schema.createArray(Schema.create(Schema.Type.DOUBLE)); + + GenericData.Array doubleArray = new GenericData.Array<>(0, doubleArraySchema); + doubleArray.add(1.0); + doubleArray.add(2.0); + + Schema floatArraySchema = Schema.createArray(Schema.create(Schema.Type.FLOAT)); + + GenericData.Array floatArray = new GenericData.Array<>(0, floatArraySchema); + floatArray.add(1.0f); + floatArray.add(2.0f); + + Schema bytesArraySchema = Schema.createArray(Schema.create(Schema.Type.BYTES)); + + GenericData.Array bytesArray = new GenericData.Array<>(0, bytesArraySchema); + bytesArray.add(ByteBuffer.wrap(new byte[]{0x01})); + bytesArray.add(ByteBuffer.wrap(new byte[]{0x02})); + + // when + GenericData.Array resultStringArray = deserializeGeneric(stringArraySchema, + serializeGenericFast(stringArray)); + + GenericData.Array resultIntegerArray = deserializeGeneric(intArraySchema, + serializeGenericFast(intArray)); + + GenericData.Array resultLongArray = deserializeGeneric(longArraySchema, + serializeGenericFast(longArray)); + + GenericData.Array resultDoubleArray = deserializeGeneric(doubleArraySchema, + serializeGenericFast(doubleArray)); + + GenericData.Array resultFloatArray = deserializeGeneric(floatArraySchema, + serializeGenericFast(floatArray)); + + GenericData.Array resultBytesArray = deserializeGeneric(bytesArraySchema, + serializeGenericFast(bytesArray)); + + // then + Assert.assertEquals(2, resultStringArray.size()); + Assert.assertEquals("aaa", resultStringArray.get(0).toString()); + Assert.assertEquals("abc", resultStringArray.get(1).toString()); + + Assert.assertEquals(2, resultIntegerArray.size()); + Assert.assertEquals(Integer.valueOf(1), resultIntegerArray.get(0)); + Assert.assertEquals(Integer.valueOf(2), resultIntegerArray.get(1)); + + Assert.assertEquals(2, resultLongArray.size()); + Assert.assertEquals(Long.valueOf(1L), resultLongArray.get(0)); + Assert.assertEquals(Long.valueOf(2L), resultLongArray.get(1)); + + Assert.assertEquals(2, resultDoubleArray.size()); + Assert.assertEquals(Double.valueOf(1.0), resultDoubleArray.get(0)); + Assert.assertEquals(Double.valueOf(2.0), resultDoubleArray.get(1)); + + Assert.assertEquals(2, resultFloatArray.size()); + Assert.assertEquals(Float.valueOf(1f), resultFloatArray.get(0)); + Assert.assertEquals(Float.valueOf(2f), resultFloatArray.get(1)); + + Assert.assertEquals(2, resultBytesArray.size()); + Assert.assertEquals(0x01, resultBytesArray.get(0).get()); + Assert.assertEquals(0x02, resultBytesArray.get(1).get()); + } + @Test public void shouldWriteArrayOfRecords() { // given @@ -367,6 +463,90 @@ public void shouldWriteArrayOfRecords() { Assert.assertEquals("abc", array.get(1).get("field").toString()); } + @Test + public void shouldWriteMapOfPrimitives() { + // given + Schema stringMapSchema = Schema.createMap(Schema.create(Schema.Type.STRING)); + + Map stringMap = new HashMap<>(0); + stringMap.put("1", "abc"); + stringMap.put("2", "aaa"); + + Schema intMapSchema = Schema.createMap(Schema.create(Schema.Type.INT)); + + Map intMap = new HashMap<>(0); + intMap.put("1", 1); + intMap.put("2", 2); + + Schema longMapSchema = Schema.createMap(Schema.create(Schema.Type.LONG)); + + Map longMap = new HashMap<>(0); + longMap.put("1", 1L); + longMap.put("2", 2L); + + Schema doubleMapSchema = Schema.createMap(Schema.create(Schema.Type.DOUBLE)); + + Map doubleMap = new HashMap<>(0); + doubleMap.put("1", 1.0); + doubleMap.put("2", 2.0); + + Schema floatMapSchema = Schema.createMap(Schema.create(Schema.Type.FLOAT)); + + Map floatMap = new HashMap<>(0); + floatMap.put("1", 1.0f); + floatMap.put("2", 2.0f); + + Schema bytesMapSchema = Schema.createMap(Schema.create(Schema.Type.BYTES)); + + Map bytesMap = new HashMap<>(0); + bytesMap.put("1", ByteBuffer.wrap(new byte[]{0x01})); + bytesMap.put("2", ByteBuffer.wrap(new byte[]{0x02})); + + // when + Map resultStringMap = deserializeGeneric(stringMapSchema, + serializeGenericFast(stringMap, stringMapSchema)); + + Map resultIntegerMap = deserializeGeneric(intMapSchema, + serializeGenericFast(intMap, intMapSchema)); + + Map resultLongMap = deserializeGeneric(longMapSchema, + serializeGenericFast(longMap, longMapSchema)); + + Map resultDoubleMap = deserializeGeneric(doubleMapSchema, + serializeGenericFast(doubleMap, doubleMapSchema)); + + Map resultFloatMap = deserializeGeneric(floatMapSchema, + serializeGenericFast(floatMap, floatMapSchema)); + + Map resultBytesMap = deserializeGeneric(bytesMapSchema, + serializeGenericFast(bytesMap, bytesMapSchema)); + + // then + Assert.assertEquals(2, resultStringMap.size()); + Assert.assertEquals("abc", resultStringMap.get(new Utf8("1")).toString()); + Assert.assertEquals("aaa", resultStringMap.get(new Utf8("2")).toString()); + + Assert.assertEquals(2, resultIntegerMap.size()); + Assert.assertEquals(Integer.valueOf(1), resultIntegerMap.get(new Utf8("1"))); + Assert.assertEquals(Integer.valueOf(2), resultIntegerMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultLongMap.size()); + Assert.assertEquals(Long.valueOf(1L), resultLongMap.get(new Utf8("1"))); + Assert.assertEquals(Long.valueOf(2L), resultLongMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultDoubleMap.size()); + Assert.assertEquals(Double.valueOf(1.0), resultDoubleMap.get(new Utf8("1"))); + Assert.assertEquals(Double.valueOf(2.0), resultDoubleMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultFloatMap.size()); + Assert.assertEquals(Float.valueOf(1f), resultFloatMap.get(new Utf8("1"))); + Assert.assertEquals(Float.valueOf(2f), resultFloatMap.get(new Utf8("2"))); + + Assert.assertEquals(2, resultBytesMap.size()); + Assert.assertEquals(0x01, resultBytesMap.get(new Utf8("1")).get()); + Assert.assertEquals(0x02, resultBytesMap.get(new Utf8("2")).get()); + } + @Test public void shouldWriteMapOfRecords() { // given diff --git a/src/test/java/com/rtbhouse/utils/avro/FastSerdeTestsSupport.java b/src/test/java/com/rtbhouse/utils/avro/FastSerdeTestsSupport.java index e857088..16f51f2 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastSerdeTestsSupport.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastSerdeTestsSupport.java @@ -10,6 +10,7 @@ import org.apache.avro.Schema; import org.apache.avro.generic.GenericContainer; +import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.io.BinaryEncoder; @@ -23,18 +24,20 @@ public final class FastSerdeTestsSupport { + public static final String NAMESPACE = "com.rtbhouse.utils.generated.avro"; + private FastSerdeTestsSupport() { } public static Schema createRecord(String name, Schema.Field... fields) { - Schema schema = Schema.createRecord(name, name, "com.adpilot.utils.generated.avro", false); + Schema schema = Schema.createRecord(name, name, NAMESPACE, false); schema.setFields(Arrays.asList(fields)); return schema; } public static Schema.Field createField(String name, Schema schema) { - return new Schema.Field(name, schema, "", null, Schema.Field.Order.ASCENDING); + return new Schema.Field(name, schema, "", (Object) null, Schema.Field.Order.ASCENDING); } public static Schema.Field createUnionField(String name, Schema... schemas) { @@ -43,38 +46,38 @@ public static Schema.Field createUnionField(String name, Schema... schemas) { typeList.addAll(Arrays.asList(schemas)); Schema unionSchema = Schema.createUnion(typeList); - return new Schema.Field(name, unionSchema, null, NullNode.getInstance(), Schema.Field.Order.ASCENDING); + return new Schema.Field(name, unionSchema, null, Schema.Field.Order.ASCENDING); } public static Schema.Field createPrimitiveFieldSchema(String name, Schema.Type type) { - return new Schema.Field(name, Schema.create(type), null, null); + return new Schema.Field(name, Schema.create(type), null, (Object) null); } public static Schema.Field createPrimitiveUnionFieldSchema(String name, Schema.Type... types) { List typeList = new ArrayList<>(); typeList.add(Schema.create(Schema.Type.NULL)); - typeList.addAll(Arrays.asList(types).stream().map(Schema::create).collect(Collectors.toList())); + typeList.addAll(Arrays.stream(types).map(Schema::create).collect(Collectors.toList())); Schema unionSchema = Schema.createUnion(typeList); - return new Schema.Field(name, unionSchema, null, NullNode.getInstance(), Schema.Field.Order.ASCENDING); + return new Schema.Field(name, unionSchema, null, (Object) null, Schema.Field.Order.ASCENDING); } public static Schema.Field createArrayFieldSchema(String name, Schema elementType, String... aliases) { - return addAliases(new Schema.Field(name, Schema.createArray(elementType), null, null, + return addAliases(new Schema.Field(name, Schema.createArray(elementType), null, (Object) null, Schema.Field.Order.ASCENDING), aliases); } public static Schema.Field createMapFieldSchema(String name, Schema valueType, String... aliases) { - return addAliases(new Schema.Field(name, Schema.createMap(valueType), null, null, + return addAliases(new Schema.Field(name, Schema.createMap(valueType), null, (Object) null, Schema.Field.Order.ASCENDING), aliases); } public static Schema createFixedSchema(String name, int size) { - return Schema.createFixed(name, "", "com.adpilot.utils.generated.avro", size); + return Schema.createFixed(name, "", NAMESPACE, size); } public static Schema createEnumSchema(String name, String[] ordinals) { - return Schema.createEnum(name, "", "com.adpilot.utils.generated.avro", Arrays.asList(ordinals)); + return Schema.createEnum(name, "", NAMESPACE, Arrays.asList(ordinals)); } public static Schema createUnionSchema(Schema... schemas) { @@ -158,11 +161,11 @@ public static com.rtbhouse.utils.generated.avro.TestRecord emptyTestRecord() { record.put("testFixed", new com.rtbhouse.utils.generated.avro.TestFixed(new byte[] { 0x01 })); record.put("testFixedArray", Collections.EMPTY_LIST); - record.put("testFixedUnionArray", Arrays.asList(new com.rtbhouse.utils.generated.avro.TestFixed(new byte[] { 0x01 }))); + record.put("testFixedUnionArray", Collections.singletonList(new com.rtbhouse.utils.generated.avro.TestFixed(new byte[] { 0x01 }))); record.put("testEnum", com.rtbhouse.utils.generated.avro.TestEnum.A); record.put("testEnumArray", Collections.EMPTY_LIST); - record.put("testEnumUnionArray", Arrays.asList(com.rtbhouse.utils.generated.avro.TestEnum.A)); + record.put("testEnumUnionArray", Collections.singletonList(com.rtbhouse.utils.generated.avro.TestEnum.A)); record.put("subRecord", new com.rtbhouse.utils.generated.avro.SubRecord()); record.put("recordsArray", Collections.emptyList()); @@ -171,7 +174,7 @@ public static com.rtbhouse.utils.generated.avro.TestRecord emptyTestRecord() { record.put("recordsMapArray", Collections.emptyMap()); record.put("testInt", 1); - record.put("testLong", 1l); + record.put("testLong", 1L); record.put("testDouble", 1.0); record.put("testFloat", 1.0f); record.put("testBoolean", true); diff --git a/src/test/java/com/rtbhouse/utils/avro/FastSpecificDeserializerGeneratorTest.java b/src/test/java/com/rtbhouse/utils/avro/FastSpecificDeserializerGeneratorTest.java index 4eacf49..345956e 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastSpecificDeserializerGeneratorTest.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastSpecificDeserializerGeneratorTest.java @@ -22,6 +22,7 @@ import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; import org.apache.avro.io.Decoder; +import org.apache.avro.util.Utf8; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -497,13 +498,13 @@ public void shouldReadMapOfRecords() { recordsMap.put("2", testRecord); // when - Map map = deserializeSpecificFast(mapRecordSchema, mapRecordSchema, + Map map = deserializeSpecificFast(mapRecordSchema, mapRecordSchema, serializeSpecific(recordsMap, mapRecordSchema)); // then Assert.assertEquals(2, map.size()); - Assert.assertEquals("abc", map.get("1").get("testStringUnion")); - Assert.assertEquals("abc", map.get("2").get("testStringUnion")); + Assert.assertEquals("abc", map.get(new Utf8("1")).get("testStringUnion").toString()); + Assert.assertEquals("abc", map.get(new Utf8("2")).get("testStringUnion").toString()); // given mapRecordSchema = Schema.createMap(createUnionSchema(TestRecord @@ -522,15 +523,16 @@ public void shouldReadMapOfRecords() { // then Assert.assertEquals(2, map.size()); - Assert.assertEquals("abc", map.get("1").get("testStringUnion")); - Assert.assertEquals("abc", map.get("2").get("testStringUnion")); + Assert.assertEquals("abc", map.get(new Utf8("1")).get("testStringUnion")); + Assert.assertEquals("abc", map.get(new Utf8("2")).get("testStringUnion")); } @Test public void shouldDeserializeNullElementInMap() { // given + Schema stringSchema = Schema.create(Schema.Type.STRING); Schema mapRecordSchema = Schema.createMap(Schema.createUnion( - Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT))); + stringSchema, Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT))); Map records = new HashMap<>(); records.put("0", "0"); @@ -538,14 +540,14 @@ public void shouldDeserializeNullElementInMap() { records.put("2", 2); // when - Map map = deserializeSpecificFast(mapRecordSchema, mapRecordSchema, + Map map = deserializeSpecificFast(mapRecordSchema, mapRecordSchema, serializeSpecific(records, mapRecordSchema)); // then Assert.assertEquals(3, map.size()); - Assert.assertEquals("0", map.get("0")); - Assert.assertNull(map.get("1")); - Assert.assertEquals(2, map.get("2")); + Assert.assertEquals("0", map.get(new Utf8("0")).toString()); + Assert.assertNull(map.get(new Utf8("1"))); + Assert.assertEquals(2, map.get(new Utf8("2"))); } @Test @@ -565,7 +567,7 @@ public void shouldDeserializeNullElementInArray() { // then Assert.assertEquals(3, array.size()); - Assert.assertEquals("0", array.get(0)); + Assert.assertEquals("0", array.get(0).toString()); Assert.assertNull(array.get(1)); Assert.assertEquals(2, array.get(2)); } diff --git a/src/test/java/com/rtbhouse/utils/avro/FastSpecificSerializerGeneratorTest.java b/src/test/java/com/rtbhouse/utils/avro/FastSpecificSerializerGeneratorTest.java index 70809dc..fed2308 100644 --- a/src/test/java/com/rtbhouse/utils/avro/FastSpecificSerializerGeneratorTest.java +++ b/src/test/java/com/rtbhouse/utils/avro/FastSpecificSerializerGeneratorTest.java @@ -19,6 +19,7 @@ import org.apache.avro.Schema; import org.apache.avro.generic.GenericContainer; +import org.apache.avro.generic.GenericData; import org.apache.avro.io.BinaryEncoder; import org.apache.avro.io.Decoder; import org.apache.avro.io.DecoderFactory; @@ -292,6 +293,7 @@ public void shouldWriteArrayOfRecords() { public void shouldWriteMapOfRecords() { // given Schema mapRecordSchema = Schema.createMap(TestRecord.getClassSchema()); + GenericData.setStringType(mapRecordSchema, GenericData.StringType.String); TestRecord testRecord = emptyTestRecord(); testRecord.put("testString", "abc"); @@ -301,17 +303,18 @@ public void shouldWriteMapOfRecords() { recordsMap.put("2", testRecord); // when - Map map = deserializeSpecific(mapRecordSchema, + Map map = deserializeSpecific(mapRecordSchema, serializeSpecificFast(recordsMap, mapRecordSchema)); // then Assert.assertEquals(2, map.size()); - Assert.assertEquals("abc", map.get(new Utf8("1")).get("testString")); - Assert.assertEquals("abc", map.get(new Utf8("2")).get("testString")); + Assert.assertEquals("abc", map.get("1").get("testString")); + Assert.assertEquals("abc", map.get("2").get("testString")); // given mapRecordSchema = Schema.createMap(createUnionSchema(TestRecord .getClassSchema())); + GenericData.setStringType(mapRecordSchema, GenericData.StringType.String); testRecord = emptyTestRecord(); testRecord.put("testString", "abc"); @@ -325,8 +328,8 @@ public void shouldWriteMapOfRecords() { // then Assert.assertEquals(2, map.size()); - Assert.assertEquals("abc", map.get(new Utf8("1")).get("testString")); - Assert.assertEquals("abc", map.get(new Utf8("2")).get("testString")); + Assert.assertEquals("abc", map.get("1").get("testString")); + Assert.assertEquals("abc", map.get("2").get("testString")); } @Test @@ -345,7 +348,7 @@ public void shouldSerializeNullElementInMap() { // then Assert.assertEquals(3, map.size()); - Assert.assertEquals(new Utf8("0"), map.get(new Utf8("0"))); + Assert.assertEquals("0", map.get(new Utf8("0")).toString()); Assert.assertNull(map.get(new Utf8("1"))); Assert.assertEquals(2, map.get(new Utf8("2"))); }