diff --git a/CHANGELOG.md b/CHANGELOG.md index a09f40bbc..2ebcb2862 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] +- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 5d90071e8..147e260b9 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -77,6 +77,11 @@ public float scoreTranslation(float rawScore) { return Math.max((2.0F - rawScore) / 2.0F, 0.0F); } + @Override + public float scoreToDistanceTranslation(float score) { + return score; + } + @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.COSINE; diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index f53655136..98153b267 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -13,6 +13,8 @@ import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; import org.opensearch.knn.index.mapper.SpaceVectorValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; +import org.opensearch.knn.index.mapper.VectorTransformerFactory; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.ArrayList; @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( return PerDimensionProcessor.NOOP_PROCESSOR; } + protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { + return VectorTransformerFactory.getVectorTransformer(knnMethodContext); + } + @Override public KNNLibraryIndexingContext getKNNLibraryIndexingContext( KNNMethodContext knnMethodContext, @@ -116,7 +122,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( knnMethodConfigContext ); Map parameterMap = knnLibraryIndexingContext.getLibraryParameters(); - parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + parameterMap.put(KNNConstants.SPACE_TYPE, getCompatibleSpaceType(knnMethodContext.getSpaceType()).getValue()); parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue()); return KNNLibraryIndexingContextImpl.builder() .quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig()) @@ -124,6 +130,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( .vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) + .vectorTransformer(getVectorTransformer(knnMethodContext)) .build(); } @@ -131,4 +138,19 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( public KNNLibrarySearchContext getKNNLibrarySearchContext() { return knnLibrarySearchContext; } + + /** + * Gets the compatible space type for the given space type parameter. + * This method validates and returns the appropriate space type that + * is compatible with the system's requirements. + * + * @param spaceType The space type to check for compatibility + * @return The compatible space type for the given input, returns the same + * space type if it's already compatible + * @see SpaceType + */ + + protected SpaceType getCompatibleSpaceType(SpaceType spaceType) { + return spaceType; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index 9208661af..283576bf6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -8,6 +8,7 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; @@ -47,4 +48,6 @@ public interface KNNLibraryIndexingContext { * @return Get the per dimension processor */ PerDimensionProcessor getPerDimensionProcessor(); + + VectorTransformer getVectorTransformer(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index f5329fc31..9822033b7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -9,6 +9,7 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Collections; @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext private VectorValidator vectorValidator; private PerDimensionValidator perDimensionValidator; private PerDimensionProcessor perDimensionProcessor; + private VectorTransformer vectorTransformer; @Builder.Default private Map parameters = Collections.emptyMap(); @Builder.Default @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() { return vectorValidator; } + @Override + public VectorTransformer getVectorTransformer() { + return vectorTransformer; + } + @Override public PerDimensionValidator getPerDimensionValidator() { return perDimensionValidator; diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 7ae403445..5810a5f54 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -17,6 +17,7 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import java.util.Objects; import java.util.Set; @@ -89,6 +90,11 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( throw new IllegalStateException("Unsupported vector data type " + vectorDataType); } + @Override + protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { + return super.getVectorTransformer(knnMethodContext); + } + static KNNLibraryIndexingContext adjustIndexDescription( MethodAsMapBuilder methodAsMapBuilder, MethodComponentContext methodComponentContext, @@ -132,4 +138,15 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m } return (MethodComponentContext) object; } + + @Override + protected SpaceType getCompatibleSpaceType(SpaceType spaceType) { + // While FAISS doesn't directly support cosine similarity, we can leverage the mathematical + // relationship between cosine similarity and inner product for normalized vectors to add support. + // When ||a|| = ||b|| = 1, cos(θ) = a · b + if (spaceType == SpaceType.COSINESIMIL) { + return SpaceType.INNER_PRODUCT; + } + return spaceType; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index a602619a1..d4222dc8d 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -26,6 +26,7 @@ */ public class Faiss extends NativeLibrary { public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B"; + Map> distanceTransform; Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name @@ -36,7 +37,10 @@ public class Faiss extends NativeLibrary { // Map that overrides OpenSearch score translation by space type of scores returned by faiss private final static Map> SCORE_TRANSLATIONS = ImmutableMap.of( SpaceType.INNER_PRODUCT, - rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) + rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore), + // COSINESIMIL expects the raw score in 1 - cosine(x,y) + SpaceType.COSINESIMIL, + rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore) ); // Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation: @@ -45,6 +49,10 @@ public class Faiss extends NativeLibrary { SpaceType, Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< + SpaceType, + Function>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build(); + // Package private so that the method resolving logic can access the methods final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); @@ -53,7 +61,8 @@ public class Faiss extends NativeLibrary { SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION, - SCORE_TO_DISTANCE_TRANSFORMATIONS + SCORE_TO_DISTANCE_TRANSFORMATIONS, + DISTANCE_TRANSLATIONS ); private final MethodResolver methodResolver; @@ -71,22 +80,25 @@ private Faiss( Map> scoreTranslation, String currentVersion, String extension, - Map> scoreTransform + Map> scoreTransform, + Map> distanceTransform ) { super(methods, scoreTranslation, currentVersion, extension); this.scoreTransform = scoreTransform; + this.distanceTransform = distanceTransform; this.methodResolver = new FaissMethodResolver(); } @Override public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { - // Faiss engine uses distance as is and does not need transformation + if (this.distanceTransform.containsKey(spaceType)) { + return this.distanceTransform.get(spaceType).apply(distance); + } return distance; } @Override public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { - // Faiss engine uses distance as is and need transformation if (this.scoreTransform.containsKey(spaceType)) { return this.scoreTransform.get(spaceType).apply(score); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index c153a9328..3386f871c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod { SpaceType.UNDEFINED, SpaceType.HAMMING, SpaceType.L2, - SpaceType.INNER_PRODUCT + SpaceType.INNER_PRODUCT, + SpaceType.COSINESIMIL ); private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 340c1f4d8..582029392 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod { SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, - SpaceType.HAMMING + SpaceType.HAMMING, + SpaceType.COSINESIMIL ); private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 8da41aa59..f671b07f8 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -109,4 +109,9 @@ protected PerDimensionValidator getPerDimensionValidator() { protected PerDimensionProcessor getPerDimensionProcessor() { return PerDimensionProcessor.NOOP_PROCESSOR; } + + @Override + protected VectorTransformer getVectorTransformer() { + return VectorTransformer.NOOP_VECTOR_TRANSFORMER; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 67f3efa5b..2b94d4e9f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -681,6 +681,8 @@ protected void validatePreparse() { */ protected abstract PerDimensionProcessor getPerDimensionProcessor(); + protected abstract VectorTransformer getVectorTransformer(); + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { validatePreparse(); @@ -691,7 +693,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForByteVector(array)); + final byte[] transformedArray = getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForByteVector(transformedArray)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -700,7 +703,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForFloatVector(array)); + final float[] transformedArray = getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForFloatVector(transformedArray)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4ceb9b4b2..83f3ce4c5 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -42,6 +42,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { private final PerDimensionProcessor perDimensionProcessor; private final PerDimensionValidator perDimensionValidator; private final VectorValidator vectorValidator; + private final VectorTransformer vectorTransformer; static LuceneFieldMapper createFieldMapper( String fullname, @@ -122,6 +123,7 @@ private LuceneFieldMapper( this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); } @Override @@ -169,6 +171,11 @@ protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + @Override + protected VectorTransformer getVectorTransformer() { + return vectorTransformer; + } + @Override void updateEngineStats() { KNNEngine.LUCENE.setInitialized(true); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 755439ce6..814bc4f63 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { private final PerDimensionProcessor perDimensionProcessor; private final PerDimensionValidator perDimensionValidator; private final VectorValidator vectorValidator; + private final VectorTransformer vectorTransformer; public static MethodFieldMapper createFieldMapper( String fullname, @@ -180,6 +181,7 @@ private MethodFieldMapper( this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); } @Override @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() { protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + + @Override + protected VectorTransformer getVectorTransformer() { + return vectorTransformer; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index cbc7520cf..2458e3355 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -41,6 +41,7 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { private PerDimensionProcessor perDimensionProcessor; private PerDimensionValidator perDimensionValidator; private VectorValidator vectorValidator; + private VectorTransformer vectorTransformer; private final String modelId; @@ -192,6 +193,31 @@ protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + @Override + protected VectorTransformer getVectorTransformer() { + initVectorTransformer(); + return vectorTransformer; + } + + private void initVectorTransformer() { + if (vectorTransformer != null) { + return; + } + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + + KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); + // Need to handle BWC case + if (knnMethodContext == null || knnMethodConfigContext == null) { + vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType()); + return; + } + + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); + } + private void initVectorValidator() { if (vectorValidator != null) { return; diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java new file mode 100644 index 000000000..6a9642435 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.util.VectorUtil; + +/** + * Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures + * that the vector's magnitude becomes 1 while preserving its directional properties. + */ +public class NormalizeVectorTransformer implements VectorTransformer { + + /** + * Transforms the input vector into unit vector by applying L2 normalization. + * + * @param vector The input vector to be normalized. Must not be null. + * @return A new float array containing the L2-normalized version of the input vector. + * Each component is divided by the Euclidean norm of the vector. + * @throws IllegalArgumentException if the input vector is null, empty, or a zero vector + */ + @Override + public float[] transform(float[] vector) { + if (vector == null || vector.length == 0) { + throw new IllegalArgumentException("Vector cannot be null or empty"); + } + return VectorUtil.l2normalize(vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java new file mode 100644 index 000000000..d29e4a460 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +/** + * Defines operations for transforming vectors in the k-NN search context. + * Implementations can modify vectors while preserving their dimensional properties + * for specific use cases such as normalization, scaling, or other transformations. + * + *

This interface provides default implementations that pass through the original + * vector without modification. Implementing classes should override these methods + * to provide specific transformation logic. + */ +public interface VectorTransformer { + + /** + * Transforms a float vector into a new vector of the same type. + * The default implementation returns the input vector unchanged. + * + * @param vector The input vector to transform + * @return The transformed vector + */ + default float[] transform(float[] vector) { + return vector; + } + + /** + * Transforms a byte vector into a new vector of the same type. + * The default implementation returns the input vector unchanged. + * + * @param vector The input vector to transform + * @return The transformed vector + */ + default byte[] transform(byte[] vector) { + return vector; + } + + /** + * A no-operation transformer that returns vectors unchanged. + * This constant can be used when no transformation is needed. + */ + VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java new file mode 100644 index 000000000..8726c48b0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; + +/** + * Factory class responsible for creating appropriate vector transformers based on the KNN method context. + * This factory determines whether vectors need transformation based on the engine type and space type. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class VectorTransformerFactory { + + /** + * Returns a vector transformer based on the provided KNN method context. + * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer + * since FAISS doesn't natively support cosine space type. For all other cases, + * returns a no-operation transformer. + * + * @param context The KNN method context containing engine and space type information + * @return VectorTransformer An appropriate vector transformer instance + * @throws IllegalArgumentException if the context parameter is null + */ + public static VectorTransformer getVectorTransformer(final KNNMethodContext context) { + if (context == null) { + throw new IllegalArgumentException("KNNMethod context cannot be null"); + } + return getVectorTransformer(context.getKnnEngine(), context.getSpaceType()); + } + + /** + * Returns a vector transformer based on the provided KNN engine and space type. + * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer + * since FAISS doesn't natively support cosine space type. For all other cases, + * returns a no-operation transformer. + * + * @param knnEngine The KNN engine type + * @param spaceType The space type + * @return VectorTransformer An appropriate vector transformer instance + */ + public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { + return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : VectorTransformer.NOOP_VECTOR_TRANSFORMER; + } + + private static boolean shouldNormalizeVector(final KNNEngine knnEngine, final SpaceType spaceType) { + return knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index ee18394f6..c2998df6c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -12,6 +12,7 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.util.VectorUtil; import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; @@ -541,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .k(this.k) @@ -558,8 +559,8 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) - .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) + .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .radius(radius) .methodParameters(this.methodParameters) @@ -611,7 +612,13 @@ private void updateQueryStats(VectorQueryType vectorQueryType) { } } - private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) { + private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) { + + // Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize + // query vector before applying search + if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) { + return VectorUtil.l2normalize(this.vector); + } if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) { return this.vector; } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index c2e75ecb2..e7db4868b 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -17,6 +17,7 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.lucene.util.VectorUtil; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; @@ -43,6 +44,7 @@ import java.util.Map; import java.util.Random; import java.util.TreeMap; +import java.util.function.BiFunction; import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.DIMENSION; @@ -93,6 +95,7 @@ public class FaissIT extends KNNRestTestCase { private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; + public static final int ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = 0; static TestUtils.TestData testData; @@ -373,16 +376,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN deleteModel(modelId); // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } @SneakyThrows @@ -555,17 +549,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { deleteKNNIndex(indexName); deleteModel(modelId); - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } @SneakyThrows @@ -1239,17 +1223,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( deleteKNNIndex(indexName); deleteModel(modelId); - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } /** @@ -2050,6 +2024,20 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( assertEquals(1, resultsQuery2.size()); } + public void testCosineSimilarity_withHNSW_withExactSearch_thenSucceed() throws Exception { + testCosineSimilarityForApproximateSearch(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + } + + public void testCosineSimilarity_withHNSW_withApproximate_thenSucceed() throws Exception { + testCosineSimilarityForApproximateSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + validateGraphEviction(); + } + + public void testCosineSimilarity_withHNSW_withRadialSearch_thenSucceed() throws Exception { + testCosineSimilarityForRadialSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + validateGraphEviction(); + } + protected void setupKNNIndexForFilterQuery() throws Exception { setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); } @@ -2161,7 +2149,7 @@ private List> validateRadiusSearchResults( if (filterQuery != null) { queryBuilder.field("filter", filterQuery); } - if (methodParameters != null) { + if (methodParameters != null && methodParameters.size() > 0) { queryBuilder.startObject(METHOD_PARAMETER); for (Map.Entry entry : methodParameters.entrySet()) { queryBuilder.field(entry.getKey(), entry.getValue()); @@ -2182,6 +2170,8 @@ private List> validateRadiusSearchResults( assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); } else if (spaceType == SpaceType.INNER_PRODUCT) { assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else if (spaceType == SpaceType.COSINESIMIL) { + assertTrue(KNNScoringUtil.cosinesimil(queryVector, vector) >= distance); } else { throw new IllegalArgumentException("Invalid space type"); } @@ -2190,4 +2180,97 @@ private List> validateRadiusSearchResults( } return queryResults; } + + private void testCosineSimilarityForApproximateSearch(int approximateThreshold) throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + indexTestData(approximateThreshold, indexName, spaceType, fieldName); + + // search index + validateNearestNeighborsSearch(indexName, fieldName, spaceType, 10, VectorUtil::cosine); + + // Delete index + deleteKNNIndex(indexName); + } + + private void testCosineSimilarityForRadialSearch(int approximateThreshold) throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + indexTestData(approximateThreshold, indexName, spaceType, fieldName); + + // search index + validateRadiusSearchResults(indexName, fieldName, testData.queries, 0.92f, null, spaceType, null, null); + + // Delete index + deleteKNNIndex(indexName); + } + + private void indexTestData(int approximateThreshold, String indexName, SpaceType spaceType, String fieldName) throws Exception { + Integer dimension = testData.indexData.vectors[0].length; + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(indexName, buildKNNIndexSettings(approximateThreshold), mapping); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + } + + @SneakyThrows + private void validateNearestNeighborsSearch( + final String indexName, + final String fieldName, + final SpaceType spaceType, + final int k, + final BiFunction scoringFunction + ) { + for (int i = 0; i < testData.queries.length; i++) { + final Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[i]).k(k).build(), + k + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + final List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + final float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(scoringFunction.apply(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + } + } diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java new file mode 100644 index 000000000..923e59a26 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; + +public class NormalizeVectorTransformerTests extends KNNTestCase { + private final NormalizeVectorTransformer transformer = new NormalizeVectorTransformer(); + private static final float DELTA = 0.001f; // Delta for floating point comparisons + + public void testNormalizeTransformer_withNullVector_thenThrowsException() { + assertThrows(IllegalArgumentException.class, () -> transformer.transform((float[]) null)); + assertThrows(IllegalArgumentException.class, () -> transformer.transform((byte[]) null)); + } + + public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { + assertThrows(IllegalArgumentException.class, () -> transformer.transform(new float[0])); + } + + public void testNormalizeTransformer_withValidVector_thenSuccess() { + float[] input = { -3.0f, 4.0f }; + float[] normalized = transformer.transform(input); + + assertEquals(-0.6f, normalized[0], DELTA); + assertEquals(0.8f, normalized[1], DELTA); + + // Verify the magnitude is 1 + assertEquals(1.0f, calculateMagnitude(normalized), DELTA); + } + + private float calculateMagnitude(float[] vector) { + float magnitude = 0.0f; + for (float value : vector) { + magnitude += value * value; + } + return (float) Math.sqrt(magnitude); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java new file mode 100644 index 000000000..6148f83d6 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class VectorTransformerFactoryTests extends KNNTestCase { + public void testAllSpaceTypes_withFaiss() { + for (SpaceType spaceType : SpaceType.values()) { + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); + validateTransformer(spaceType, KNNEngine.FAISS, transformer); + } + } + + public void testAllEngines_withCosine() { + // Test all engines with COSINESIMIL space type + for (KNNEngine engine : KNNEngine.values()) { + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(engine, SpaceType.COSINESIMIL); + validateTransformer(SpaceType.COSINESIMIL, engine, transformer); + } + } + + public void testGetVectorTransformer_withNullContext() { + // Test case for null context + assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer(null)); + } + + public void testAllSpaceTypes_usingContext_withFaiss() { + for (SpaceType spaceType : SpaceType.values()) { + KNNMethodContext context = mock(KNNMethodContext.class); + when(context.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(context.getSpaceType()).thenReturn(spaceType); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context); + validateTransformer(spaceType, KNNEngine.FAISS, transformer); + } + } + + public void testAllEngines_usingContext_withCosine() { + // Test all engines with COSINESIMIL space type + for (KNNEngine engine : KNNEngine.values()) { + KNNMethodContext context = mock(KNNMethodContext.class); + when(context.getKnnEngine()).thenReturn(engine); + when(context.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context); + validateTransformer(SpaceType.COSINESIMIL, engine, transformer); + } + } + + private static void validateTransformer(SpaceType spaceType, KNNEngine engine, VectorTransformer transformer) { + if (spaceType == SpaceType.COSINESIMIL && engine == KNNEngine.FAISS) { + assertTrue( + "Should return NormalizeVectorTransformer for FAISS with " + spaceType, + transformer instanceof NormalizeVectorTransformer + ); + } else { + assertSame( + "Should return NOOP transformer for " + engine + " with COSINESIMIL", + VectorTransformer.NOOP_VECTOR_TRANSFORMER, + transformer + ); + } + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 896674a18..3fd6aedc1 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -129,6 +129,8 @@ public class KNNRestTestCase extends ODFERestTestCase { protected static final int DELAY_MILLI_SEC = 1000; protected static final int NUM_OF_ATTEMPTS = 30; private static final String SYSTEM_INDEX_PREFIX = ".opendistro"; + public static final int MIN_CODE_UNITS = 4; + public static final int MAX_CODE_UNITS = 10; @AfterClass public static void dumpCoverage() throws IOException, MalformedObjectNameException { @@ -1939,4 +1941,22 @@ protected boolean isApproximateThresholdSupported(final Optional bwcVers final Version version = Version.fromString(versionString); return version.onOrAfter(Version.V_2_18_0); } + + /** + * Generates a random lowercase string with length between MIN_CODE_UNITS and MAX_CODE_UNITS. + * This method is used for test fixtures to generate random string values that can be used + * as identifiers, names, or other string-based test data. + * Example usage: + *

+     * String randomId = randomLowerCaseString();
+     * String indexName = randomLowerCaseString();
+     * String fieldName = randomLowerCaseString();
+     * 
+ * + * @return A random lowercase string of variable length between MIN_CODE_UNITS and MAX_CODE_UNITS + * @see #randomAlphaOfLengthBetween(int, int) + */ + protected static String randomLowerCaseString() { + return randomAlphaOfLengthBetween(MIN_CODE_UNITS, MAX_CODE_UNITS).toLowerCase(Locale.ROOT); + } }