Skip to content

Commit

Permalink
Enhance RagTool to choose neural sparse query type (opensearch-projec…
Browse files Browse the repository at this point in the history
…t#140) (opensearch-project#152)

* enhance RagTool to choose neural sparse query type

* Work around JDK 21.0.2 bug impacting scaling executors

* Modify RAGTool factory create to initiate subtools

* Add enableContentGeneration to RAGTool

* Map embedding_model_id to model_id for NeuralSparseTool

* Map embedding_model_id to model_id for NeuralSparseTool

---------

(cherry picked from commit d7945d5)

Signed-off-by: Mingshi Liu <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
2 people authored and yuye-aws committed Apr 26, 2024
1 parent 8809bf5 commit 993d25c
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 143 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
java:
- 11
- 17
- 21
- 21.0.1
name: Build and Test skills plugin on Linux
runs-on: ubuntu-latest
container:
Expand Down Expand Up @@ -98,7 +98,7 @@ jobs:
java:
- 11
- 17
- 21
- 21.0.1
name: Build and Test skills plugin on Windows
needs: Get-CI-Image-Tag
runs-on: windows-latest
Expand Down
148 changes: 67 additions & 81 deletions src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.agent.tools;

import static org.apache.commons.lang3.StringEscapeUtils.escapeJson;
import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K;
import static org.opensearch.agent.tools.AbstractRetrieverTool.*;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

Expand All @@ -21,10 +21,10 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
Expand All @@ -44,25 +44,28 @@
@Setter
@Getter
@ToolAnnotation(RAGTool.TYPE)
public class RAGTool extends AbstractRetrieverTool {
public class RAGTool implements Tool {
public static final String TYPE = "RAGTool";
public static String DEFAULT_DESCRIPTION =
"Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions.";
public static final String INFERENCE_MODEL_ID_FIELD = "inference_model_id";
public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id";
public static final String INDEX_FIELD = "index";
public static final String SOURCE_FIELD = "source_field";
public static final String DOC_SIZE_FIELD = "doc_size";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String OUTPUT_FIELD = "output_field";
public static final String QUERY_TYPE = "query_type";
public static final String CONTENT_GENERATION_FIELD = "enable_Content_Generation";
public static final String K_FIELD = "k";
private final AbstractRetrieverTool queryTool;
private String name = TYPE;
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String inferenceModelId;
private Boolean enableContentGeneration;
private NamedXContentRegistry xContentRegistry;
private String index;
private String embeddingField;
private String[] sourceFields;
private String embeddingModelId;
private Integer docSize;
private Integer k;
private String queryType;
@Setter
private Parser inputParser;
@Setter
Expand All @@ -72,24 +75,15 @@ public class RAGTool extends AbstractRetrieverTool {
public RAGTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String embeddingField,
String[] sourceFields,
Integer k,
Integer docSize,
String embeddingModelId,
String inferenceModelId
String inferenceModelId,
Boolean enableContentGeneration,
AbstractRetrieverTool queryTool
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.embeddingField = embeddingField;
this.sourceFields = sourceFields;
this.embeddingModelId = embeddingModelId;
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
this.k = k == null ? DEFAULT_K : k;
this.inferenceModelId = inferenceModelId;
this.enableContentGeneration = enableContentGeneration;
this.queryTool = queryTool;
outputParser = new Parser() {
@Override
public Object parse(Object o) {
Expand All @@ -99,13 +93,6 @@ public Object parse(Object o) {
};
}

// getQueryBody is not used in RAGTool
@Override
protected String getQueryBody(String queryText) {
return queryText;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
String input = null;

Expand All @@ -121,22 +108,14 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
return;
}

Map<String, Object> params = new HashMap<>();
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
params.put(VectorDBTool.INDEX_FIELD, this.index);
params.put(VectorDBTool.EMBEDDING_FIELD, this.embeddingField);
params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(this.sourceFields));
params.put(VectorDBTool.MODEL_ID_FIELD, this.embeddingModelId);
params.put(VectorDBTool.DOC_SIZE_FIELD, String.valueOf(this.docSize));
params.put(VectorDBTool.K_FIELD, String.valueOf(this.k));
VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);

String embeddingInput = input;
ActionListener actionListener = ActionListener.<T>wrap(r -> {
T vectorDBToolOutput;

T queryToolOutput;
if (!this.enableContentGeneration) {
listener.onResponse(r);
}
if (r.equals("Can not get any match from search result.")) {
vectorDBToolOutput = (T) "";
queryToolOutput = (T) "";
} else {
Gson gson = new Gson();
String[] hits = r.toString().split("\n");
Expand All @@ -151,31 +130,21 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
resultBuilder.append("_source: ").append(source.toString()).append("\n");
}

vectorDBToolOutput = (T) gson.toJson(resultBuilder.toString());
queryToolOutput = (T) gson.toJson(resultBuilder.toString());
}

Map<String, String> tmpParameters = new HashMap<>();
tmpParameters.putAll(parameters);

if (vectorDBToolOutput instanceof List
&& !((List) vectorDBToolOutput).isEmpty()
&& ((List) vectorDBToolOutput).get(0) instanceof ModelTensors) {
ModelTensors tensors = (ModelTensors) ((List) vectorDBToolOutput).get(0);
Object response = tensors.getMlModelTensors().get(0).getDataAsMap().get("response");
tmpParameters.put(OUTPUT_FIELD, response + "");
} else if (vectorDBToolOutput instanceof ModelTensor) {
tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(((ModelTensor) vectorDBToolOutput).getDataAsMap())));
if (queryToolOutput instanceof String) {
tmpParameters.put(OUTPUT_FIELD, (String) queryToolOutput);
} else {
if (vectorDBToolOutput instanceof String) {
tmpParameters.put(OUTPUT_FIELD, (String) vectorDBToolOutput);
} else {
tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(vectorDBToolOutput.toString())));
}
tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(queryToolOutput.toString())));
}

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
ActionRequest request = new MLPredictionTaskRequest(inferenceModelId, mlInput, null);
ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null);

