Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add support for search using the "fields" parameter with knn_vector field #2394

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
* Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
* Fixing the bug to prevent updating the index.knn setting after index creation(#2348)[https://github.com/opensearch-project/k-NN/pull/2348]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
package org.opensearch.knn.index.mapper;

import lombok.Getter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.ArraySourceValueFetcher;
import org.opensearch.index.mapper.TextSearchInfo;
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -21,6 +24,8 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -32,6 +37,7 @@
*/
@Getter
public class KNNVectorFieldType extends MappedFieldType {
private static final Logger logger = LogManager.getLogger(KNNVectorFieldType.class);
KNNMappingConfig knnMappingConfig;
VectorDataType vectorDataType;

Expand All @@ -51,7 +57,17 @@ public KNNVectorFieldType(String name, Map<String, String> metadata, VectorDataT

@Override
public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) {
throw new UnsupportedOperationException("KNN Vector do not support fields search");
return new ArraySourceValueFetcher(name(), context) {
@Override
protected Object parseSourceValue(Object value) {
if (value instanceof ArrayList) {
return value;
} else {
logger.warn("Expected type ArrayList for value, but got {} ", value.getClass());
return Collections.emptyList();
}
}
};
}

@Override
Expand Down
286 changes: 285 additions & 1 deletion src/test/java/org/opensearch/knn/index/OpenSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Floats;
import java.util.Locale;

import lombok.SneakyThrows;
import org.apache.http.ParseException;
import org.junit.BeforeClass;
Expand All @@ -39,6 +39,7 @@
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;

Expand Down Expand Up @@ -924,6 +925,289 @@ public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBui
deleteKNNIndex(indexName);
}

public void testKNNIndexSearchFieldsParameter() throws Exception {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(Arrays.asList("vector1", "vector2", "vector3"), Arrays.asList(2, 3, 5)));
// Add docs with knn_vector fields
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
Float[] vector3 = { (float) i, (float) (i + 1), (float) (i + 2), (float) (i + 3), (float) (i + 4) };
addKnnDoc(
INDEX_NAME,
Integer.toString(i),
Arrays.asList("vector1", "vector2", "vector3"),
Arrays.asList(vector1, vector2, vector3)
);
}
int k = 10; // nearest 10 neighbors

// Create match_all search body, all fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector3"));

// Create match_all search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector3"));

// Create knn search body, all fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector3"));

// Create knn search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector3"));
}

public void testKNNIndexSearchFieldsParameterWithOtherFields() throws Exception {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector1")
.field("type", "knn_vector")
.field("dimension", "2")
.endObject()
.startObject("vector2")
.field("type", "knn_vector")
.field("dimension", "3")
.endObject()
.startObject("float1")
.field("type", "float")
.endObject()
.startObject("float2")
.field("type", "float")
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, xContentBuilder.toString());
// Add docs with knn_vector and other fields
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
Float[] float1 = { (float) i };
Float[] float2 = { (float) (i + 1) };
addKnnDoc(
INDEX_NAME,
Integer.toString(i),
Arrays.asList("vector1", "vector2", "float1", "float2"),
Arrays.asList(vector1, vector2, float1, float2)
);
}
int k = 10; // nearest 10 neighbors

// Create match_all search body, all fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "float2"));

// Create match_all search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "float2" })
.startObject("query")
.startObject("match_all")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "float2"));

// Create knn search body, all fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "*" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "float2"));

// Create knn search body, some fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "float2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "float2"));
}

public void testKNNIndexSearchFieldsParameterDocsWithOnlyOtherFields() throws Exception {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector1")
.field("type", "knn_vector")
.field("dimension", "2")
.endObject()
.startObject("vector2")
.field("type", "knn_vector")
.field("dimension", "3")
.endObject()
.startObject("text1")
.field("type", "text")
.endObject()
.endObject()
.endObject();
createKnnIndex(INDEX_NAME, xContentBuilder.toString());
// Add knn_vector docs
for (int i = 1; i <= 20; i++) {
Float[] vector1 = { (float) i, (float) (i + 1) };
Float[] vector2 = { (float) i, (float) (i + 1), (float) (i + 2) };
addKnnDoc(INDEX_NAME, Integer.toString(i), Arrays.asList("vector1", "vector2"), Arrays.asList(vector1, vector2));
}
// Add non knn_vector docs
for (int i = 21; i <= 40; i++) {
addNonKNNDoc(INDEX_NAME, Integer.toString(i), "text1", "text " + i);
}
int k = 10; // nearest 10 neighbors

// Create match search body, all non vector fields
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "text1" })
.startObject("query")
.startObject("match")
.field("text1", "text")
.endObject()
.endObject()
.endObject();
Response response = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "vector2"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response.getEntity()), "text1"));

// Create match search body, all vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("match")
.field("text1", "text")
.endObject()
.endObject()
.endObject();
Response response2 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response2.getEntity()), "text1"));

// Create knn search body, all vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "vector1", "vector2" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response3 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector1"));
assertEquals(k, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response3.getEntity()), "text1"));

// Create knn search body, all non vector fields
builder = XContentFactory.jsonBuilder()
.startObject()
.field("fields", new String[] { "text1" })
.startObject("query")
.startObject("knn")
.startObject("vector2")
.field("vector", new float[] { 2.0f, 2.0f, 2.0f })
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response response4 = searchKNNIndex(INDEX_NAME, builder, k);
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector1"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "vector2"));
assertEquals(0, parseSearchResponseFieldsCount(EntityUtils.toString(response4.getEntity()), "text1"));
}

private List<KNNResult> getResults(final String indexName, final String fieldName, final float[] vector, final int k)
throws IOException, ParseException {
final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k);
Expand Down
Loading
Loading