diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a757acd..36d667448 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] +- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index f53655136..bfd908a09 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -13,6 +13,8 @@ import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; import org.opensearch.knn.index.mapper.SpaceVectorValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; +import org.opensearch.knn.index.mapper.VectorTransformerFactory; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.ArrayList; @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( return PerDimensionProcessor.NOOP_PROCESSOR; } + protected VectorTransformer getVectorTransformer(SpaceType spaceType) { + return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; + } + @Override public KNNLibraryIndexingContext getKNNLibraryIndexingContext( KNNMethodContext knnMethodContext, @@ -116,7 +122,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( knnMethodConfigContext ); Map parameterMap = knnLibraryIndexingContext.getLibraryParameters(); - parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + parameterMap.put(KNNConstants.SPACE_TYPE, convertUserToMethodSpaceType(knnMethodContext.getSpaceType()).getValue()); parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue()); return KNNLibraryIndexingContextImpl.builder() .quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig()) @@ -124,6 +130,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( .vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) + .vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType())) .build(); } @@ -131,4 +138,21 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( public KNNLibrarySearchContext getKNNLibrarySearchContext() { return knnLibrarySearchContext; } + + /** + * Converts user defined space type to method space type that is supported by library. + * The subclass can override this method and returns the appropriate space type that + * is supported by the library. This is required because, some libraries may not + * support all the space types supported by OpenSearch, however. this can be achieved by using compatible space type by the library. + * For example, faiss does not support cosine similarity. However, we can use inner product space type for cosine similarity after normalization. + * In this case, we can return the inner product space type for cosine similarity. + * + * @param spaceType The space type to check for compatibility + * @return The compatible space type for the given input, returns the same + * space type if it's already compatible + * @see SpaceType + */ + protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) { + return spaceType; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index 9208661af..1ff677cd6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -8,6 +8,7 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; @@ -47,4 +48,12 @@ public interface KNNLibraryIndexingContext { * @return Get the per dimension processor */ PerDimensionProcessor getPerDimensionProcessor(); + + /** + * Get the vector transformer that will be used to transform the vector before indexing. + * This will be applied at vector level once entire vector is parsed and validated. + * + * @return VectorTransformer + */ + VectorTransformer getVectorTransformer(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index f5329fc31..9822033b7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -9,6 +9,7 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Collections; @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext private VectorValidator vectorValidator; private PerDimensionValidator perDimensionValidator; private PerDimensionProcessor perDimensionProcessor; + private VectorTransformer vectorTransformer; @Builder.Default private Map parameters = Collections.emptyMap(); @Builder.Default @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() { return vectorValidator; } + @Override + public VectorTransformer getVectorTransformer() { + return vectorTransformer; + } + @Override public PerDimensionValidator getPerDimensionValidator() { return perDimensionValidator; diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 7ae403445..5e7b72b69 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -8,15 +8,11 @@ import org.apache.commons.lang.StringUtils; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.AbstractKNNMethod; -import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; -import org.opensearch.knn.index.engine.KNNLibrarySearchContext; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.MethodComponent; -import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.*; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; +import org.opensearch.knn.index.mapper.VectorTransformerFactory; import java.util.Objects; import java.util.Set; @@ -132,4 +128,20 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m } return (MethodComponentContext) object; } + + @Override + protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) { + // While FAISS doesn't directly support cosine similarity, we can leverage the mathematical + // relationship between cosine similarity and inner product for normalized vectors to add support. + // When ||a|| = ||b|| = 1, cos(θ) = a · b + if (spaceType == SpaceType.COSINESIMIL) { + return SpaceType.INNER_PRODUCT; + } + return super.convertUserToMethodSpaceType(spaceType); + } + + @Override + protected VectorTransformer getVectorTransformer(SpaceType spaceType) { + return VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index a602619a1..5a0258279 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -26,6 +26,7 @@ */ public class Faiss extends NativeLibrary { public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B"; + Map> distanceTransform; Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name @@ -36,14 +37,24 @@ public class Faiss extends NativeLibrary { // Map that overrides OpenSearch score translation by space type of scores returned by faiss private final static Map> SCORE_TRANSLATIONS = ImmutableMap.of( SpaceType.INNER_PRODUCT, - rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) + rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore), + // COSINESIMIL expects the raw score in 1 - cosine(x,y) + SpaceType.COSINESIMIL, + rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore) ); // Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation: // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< SpaceType, - Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + Function>builder() + .put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1) + .put(SpaceType.COSINESIMIL, score -> 2 * score - 1) + .build(); + + private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< + SpaceType, + Function>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build(); // Package private so that the method resolving logic can access the methods final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); @@ -53,7 +64,8 @@ public class Faiss extends NativeLibrary { SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION, - SCORE_TO_DISTANCE_TRANSFORMATIONS + SCORE_TO_DISTANCE_TRANSFORMATIONS, + DISTANCE_TRANSLATIONS ); private final MethodResolver methodResolver; @@ -71,16 +83,20 @@ private Faiss( Map> scoreTranslation, String currentVersion, String extension, - Map> scoreTransform + Map> scoreTransform, + Map> distanceTransform ) { super(methods, scoreTranslation, currentVersion, extension); this.scoreTransform = scoreTransform; + this.distanceTransform = distanceTransform; this.methodResolver = new FaissMethodResolver(); } @Override public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { - // Faiss engine uses distance as is and does not need transformation + if (this.distanceTransform.containsKey(spaceType)) { + return this.distanceTransform.get(spaceType).apply(distance); + } return distance; } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index c153a9328..3386f871c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod { SpaceType.UNDEFINED, SpaceType.HAMMING, SpaceType.L2, - SpaceType.INNER_PRODUCT + SpaceType.INNER_PRODUCT, + SpaceType.COSINESIMIL ); private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 340c1f4d8..582029392 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod { SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, - SpaceType.HAMMING + SpaceType.HAMMING, + SpaceType.COSINESIMIL ); private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 8da41aa59..9f1ebcf01 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -109,4 +109,5 @@ protected PerDimensionValidator getPerDimensionValidator() { protected PerDimensionProcessor getPerDimensionProcessor() { return PerDimensionProcessor.NOOP_PROCESSOR; } + } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index c17abc2e1..99c6ebe2a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -675,7 +675,7 @@ protected void validatePreparse() { protected abstract VectorValidator getVectorValidator(); /** - * Getter for per dimension validator during vector parsing + * Getter for per dimension validator during vector parsing, and before any transformation * * @return PerDimensionValidator */ @@ -688,6 +688,23 @@ protected void validatePreparse() { */ protected abstract PerDimensionProcessor getPerDimensionProcessor(); + /** + * Retrieves the vector transformer for the KNN vector field. + * This method provides access to the vector transformer instance that will be used + * for processing vectors in the KNN field. The transformer is responsible for any + * necessary vector transformations before indexing or searching. + * This implementation delegates to the VectorTransformerFactory to obtain + * the appropriate transformer instance. The returned transformer is typically + * stateless and thread-safe. + * + * @return VectorTransformer An instance of VectorTransformer that will be used + * for vector transformations in this field + * + */ + protected VectorTransformer getVectorTransformer() { + return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; + } + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { validatePreparse(); @@ -698,6 +715,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); + getVectorTransformer().transform(array); context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -707,6 +725,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); + getVectorTransformer().transform(array); context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index d12247ad7..461c6f7c8 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -20,7 +20,10 @@ import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; @@ -28,6 +31,7 @@ import java.util.Collections; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; @@ -115,4 +119,38 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) Mode mode = knnMappingConfig.getMode(); return compressionLevel.getDefaultRescoreContext(mode, dimension); } + + /** + * Transforms a query vector based on the field's configuration. The transformation is performed + * in-place on the input vector according to either the KNN method context or the model ID. + * + * @param vector The float array to be transformed in-place. Must not be null. + * @throws IllegalStateException if neither KNN method context nor Model ID is configured + * + * The transformation process follows this order: + * 1. If vector is not FLOAT type, no transformation is performed + * 2. Attempts to use KNN method context if present + * 3. Falls back to model ID if KNN method context is not available + * 4. Throws exception if neither configuration is present + */ + public void transformQueryVector(float[] vector) { + if (VectorDataType.FLOAT != vectorDataType) { + return; + } + final Optional knnMethodContext = knnMappingConfig.getKnnMethodContext(); + if (knnMethodContext.isPresent()) { + KNNMethodContext context = knnMethodContext.get(); + VectorTransformerFactory.getVectorTransformer(context.getKnnEngine(), context.getSpaceType()).transform(vector); + return; + } + final Optional modelId = knnMappingConfig.getModelId(); + if (modelId.isPresent()) { + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + final ModelMetadata metadata = modelDao.getMetadata(modelId.get()); + VectorTransformerFactory.getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()).transform(vector); + return; + } + throw new IllegalStateException("Either KNN method context or Model Id should be configured"); + + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 755439ce6..814bc4f63 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { private final PerDimensionProcessor perDimensionProcessor; private final PerDimensionValidator perDimensionValidator; private final VectorValidator vectorValidator; + private final VectorTransformer vectorTransformer; public static MethodFieldMapper createFieldMapper( String fullname, @@ -180,6 +181,7 @@ private MethodFieldMapper( this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); } @Override @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() { protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + + @Override + protected VectorTransformer getVectorTransformer() { + return vectorTransformer; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index cbc7520cf..d472090fc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.mapper; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.VectorEncoding; @@ -33,6 +34,7 @@ /** * Field mapper for model in mapping */ +@Log4j2 public class ModelFieldMapper extends KNNVectorFieldMapper { // If the dimension has not yet been set because we do not have access to model metadata, it will be -1 @@ -41,6 +43,7 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { private PerDimensionProcessor perDimensionProcessor; private PerDimensionValidator perDimensionValidator; private VectorValidator vectorValidator; + private VectorTransformer vectorTransformer; private final String modelId; @@ -192,6 +195,43 @@ protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + @Override + protected VectorTransformer getVectorTransformer() { + // we don't want to call model metadata to get space type and engine for every vector, + // since getVectorTransformer() will be called once per vector. Hence, + // we initialize it once, and use it every other time + initVectorTransformer(); + return this.vectorTransformer; + } + + /** + * Initializes the vector transformer for the model field if not already initialized. + * This method handles the vector transformation configuration based on the model metadata + * and KNN method context. + * @throws IllegalStateException if model metadata cannot be retrieved + */ + + private void initVectorTransformer() { + if (vectorTransformer != null) { + return; + } + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + + KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); + // Need to handle BWC case where method context is not available + if (knnMethodContext == null || knnMethodConfigContext == null) { + vectorTransformer = VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; + return; + } + // get vector transformer from Indexing Context. We want Engine/Library to provide necessary + // input rather than creating Transformer from the engine and space type. This design + // decision is taken to make sure that Engine will drive the implementation than Field Mapper. + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); + } + private void initVectorValidator() { if (vectorValidator != null) { return; diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java new file mode 100644 index 000000000..cd1331d3d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.util.VectorUtil; + +/** + * Normalizes vectors using L2 (Euclidean) normalization, ensuring the vector's + * magnitude becomes 1 while preserving its directional properties. + */ +public class NormalizeVectorTransformer implements VectorTransformer { + + @Override + public void transform(float[] vector) { + validateVector(vector); + VectorUtil.l2normalize(vector); + } + + /** + * Transforms a byte array vector by normalizing it. + * This operation is currently not supported for byte arrays. + * + * @param vector the byte array to be normalized + * @throws UnsupportedOperationException when this method is called, as byte array normalization is not supported + */ + @Override + public void transform(byte[] vector) { + throw new UnsupportedOperationException("Byte array normalization is not supported"); + } + + private void validateVector(float[] vector) { + if (vector == null || vector.length == 0) { + throw new IllegalArgumentException("Vector cannot be null or empty"); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java new file mode 100644 index 000000000..eb8c9dca0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.mapper; + +/** + * Defines operations for transforming vectors in the k-NN search context. + * Implementations can modify vectors while preserving their dimensional properties + * for specific use cases such as normalization, scaling, or other transformations. + */ +public interface VectorTransformer { + + /** + * Transforms a float vector in place. + * + * @param vector The input vector to transform (must not be null) + * @throws IllegalArgumentException if the input vector is null + */ + default void transform(final float[] vector) { + if (vector == null) { + throw new IllegalArgumentException("Input vector cannot be null"); + } + } + + /** + * Transforms a byte vector in place. + * + * @param vector The input vector to transform (must not be null) + * @throws IllegalArgumentException if the input vector is null + */ + default void transform(final byte[] vector) { + if (vector == null) { + throw new IllegalArgumentException("Input vector cannot be null"); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java new file mode 100644 index 000000000..f87e496df --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; + +/** + * Factory class responsible for creating appropriate vector transformers. + * This factory determines whether vectors need transformation based on the engine type and space type. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class VectorTransformerFactory { + + /** + * A no-operation transformer that returns vector values unchanged. + */ + public final static VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { + }; + + /** + * Returns a vector transformer based on the provided KNN engine and space type. + * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer + * since FAISS doesn't natively support cosine space type. For all other cases, + * returns a no-operation transformer. + * + * @param knnEngine The KNN engine type + * @param spaceType The space type + * @return VectorTransformer An appropriate vector transformer instance + */ + public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { + return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : NOOP_VECTOR_TRANSFORMER; + } + + private static boolean shouldNormalizeVector(final KNNEngine knnEngine, final SpaceType spaceType) { + return knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 6015d8482..7b27fd5a4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -430,6 +430,7 @@ protected Query doToQuery(QueryShardContext context) { SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); + knnVectorFieldType.transformQueryVector(vector); VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); updateQueryStats(vectorQueryType); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index c2e75ecb2..4113579bf 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -17,6 +17,7 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.lucene.util.VectorUtil; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; @@ -43,6 +44,7 @@ import java.util.Map; import java.util.Random; import java.util.TreeMap; +import java.util.function.BiFunction; import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.DIMENSION; @@ -93,6 +95,7 @@ public class FaissIT extends KNNRestTestCase { private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; + public static final int ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = 0; static TestUtils.TestData testData; @@ -373,16 +376,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN deleteModel(modelId); // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } @SneakyThrows @@ -555,17 +549,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { deleteKNNIndex(indexName); deleteModel(modelId); - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } @SneakyThrows @@ -1239,17 +1223,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( deleteKNNIndex(indexName); deleteModel(modelId); - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } /** @@ -2050,6 +2024,90 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( assertEquals(1, resultsQuery2.size()); } + public void testCosineSimilarity_withHNSW_withExactSearch_thenSucceed() throws Exception { + testCosineSimilarityForApproximateSearch(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + } + + public void testCosineSimilarity_withHNSW_withApproximate_thenSucceed() throws Exception { + testCosineSimilarityForApproximateSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + validateGraphEviction(); + } + + public void testCosineSimilarity_withGraph_withRadialSearch_withDistanceThreshold_thenSucceed() throws Exception { + testCosineSimilarityForRadialSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, null, 0.1f); + validateGraphEviction(); + } + + public void testCosineSimilarity_withGraph_withRadialSearch_withScore_thenSucceed() throws Exception { + testCosineSimilarityForRadialSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0.9f, null); + validateGraphEviction(); + } + + public void testCosineSimilarity_withNoGraphs_withRadialSearch_withDistanceThreshold_thenSucceed() throws Exception { + testCosineSimilarityForRadialSearch(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, null, 0.1f); + validateGraphEviction(); + } + + public void testCosineSimilarity_withNoGraphs_withRadialSearch_withScore_thenSucceed() throws Exception { + testCosineSimilarityForRadialSearch(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0.9f, null); + validateGraphEviction(); + } + + public void testEndToEnd_withApproxAndExactSearch_inSameIndex_ForCosineSpaceType() throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, buildKNNIndexSettings(0), mapping); + + // Index one document + addKnnDoc(indexName, randomAlphaOfLength(5), fieldName, Floats.asList(testData.indexData.vectors[0]).toArray()); + + // Assert we have the right number of documents in the index + refreshAllIndices(); + assertEquals(1, getDocCount(indexName)); + // update threshold setting to skip building graph + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, -1)); + // add duplicate document with different id + addKnnDoc(indexName, randomAlphaOfLength(5), fieldName, Floats.asList(testData.indexData.vectors[0]).toArray()); + assertEquals(2, getDocCount(indexName)); + final int k = 2; + // search index + Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[0]).k(k).build(), + k + ); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + + // both document should have identical score + assertEquals(actualScores.get(0), actualScores.get(1), 0.001); + } + protected void setupKNNIndexForFilterQuery() throws Exception { setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); } @@ -2161,7 +2219,7 @@ private List> validateRadiusSearchResults( if (filterQuery != null) { queryBuilder.field("filter", filterQuery); } - if (methodParameters != null) { + if (methodParameters != null && methodParameters.size() > 0) { queryBuilder.startObject(METHOD_PARAMETER); for (Map.Entry entry : methodParameters.entrySet()) { queryBuilder.field(entry.getKey(), entry.getValue()); @@ -2182,6 +2240,8 @@ private List> validateRadiusSearchResults( assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); } else if (spaceType == SpaceType.INNER_PRODUCT) { assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else if (spaceType == SpaceType.COSINESIMIL) { + assertTrue(KNNScoringUtil.cosinesimil(queryVector, vector) >= distance); } else { throw new IllegalArgumentException("Invalid space type"); } @@ -2190,4 +2250,97 @@ private List> validateRadiusSearchResults( } return queryResults; } + + private void testCosineSimilarityForApproximateSearch(int approximateThreshold) throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + indexTestData(approximateThreshold, indexName, spaceType, fieldName); + + // search index + validateNearestNeighborsSearch(indexName, fieldName, spaceType, 10, VectorUtil::cosine); + + // Delete index + deleteKNNIndex(indexName); + } + + private void testCosineSimilarityForRadialSearch(int approximateThreshold, Float score, Float distance) throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + indexTestData(approximateThreshold, indexName, spaceType, fieldName); + + // search index + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, score, spaceType, null, null); + + // Delete index + deleteKNNIndex(indexName); + } + + private void indexTestData(int approximateThreshold, String indexName, SpaceType spaceType, String fieldName) throws Exception { + Integer dimension = testData.indexData.vectors[0].length; + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(indexName, buildKNNIndexSettings(approximateThreshold), mapping); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + } + + @SneakyThrows + private void validateNearestNeighborsSearch( + final String indexName, + final String fieldName, + final SpaceType spaceType, + final int k, + final BiFunction scoringFunction + ) { + for (int i = 0; i < testData.queries.length; i++) { + final Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[i]).k(k).build(), + k + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + final List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + final float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(scoringFunction.apply(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + } + } diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java new file mode 100644 index 000000000..4b17b9a12 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; + +public class NormalizeVectorTransformerTests extends KNNTestCase { + private final NormalizeVectorTransformer transformer = new NormalizeVectorTransformer(); + private static final float DELTA = 0.001f; // Delta for floating point comparisons + + public void testNormalizeTransformer_withNullVector_thenThrowsException() { + assertThrows(IllegalArgumentException.class, () -> transformer.transform((float[]) null)); + } + + public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { + assertThrows(IllegalArgumentException.class, () -> transformer.transform(new float[0])); + } + + public void testNormalizeTransformer_withByteVector_thenThrowsException() { + assertThrows(UnsupportedOperationException.class, () -> transformer.transform(new byte[0])); + } + + public void testNormalizeTransformer_withValidVector_thenSuccess() { + float[] input = { -3.0f, 4.0f }; + transformer.transform(input); + + assertEquals(-0.6f, input[0], DELTA); + assertEquals(0.8f, input[1], DELTA); + + // Verify the magnitude is 1 + assertEquals(1.0f, calculateMagnitude(input), DELTA); + } + + private float calculateMagnitude(float[] vector) { + float magnitude = 0.0f; + for (float value : vector) { + magnitude += value * value; + } + return (float) Math.sqrt(magnitude); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java new file mode 100644 index 000000000..6e213c151 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; + +public class VectorTransformerFactoryTests extends KNNTestCase { + + public void testAllSpaceTypes_withFaiss() { + for (SpaceType spaceType : SpaceType.values()) { + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); + validateTransformer(spaceType, KNNEngine.FAISS, transformer); + } + } + + public void testAllEngines_withCosine() { + for (KNNEngine engine : KNNEngine.values()) { + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(engine, SpaceType.COSINESIMIL); + validateTransformer(SpaceType.COSINESIMIL, engine, transformer); + } + } + + private static void validateTransformer(SpaceType spaceType, KNNEngine engine, VectorTransformer transformer) { + if (spaceType == SpaceType.COSINESIMIL && engine == KNNEngine.FAISS) { + assertTrue( + "Should return NormalizeVectorTransformer for FAISS with " + spaceType, + transformer instanceof NormalizeVectorTransformer + ); + } else { + assertSame( + "Should return NOOP transformer for " + engine + " with COSINESIMIL", + VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER, + transformer + ); + } + } +} diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index be2dd7b82..ae162401b 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -231,7 +231,7 @@ public void testRecall_whenLuceneHnswFP32_thenRecallAbove75percent() { */ @SneakyThrows public void testRecall_whenFaissHnswFP32_thenRecallAbove75percent() { - List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.COSINESIMIL); for (SpaceType spaceType : spaceTypes) { String indexName = createIndexName(KNNEngine.FAISS, spaceType); XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 477816d91..7dd1ec237 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -129,6 +129,8 @@ public class KNNRestTestCase extends ODFERestTestCase { protected static final int DELAY_MILLI_SEC = 1000; protected static final int NUM_OF_ATTEMPTS = 30; private static final String SYSTEM_INDEX_PREFIX = ".opendistro"; + public static final int MIN_CODE_UNITS = 4; + public static final int MAX_CODE_UNITS = 10; @AfterClass public static void dumpCoverage() throws IOException, MalformedObjectNameException { @@ -1960,4 +1962,22 @@ protected boolean isApproximateThresholdSupported(final Optional bwcVers final Version version = Version.fromString(versionString); return version.onOrAfter(Version.V_2_18_0); } + + /** + * Generates a random lowercase string with length between MIN_CODE_UNITS and MAX_CODE_UNITS. + * This method is used for test fixtures to generate random string values that can be used + * as identifiers, names, or other string-based test data. + * Example usage: + *
+     * String randomId = randomLowerCaseString();
+     * String indexName = randomLowerCaseString();
+     * String fieldName = randomLowerCaseString();
+     * 
+ * + * @return A random lowercase string of variable length between MIN_CODE_UNITS and MAX_CODE_UNITS + * @see #randomAlphaOfLengthBetween(int, int) + */ + protected static String randomLowerCaseString() { + return randomAlphaOfLengthBetween(MIN_CODE_UNITS, MAX_CODE_UNITS).toLowerCase(Locale.ROOT); + } }