Skip to content

Commit

Permalink
Add cosine similarity support for faiss engine (opensearch-project#2376)
Browse files Browse the repository at this point in the history
* Add cosine similarity support for faiss engine

FAISS engine doesn't support cosine similarity natively.
However we can use inner product to achieve the same, because,
when vectors are normalized then inner product will be same
as cosine similarity. Hence, before ingestion and perform search,
normalize the input vector and add it to faiss index with type
as inner product.

Since we will be storing normalized vector in segments, to get
actual vectors, source can be used. By saving as normalized vector,
we don't have to normalize whenever segments are merged. This will
keep force merge time and search at competitive, provided we will
face additional latency during indexing (one time where we normalize).

We also support radial search for cosine similarity.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 22, 2025
1 parent 209b0b7 commit 2624c6e
Show file tree
Hide file tree
Showing 22 changed files with 605 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.SpaceVectorValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.ArrayList;
Expand Down Expand Up @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
Expand All @@ -116,19 +122,37 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
knnMethodConfigContext
);
Map<String, Object> parameterMap = knnLibraryIndexingContext.getLibraryParameters();
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(KNNConstants.SPACE_TYPE, convertUserToMethodSpaceType(knnMethodContext.getSpaceType()).getValue());
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
return KNNLibraryIndexingContextImpl.builder()
.quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig())
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
.build();
}

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext() {
return knnLibrarySearchContext;
}

/**
* Converts user defined space type to method space type that is supported by library.
* The subclass can override this method and returns the appropriate space type that
* is supported by the library. This is required because, some libraries may not
* support all the space types supported by OpenSearch, however. this can be achieved by using compatible space type by the library.
* For example, faiss does not support cosine similarity. However, we can use inner product space type for cosine similarity after normalization.
* In this case, we can return the inner product space type for cosine similarity.
*
* @param spaceType The space type to check for compatibility
* @return The compatible space type for the given input, returns the same
* space type if it's already compatible
* @see SpaceType
*/
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
return spaceType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
Expand Down Expand Up @@ -47,4 +48,12 @@ public interface KNNLibraryIndexingContext {
* @return Get the per dimension processor
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
*
* @return VectorTransformer
*/
VectorTransformer getVectorTransformer();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Collections;
Expand All @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private VectorValidator vectorValidator;
private PerDimensionValidator perDimensionValidator;
private PerDimensionProcessor perDimensionProcessor;
private VectorTransformer vectorTransformer;
@Builder.Default
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
Expand All @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() {
return vectorValidator;
}

@Override
public VectorTransformer getVectorTransformer() {
return vectorTransformer;
}

@Override
public PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@
import org.apache.commons.lang.StringUtils;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.AbstractKNNMethod;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.*;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;

import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -132,4 +128,20 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
}
return (MethodComponentContext) object;
}

@Override
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
// relationship between cosine similarity and inner product for normalized vectors to add support.
// When ||a|| = ||b|| = 1, cos(x) = a dot b
if (spaceType == SpaceType.COSINESIMIL) {
return SpaceType.INNER_PRODUCT;
}
return super.convertUserToMethodSpaceType(spaceType);
}

@Override
protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType);
}
}
26 changes: 21 additions & 5 deletions src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*/
public class Faiss extends NativeLibrary {
public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B";
Map<SpaceType, Function<Float, Float>> distanceTransform;
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
Expand All @@ -36,14 +37,24 @@ public class Faiss extends NativeLibrary {
// Map that overrides OpenSearch score translation by space type of scores returned by faiss
private final static Map<SpaceType, Function<Float, Float>> SCORE_TRANSLATIONS = ImmutableMap.of(
SpaceType.INNER_PRODUCT,
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore),
// COSINESIMIL expects the raw score in 1 - cosine(x,y)
SpaceType.COSINESIMIL,
rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore)
);

// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1)
.put(SpaceType.COSINESIMIL, score -> 2 * score - 1)
.build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());
Expand All @@ -53,7 +64,8 @@ public class Faiss extends NativeLibrary {
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
SCORE_TO_DISTANCE_TRANSFORMATIONS,
DISTANCE_TRANSLATIONS
);

private final MethodResolver methodResolver;
Expand All @@ -71,16 +83,20 @@ private Faiss(
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
Map<SpaceType, Function<Float, Float>> scoreTransform,
Map<SpaceType, Function<Float, Float>> distanceTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
this.distanceTransform = distanceTransform;
this.methodResolver = new FaissMethodResolver();
}

@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
// Faiss engine uses distance as is and does not need transformation
if (this.distanceTransform.containsKey(spaceType)) {
return this.distanceTransform.get(spaceType).apply(distance);
}
return distance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.HAMMING,
SpaceType.L2,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
SpaceType.HAMMING,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ protected void validatePreparse() {
protected abstract VectorValidator getVectorValidator();

/**
* Getter for per dimension validator during vector parsing
* Getter for per dimension validator during vector parsing, and before any transformation
*
* @return PerDimensionValidator
*/
Expand All @@ -688,6 +688,23 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

/**
* Retrieves the vector transformer for the KNN vector field.
* This method provides access to the vector transformer instance that will be used
* for processing vectors in the KNN field. The transformer is responsible for any
* necessary vector transformations before indexing or searching.
* This implementation delegates to the VectorTransformerFactory to obtain
* the appropriate transformer instance. The returned transformer is typically
* stateless and thread-safe.
*
* @return VectorTransformer An instance of VectorTransformer that will be used
* for vector transformations in this field
*
*/
protected VectorTransformer getVectorTransformer() {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}

protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
validatePreparse();

Expand All @@ -698,6 +715,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(array));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Expand All @@ -707,6 +725,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(array));
} else {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
Expand Down Expand Up @@ -99,4 +103,38 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext)
Mode mode = knnMappingConfig.getMode();
return compressionLevel.getDefaultRescoreContext(mode, dimension);
}

/**
* Transforms a query vector based on the field's configuration. The transformation is performed
* in-place on the input vector according to either the KNN method context or the model ID.
*
* @param vector The float array to be transformed in-place. Must not be null.
* @throws IllegalStateException if neither KNN method context nor Model ID is configured
*
* The transformation process follows this order:
* 1. If vector is not FLOAT type, no transformation is performed
* 2. Attempts to use KNN method context if present
* 3. Falls back to model ID if KNN method context is not available
* 4. Throws exception if neither configuration is present
*/
public void transformQueryVector(float[] vector) {
if (VectorDataType.FLOAT != vectorDataType) {
return;
}
final Optional<KNNMethodContext> knnMethodContext = knnMappingConfig.getKnnMethodContext();
if (knnMethodContext.isPresent()) {
KNNMethodContext context = knnMethodContext.get();
VectorTransformerFactory.getVectorTransformer(context.getKnnEngine(), context.getSpaceType()).transform(vector);
return;
}
final Optional<String> modelId = knnMappingConfig.getModelId();
if (modelId.isPresent()) {
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata metadata = modelDao.getMetadata(modelId.get());
VectorTransformerFactory.getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()).transform(vector);
return;
}
throw new IllegalStateException("Either KNN method context or Model Id should be configured");

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
private final PerDimensionProcessor perDimensionProcessor;
private final PerDimensionValidator perDimensionValidator;
private final VectorValidator vectorValidator;
private final VectorTransformer vectorTransformer;

public static MethodFieldMapper createFieldMapper(
String fullname,
Expand Down Expand Up @@ -180,6 +181,7 @@ private MethodFieldMapper(
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

@Override
Expand All @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
protected VectorTransformer getVectorTransformer() {
return vectorTransformer;
}
}
Loading

0 comments on commit 2624c6e

Please sign in to comment.