client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput();
Expand All @@ -186,33 +155,33 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
}
}, e -> {
log.error("Failed to run model " + inferenceModelId, e);
log.error("Failed to run model " + this.inferenceModelId, e);
listener.onFailure(e);
}));
}, e -> {
log.error("Failed to search index.", e);
listener.onFailure(e);
});
vectorDBTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener);

this.queryTool.run(Map.of(INPUT_FIELD, embeddingInput), actionListener);
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

public String getName() {
return this.name;
}

@Override
public void setName(String s) {
this.name = s;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
return false;
Expand All @@ -224,7 +193,7 @@ public boolean validate(Map<String, String> parameters) {
/**
* Factory class to create RAGTool
*/
public static class Factory extends AbstractRetrieverTool.Factory<RAGTool> {
public static class Factory implements Tool.Factory<RAGTool> {
private Client client;
private NamedXContentRegistry xContentRegistry;

Expand All @@ -250,23 +219,40 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) {

@Override
public RAGTool create(Map<String, Object> params) {
String queryType = params.containsKey(QUERY_TYPE) ? (String) params.get(QUERY_TYPE) : "neural";
String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD);
String index = (String) params.get(INDEX_FIELD);
String embeddingField = (String) params.get(EMBEDDING_FIELD);
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String inferenceModelId = (String) params.get(INFERENCE_MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2;
return RAGTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceFields(sourceFields)
.embeddingModelId(embeddingModelId)
.docSize(docSize)
.inferenceModelId(inferenceModelId)
.build();
Boolean enableContentGeneration = params.containsKey(CONTENT_GENERATION_FIELD)
? Boolean.parseBoolean((String) params.get(CONTENT_GENERATION_FIELD))
: true;
String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : "";
switch (queryType) {
case "neural_sparse":
params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, embeddingModelId);
NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params);
return RAGTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.inferenceModelId(inferenceModelId)
.enableContentGeneration(enableContentGeneration)
.queryTool(neuralSparseSearchTool)
.build();
case "neural":
params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId);
VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params);
return RAGTool
.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.inferenceModelId(inferenceModelId)
.enableContentGeneration(enableContentGeneration)
.queryTool(vectorDBTool)
.build();
default:
log.error("Failed to read queryType, please input neural_sparse or neural.");
throw new IllegalArgumentException("Failed to read queryType, please input neural_sparse or neural.");
}

}

@Override
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Loading

0 comments on commit 993d25c

Please sign in to comment.