Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cosine similarity support for faiss engine #2376

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.SpaceVectorValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.ArrayList;
Expand Down Expand Up @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
Expand All @@ -116,19 +122,37 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
knnMethodConfigContext
);
Map<String, Object> parameterMap = knnLibraryIndexingContext.getLibraryParameters();
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(KNNConstants.SPACE_TYPE, convertUserToMethodSpaceType(knnMethodContext.getSpaceType()).getValue());
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
return KNNLibraryIndexingContextImpl.builder()
.quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig())
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
.build();
}

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext() {
return knnLibrarySearchContext;
}

/**
* Converts user defined space type to method space type that is supported by library.
* The subclass can override this method and returns the appropriate space type that
* is supported by the library. This is required because, some libraries may not
* support all the space types supported by OpenSearch, however. this can be achieved by using compatible space type by the library.
* For example, faiss does not support cosine similarity. However, we can use inner product space type for cosine similarity after normalization.
* In this case, we can return the inner product space type for cosine similarity.
*
* @param spaceType The space type to check for compatibility
* @return The compatible space type for the given input, returns the same
* space type if it's already compatible
* @see SpaceType
*/
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
VijayanB marked this conversation as resolved.
Show resolved Hide resolved
return spaceType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
Expand Down Expand Up @@ -47,4 +48,12 @@ public interface KNNLibraryIndexingContext {
* @return Get the per dimension processor
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
*
* @return VectorTransformer
*/
VectorTransformer getVectorTransformer();
VijayanB marked this conversation as resolved.
Show resolved Hide resolved
VijayanB marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Collections;
Expand All @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private VectorValidator vectorValidator;
private PerDimensionValidator perDimensionValidator;
private PerDimensionProcessor perDimensionProcessor;
private VectorTransformer vectorTransformer;
@Builder.Default
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
Expand All @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() {
return vectorValidator;
}

@Override
public VectorTransformer getVectorTransformer() {
return vectorTransformer;
}

@Override
public PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@
import org.apache.commons.lang.StringUtils;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.AbstractKNNMethod;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.*;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;

import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -132,4 +128,20 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
}
return (MethodComponentContext) object;
}

@Override
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
// relationship between cosine similarity and inner product for normalized vectors to add support.
// When ||a|| = ||b|| = 1, cos(θ) = a · b
if (spaceType == SpaceType.COSINESIMIL) {
return SpaceType.INNER_PRODUCT;
}
return super.convertUserToMethodSpaceType(spaceType);
}

@Override
protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType);
}
}
26 changes: 21 additions & 5 deletions src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*/
public class Faiss extends NativeLibrary {
public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B";
Map<SpaceType, Function<Float, Float>> distanceTransform;
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
Expand All @@ -36,14 +37,24 @@ public class Faiss extends NativeLibrary {
// Map that overrides OpenSearch score translation by space type of scores returned by faiss
private final static Map<SpaceType, Function<Float, Float>> SCORE_TRANSLATIONS = ImmutableMap.of(
SpaceType.INNER_PRODUCT,
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore),
// COSINESIMIL expects the raw score in 1 - cosine(x,y)
SpaceType.COSINESIMIL,
rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore)
);

// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1)
.put(SpaceType.COSINESIMIL, score -> 2 * score - 1)
.build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());
Expand All @@ -53,7 +64,8 @@ public class Faiss extends NativeLibrary {
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
SCORE_TO_DISTANCE_TRANSFORMATIONS,
DISTANCE_TRANSLATIONS
);

private final MethodResolver methodResolver;
Expand All @@ -71,16 +83,20 @@ private Faiss(
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
Map<SpaceType, Function<Float, Float>> scoreTransform,
Map<SpaceType, Function<Float, Float>> distanceTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
this.distanceTransform = distanceTransform;
this.methodResolver = new FaissMethodResolver();
}

@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
// Faiss engine uses distance as is and does not need transformation
if (this.distanceTransform.containsKey(spaceType)) {
return this.distanceTransform.get(spaceType).apply(distance);
}
return distance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.HAMMING,
SpaceType.L2,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
SpaceType.HAMMING,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ protected void validatePreparse() {
protected abstract VectorValidator getVectorValidator();

/**
* Getter for per dimension validator during vector parsing
* Getter for per dimension validator during vector parsing, and before any transformation
*
* @return PerDimensionValidator
*/
Expand All @@ -688,6 +688,23 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

/**
* Retrieves the vector transformer for the KNN vector field.
* This method provides access to the vector transformer instance that will be used
* for processing vectors in the KNN field. The transformer is responsible for any
* necessary vector transformations before indexing or searching.
* This implementation delegates to the VectorTransformerFactory to obtain
* the appropriate transformer instance. The returned transformer is typically
* stateless and thread-safe.
*
* @return VectorTransformer An instance of VectorTransformer that will be used
* for vector transformations in this field
*
*/
protected VectorTransformer getVectorTransformer() {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}

protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
validatePreparse();

Expand All @@ -698,6 +715,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(array));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Expand All @@ -707,6 +725,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(array));
} else {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
Expand Down Expand Up @@ -115,4 +119,38 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext)
Mode mode = knnMappingConfig.getMode();
return compressionLevel.getDefaultRescoreContext(mode, dimension);
}

/**
* Transforms a query vector based on the field's configuration. The transformation is performed
* in-place on the input vector according to either the KNN method context or the model ID.
*
* @param vector The float array to be transformed in-place. Must not be null.
* @throws IllegalStateException if neither KNN method context nor Model ID is configured
*
* The transformation process follows this order:
* 1. If vector is not FLOAT type, no transformation is performed
* 2. Attempts to use KNN method context if present
* 3. Falls back to model ID if KNN method context is not available
* 4. Throws exception if neither configuration is present
*/
public void transformQueryVector(float[] vector) {
VijayanB marked this conversation as resolved.
Show resolved Hide resolved
if (VectorDataType.FLOAT != vectorDataType) {
return;
}
final Optional<KNNMethodContext> knnMethodContext = knnMappingConfig.getKnnMethodContext();
if (knnMethodContext.isPresent()) {
KNNMethodContext context = knnMethodContext.get();
VectorTransformerFactory.getVectorTransformer(context.getKnnEngine(), context.getSpaceType()).transform(vector);
return;
}
final Optional<String> modelId = knnMappingConfig.getModelId();
if (modelId.isPresent()) {
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata metadata = modelDao.getMetadata(modelId.get());
VectorTransformerFactory.getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()).transform(vector);
return;
}
throw new IllegalStateException("Either KNN method context or Model Id should be configured");

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
private final PerDimensionProcessor perDimensionProcessor;
private final PerDimensionValidator perDimensionValidator;
private final VectorValidator vectorValidator;
private final VectorTransformer vectorTransformer;

public static MethodFieldMapper createFieldMapper(
String fullname,
Expand Down Expand Up @@ -180,6 +181,7 @@ private MethodFieldMapper(
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

@Override
Expand All @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

VijayanB marked this conversation as resolved.
Show resolved Hide resolved
@Override
protected VectorTransformer getVectorTransformer() {
return vectorTransformer;
}
}
Loading
Loading