Skip to content

Commit

Permalink
[Backport 2.x] feature: Add AbstractRetriverTool, VectorDBTool, Neura…
Browse files Browse the repository at this point in the history
…lSparseTools (#58)

* Merge pull request #40 from zhichao-aws/SearchTools

feature: Add AbstractRetriverTool, VectorDBTool, NeuralSparseTools
(cherry picked from commit c088f77)

* fix commons-lang3 version (#45) (#59)

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
Co-authored-by: zane-neo <[email protected]>
  • Loading branch information
zhichao-aws and zane-neo authored Dec 25, 2023
1 parent c931029 commit c6e18c6
Show file tree
Hide file tree
Showing 10 changed files with 834 additions and 6 deletions.
10 changes: 5 additions & 5 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ configurations {
zipArchive
all {
resolutionStrategy {
force "org.mockito:mockito-core:5.8.0"
force "org.mockito:mockito-core:${versions.mockito}"
force "com.google.guava:guava:32.1.3-jre" // CVE for 31.1
force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0, forces JDK17 for spotless
}
Expand Down Expand Up @@ -107,8 +107,8 @@ dependencies {
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0"
compileOnly group: 'org.json', name: 'json', version: '20231013'
implementation("com.google.guava:guava:32.1.3-jre")
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0'
compileOnly("com.google.guava:guava:32.1.3-jre")
compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'

// Plugin dependencies
compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}"
Expand All @@ -128,8 +128,8 @@ dependencies {
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0'
testImplementation("net.bytebuddy:byte-buddy:1.14.7")
testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7")
testImplementation("net.bytebuddy:byte-buddy:1.14.7")
testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7")
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1'
testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0'
testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0"
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -54,11 +56,13 @@ public Collection<Object> createComponents(
this.xContentRegistry = xContentRegistry;

PPLTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
return Collections.emptyList();
}

@Override
public List<Tool.Factory<? extends Tool>> getToolFactories() {
return List.of(PPLTool.Factory.getInstance());
return List.of(PPLTool.Factory.getInstance(), NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance());
}
}
140 changes: 140 additions & 0 deletions src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* Abstract tool supports search paradigms in neural-search plugin.
*/
@Log4j2
@Getter
@Setter
public abstract class AbstractRetrieverTool implements Tool {
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index.";
public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String SOURCE_FIELD = "source_field";
public static final String DOC_SIZE_FIELD = "doc_size";
public static final int DEFAULT_DOC_SIZE = 2;

protected String description = DEFAULT_DESCRIPTION;
protected Client client;
protected NamedXContentRegistry xContentRegistry;
protected String index;
protected String[] sourceFields;
protected Integer docSize;
protected String version;

protected AbstractRetrieverTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String[] sourceFields,
Integer docSize
) {
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.sourceFields = sourceFields;
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
}

protected abstract String getQueryBody(String queryText);

private <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it.");
}

String query = getQueryBody(question);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.fetchSource(sourceFields, null);
searchSourceBuilder.size(docSize);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return searchRequest;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
SearchRequest searchRequest;
try {
searchRequest = buildSearchRequest(parameters);
} catch (Exception e) {
log.error("Failed to build search request.", e);
listener.onFailure(e);
return;
}

ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
contextBuilder.append(gson.toJson(docContent)).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "Can not get any match from search result.");
}
}, e -> {
log.error("Failed to search index.", e);
listener.onFailure(e);
});
client.search(searchRequest, actionListener);
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.size() > 0 && !StringUtils.isBlank(parameters.get("input"));
}

protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> {
protected Client client;
protected NamedXContentRegistry xContentRegistry;

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
110 changes: 110 additions & 0 deletions src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.client.Client;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports neural_sparse search with sparse encoding models and rank_features field.
*/
@Log4j2
@Getter
@Setter
@ToolAnnotation(NeuralSparseSearchTool.TYPE)
public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";

private String name = TYPE;
private String modelId;
private String embeddingField;

@Builder
public NeuralSparseSearchTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String embeddingField,
String[] sourceFields,
Integer docSize,
String modelId
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
}

@Override
protected String getQueryBody(String queryText) {
if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
);
}
return "{\"query\":{\"neural_sparse\":{\""
+ embeddingField
+ "\":{\"query_text\":\""
+ queryText
+ "\",\"model_id\":\""
+ modelId
+ "\"}}}"
+ " }";
}

@Override
public String getType() {
return TYPE;
}

public static class Factory extends AbstractRetrieverTool.Factory<NeuralSparseSearchTool> {
private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (NeuralSparseSearchTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

@Override
public NeuralSparseSearchTool create(Map<String, Object> params) {
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
return NeuralSparseSearchTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.build();
}
}
}
Loading

0 comments on commit c6e18c6

Please sign in to comment.