diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3183a959..87f9899c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: java: - 11 - 17 - - 21 + - 21.0.1 name: Build and Test skills plugin on Linux runs-on: ubuntu-latest container: @@ -98,7 +98,7 @@ jobs: java: - 11 - 17 - - 21 + - 21.0.1 name: Build and Test skills plugin on Windows needs: Get-CI-Image-Tag runs-on: windows-latest diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java index e3670bd0..6c341b05 100644 --- a/src/main/java/org/opensearch/agent/tools/RAGTool.java +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -6,7 +6,7 @@ package org.opensearch.agent.tools; import static org.apache.commons.lang3.StringEscapeUtils.escapeJson; -import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K; +import static org.opensearch.agent.tools.AbstractRetrieverTool.*; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.toJson; @@ -21,10 +21,10 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -44,25 +44,28 @@ @Setter @Getter @ToolAnnotation(RAGTool.TYPE) -public class RAGTool extends AbstractRetrieverTool { +public class RAGTool implements Tool { public static final String TYPE = "RAGTool"; public static String DEFAULT_DESCRIPTION = "Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions."; public static final String INFERENCE_MODEL_ID_FIELD = "inference_model_id"; public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id"; + public static final String INDEX_FIELD = "index"; + public static final String SOURCE_FIELD = "source_field"; + public static final String DOC_SIZE_FIELD = "doc_size"; public static final String EMBEDDING_FIELD = "embedding_field"; public static final String OUTPUT_FIELD = "output_field"; + public static final String QUERY_TYPE = "query_type"; + public static final String CONTENT_GENERATION_FIELD = "enable_Content_Generation"; + public static final String K_FIELD = "k"; + private final AbstractRetrieverTool queryTool; private String name = TYPE; private String description = DEFAULT_DESCRIPTION; private Client client; private String inferenceModelId; + private Boolean enableContentGeneration; private NamedXContentRegistry xContentRegistry; - private String index; - private String embeddingField; - private String[] sourceFields; - private String embeddingModelId; - private Integer docSize; - private Integer k; + private String queryType; @Setter private Parser inputParser; @Setter @@ -72,24 +75,15 @@ public class RAGTool extends AbstractRetrieverTool { public RAGTool( Client client, NamedXContentRegistry xContentRegistry, - String index, - String embeddingField, - String[] sourceFields, - Integer k, - Integer docSize, - String embeddingModelId, - String inferenceModelId + String inferenceModelId, + Boolean enableContentGeneration, + AbstractRetrieverTool queryTool ) { - super(client, xContentRegistry, index, sourceFields, docSize); this.client = client; this.xContentRegistry = xContentRegistry; - this.index = index; - this.embeddingField = embeddingField; - this.sourceFields = sourceFields; - this.embeddingModelId = embeddingModelId; - this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; - this.k = k == null ? DEFAULT_K : k; this.inferenceModelId = inferenceModelId; + this.enableContentGeneration = enableContentGeneration; + this.queryTool = queryTool; outputParser = new Parser() { @Override public Object parse(Object o) { @@ -99,13 +93,6 @@ public Object parse(Object o) { }; } - // getQueryBody is not used in RAGTool - @Override - protected String getQueryBody(String queryText) { - return queryText; - } - - @Override public void run(Map parameters, ActionListener listener) { String input = null; @@ -121,22 +108,14 @@ public void run(Map parameters, ActionListener listener) return; } - Map params = new HashMap<>(); - VectorDBTool.Factory.getInstance().init(client, xContentRegistry); - params.put(VectorDBTool.INDEX_FIELD, this.index); - params.put(VectorDBTool.EMBEDDING_FIELD, this.embeddingField); - params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(this.sourceFields)); - params.put(VectorDBTool.MODEL_ID_FIELD, this.embeddingModelId); - params.put(VectorDBTool.DOC_SIZE_FIELD, String.valueOf(this.docSize)); - params.put(VectorDBTool.K_FIELD, String.valueOf(this.k)); - VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params); - String embeddingInput = input; ActionListener actionListener = ActionListener.wrap(r -> { - T vectorDBToolOutput; - + T queryToolOutput; + if (!this.enableContentGeneration) { + listener.onResponse(r); + } if (r.equals("Can not get any match from search result.")) { - vectorDBToolOutput = (T) ""; + queryToolOutput = (T) ""; } else { Gson gson = new Gson(); String[] hits = r.toString().split("\n"); @@ -151,31 +130,21 @@ public void run(Map parameters, ActionListener listener) resultBuilder.append("_source: ").append(source.toString()).append("\n"); } - vectorDBToolOutput = (T) gson.toJson(resultBuilder.toString()); + queryToolOutput = (T) gson.toJson(resultBuilder.toString()); } Map tmpParameters = new HashMap<>(); tmpParameters.putAll(parameters); - if (vectorDBToolOutput instanceof List - && !((List) vectorDBToolOutput).isEmpty() - && ((List) vectorDBToolOutput).get(0) instanceof ModelTensors) { - ModelTensors tensors = (ModelTensors) ((List) vectorDBToolOutput).get(0); - Object response = tensors.getMlModelTensors().get(0).getDataAsMap().get("response"); - tmpParameters.put(OUTPUT_FIELD, response + ""); - } else if (vectorDBToolOutput instanceof ModelTensor) { - tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(((ModelTensor) vectorDBToolOutput).getDataAsMap()))); + if (queryToolOutput instanceof String) { + tmpParameters.put(OUTPUT_FIELD, (String) queryToolOutput); } else { - if (vectorDBToolOutput instanceof String) { - tmpParameters.put(OUTPUT_FIELD, (String) vectorDBToolOutput); - } else { - tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(vectorDBToolOutput.toString()))); - } + tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(queryToolOutput.toString()))); } RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); - ActionRequest request = new MLPredictionTaskRequest(inferenceModelId, mlInput, null); + ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null); client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput(); @@ -186,33 +155,33 @@ public void run(Map parameters, ActionListener listener) listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); } }, e -> { - log.error("Failed to run model " + inferenceModelId, e); + log.error("Failed to run model " + this.inferenceModelId, e); listener.onFailure(e); })); }, e -> { log.error("Failed to search index.", e); listener.onFailure(e); }); - vectorDBTool.run(Map.of(VectorDBTool.INPUT_FIELD, embeddingInput), actionListener); - + this.queryTool.run(Map.of(INPUT_FIELD, embeddingInput), actionListener); } - @Override public String getType() { return TYPE; } @Override + public String getVersion() { + return null; + } + public String getName() { return this.name; } - @Override public void setName(String s) { this.name = s; } - @Override public boolean validate(Map parameters) { if (parameters == null || parameters.size() == 0) { return false; @@ -224,7 +193,7 @@ public boolean validate(Map parameters) { /** * Factory class to create RAGTool */ - public static class Factory extends AbstractRetrieverTool.Factory { + public static class Factory implements Tool.Factory { private Client client; private NamedXContentRegistry xContentRegistry; @@ -250,23 +219,40 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) { @Override public RAGTool create(Map params) { + String queryType = params.containsKey(QUERY_TYPE) ? (String) params.get(QUERY_TYPE) : "neural"; String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD); - String index = (String) params.get(INDEX_FIELD); - String embeddingField = (String) params.get(EMBEDDING_FIELD); - String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); - String inferenceModelId = (String) params.get(INFERENCE_MODEL_ID_FIELD); - Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : 2; - return RAGTool - .builder() - .client(client) - .xContentRegistry(xContentRegistry) - .index(index) - .embeddingField(embeddingField) - .sourceFields(sourceFields) - .embeddingModelId(embeddingModelId) - .docSize(docSize) - .inferenceModelId(inferenceModelId) - .build(); + Boolean enableContentGeneration = params.containsKey(CONTENT_GENERATION_FIELD) + ? Boolean.parseBoolean((String) params.get(CONTENT_GENERATION_FIELD)) + : true; + String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : ""; + switch (queryType) { + case "neural_sparse": + params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, embeddingModelId); + NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params); + return RAGTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .inferenceModelId(inferenceModelId) + .enableContentGeneration(enableContentGeneration) + .queryTool(neuralSparseSearchTool) + .build(); + case "neural": + params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId); + VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params); + return RAGTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .inferenceModelId(inferenceModelId) + .enableContentGeneration(enableContentGeneration) + .queryTool(vectorDBTool) + .build(); + default: + log.error("Failed to read queryType, please input neural_sparse or neural."); + throw new IllegalArgumentException("Failed to read queryType, please input neural_sparse or neural."); + } + } @Override diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java index dfbbed26..dd83cb46 100644 --- a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -128,5 +128,10 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } diff --git a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java index 8ef43468..4696c12c 100644 --- a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java @@ -53,20 +53,14 @@ public class RAGToolTests { public static final String TEST_EMBEDDING_FIELD = "test_embedding"; public static final String TEST_EMBEDDING_MODEL_ID = "1234"; public static final String TEST_INFERENCE_MODEL_ID = "1234"; + public static final String TEST_NEURAL_QUERY_TYPE = "neural"; + public static final String TEST_NEURAL_SPARSE_QUERY_TYPE = "neural_sparse"; - public static final String TEST_NEURAL_QUERY = "{\"query\":{\"neural\":{\"" - + TEST_EMBEDDING_FIELD - + "\":{\"query_text\":\"" - + TEST_QUERY_TEXT - + "\",\"model_id\":\"" - + TEST_EMBEDDING_MODEL_ID - + "\",\"k\":" - + DEFAULT_K - + "}}}" - + " }";; + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY = getQueryNamedXContentRegistry(); private RAGTool ragTool; private String mockedSearchResponseString; private String mockedEmptySearchResponseString; + private String mockedNeuralSparseSearchResponseString; @Mock private Parser mockOutputParser; @Mock @@ -89,10 +83,16 @@ public void setup() { } } + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("neural_sparse_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedNeuralSparseSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } client = mock(Client.class); listener = mock(ActionListener.class); - RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); - + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + NeuralSparseSearchTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); params = new HashMap<>(); params.put(RAGTool.INDEX_FIELD, TEST_INDEX); params.put(RAGTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); @@ -100,7 +100,9 @@ public void setup() { params.put(RAGTool.EMBEDDING_MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); params.put(RAGTool.INFERENCE_MODEL_ID_FIELD, TEST_INFERENCE_MODEL_ID); params.put(RAGTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); - params.put(VectorDBTool.K_FIELD, DEFAULT_K); + params.put(RAGTool.K_FIELD, DEFAULT_K.toString()); + params.put(RAGTool.QUERY_TYPE, TEST_NEURAL_QUERY_TYPE); + params.put(RAGTool.CONTENT_GENERATION_FIELD, "true"); ragTool = RAGTool.Factory.getInstance().create(params); } @@ -118,12 +120,6 @@ public void testValidate() { public void testGetAttributes() { assertEquals(ragTool.getVersion(), null); assertEquals(ragTool.getType(), RAGTool.TYPE); - assertEquals(ragTool.getIndex(), TEST_INDEX); - assertEquals(ragTool.getDocSize(), TEST_DOC_SIZE); - assertEquals(ragTool.getSourceFields(), TEST_SOURCE_FIELDS); - assertEquals(ragTool.getEmbeddingField(), TEST_EMBEDDING_FIELD); - assertEquals(ragTool.getEmbeddingModelId(), TEST_EMBEDDING_MODEL_ID); - assertEquals(ragTool.getK(), DEFAULT_K); assertEquals(ragTool.getInferenceModelId(), TEST_INFERENCE_MODEL_ID); } @@ -134,15 +130,10 @@ public void testSetName() { assertEquals(ragTool.getName(), "test-tool"); } - @Test - public void testGetQueryBodySuccess() { - assertEquals(ragTool.getQueryBody(TEST_QUERY_TEXT), TEST_QUERY_TEXT); - } - @Test public void testOutputParser() throws IOException { - NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); ragTool.setXContentRegistry(mockNamedXContentRegistry); ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); @@ -167,7 +158,7 @@ public void testOutputParser() throws IOException { }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); ragTool.setOutputParser(mockOutputParser); - ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); verify(client).search(any(), any()); verify(client).execute(any(), any(), any()); @@ -175,7 +166,7 @@ public void testOutputParser() throws IOException { @Test public void testRunWithEmptySearchResponse() throws IOException { - NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); ragTool.setXContentRegistry(mockNamedXContentRegistry); ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); @@ -198,14 +189,68 @@ public void testRunWithEmptySearchResponse() throws IOException { actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); return null; }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); - ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + public void testRunWithNeuralSparseQueryType() throws IOException { + + Map paramsWithNeuralSparse = new HashMap<>(params); + paramsWithNeuralSparse.put(RAGTool.QUERY_TYPE, TEST_NEURAL_SPARSE_QUERY_TYPE); + + RAGTool rAGtoolWithNeuralSparseQuery = RAGTool.Factory.getInstance().create(paramsWithNeuralSparse); + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + rAGtoolWithNeuralSparseQuery.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedNeuralSparseSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + mockedNeuralSparseSearchResponseString + ) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedNeuralSparseSearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + rAGtoolWithNeuralSparseQuery.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); verify(client).search(any(), any()); verify(client).execute(any(), any(), any()); } + @Test + public void testRunWithInvalidQueryType() throws IOException { + + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + Map paramsWithInvalidQueryType = new HashMap<>(params); + paramsWithInvalidQueryType.put(RAGTool.QUERY_TYPE, "sparse"); + try { + RAGTool rAGtoolWithInvalidQueryType = RAGTool.Factory.getInstance().create(paramsWithInvalidQueryType); + } catch (IllegalArgumentException e) { + assertEquals("Failed to read queryType, please input neural_sparse or neural.", e.getMessage()); + } + + } + @Test public void testRunWithQuestionJson() throws IOException { - NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); ragTool.setXContentRegistry(mockNamedXContentRegistry); ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); @@ -233,10 +278,84 @@ public void testRunWithQuestionJson() throws IOException { verify(client).execute(any(), any(), any()); } + @Test + public void testRunEmptyResponseWithNotEnableContentGeneration() throws IOException { + ActionListener mockListener = mock(ActionListener.class); + Map paramsWithNotEnableContentGeneration = new HashMap<>(params); + paramsWithNotEnableContentGeneration.put(RAGTool.CONTENT_GENERATION_FIELD, "false"); + + RAGTool rAGtoolWithNotEnableContentGeneration = RAGTool.Factory.getInstance().create(paramsWithNotEnableContentGeneration); + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + rAGtoolWithNotEnableContentGeneration.setXContentRegistry(mockNamedXContentRegistry); + + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + rAGtoolWithNotEnableContentGeneration.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), mockListener); + + verify(client).search(any(), any()); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(mockListener).onResponse(responseCaptor.capture()); + assertEquals("Can not get any match from search result.", responseCaptor.getValue()); + + } + + @Test + public void testRunResponseWithNotEnableContentGeneration() throws IOException { + ActionListener mockListener = mock(ActionListener.class); + Map paramsWithNotEnableContentGeneration = new HashMap<>(params); + paramsWithNotEnableContentGeneration.put(RAGTool.CONTENT_GENERATION_FIELD, "false"); + + RAGTool rAGtoolWithNotEnableContentGeneration = RAGTool.Factory.getInstance().create(paramsWithNotEnableContentGeneration); + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + rAGtoolWithNotEnableContentGeneration.setXContentRegistry(mockNamedXContentRegistry); + + SearchResponse mockedNeuralSparseSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + mockedNeuralSparseSearchResponseString + ) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedNeuralSparseSearchResponse); + return null; + }).when(client).search(any(), any()); + rAGtoolWithNotEnableContentGeneration.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), mockListener); + + verify(client).search(any(), any()); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(mockListener).onResponse(responseCaptor.capture()); + assertEquals( + "{\"_index\":\"my-nlp-index\",\"_source\":{\"passage_text\":\"Hello world\",\"passage_embedding\":{\"!\":0.8708904,\"door\":0.8587369,\"hi\":2.3929274,\"worlds\":2.7839446,\"yes\":0.75845814,\"##world\":2.5432441,\"born\":0.2682308,\"nothing\":0.8625516,\"goodbye\":0.17146169,\"greeting\":0.96817183,\"birth\":1.2788506,\"come\":0.1623208,\"global\":0.4371151,\"it\":0.42951578,\"life\":1.5750692,\"thanks\":0.26481047,\"world\":4.7300377,\"tiny\":0.5462298,\"earth\":2.6555297,\"universe\":2.0308156,\"worldwide\":1.3903781,\"hello\":6.696973,\"so\":0.20279501,\"?\":0.67785245},\"id\":\"s1\"},\"_id\":\"1\",\"_score\":30.0029}\n" + + "{\"_index\":\"my-nlp-index\",\"_source\":{\"passage_text\":\"Hi planet\",\"passage_embedding\":{\"hi\":4.338913,\"planets\":2.7755864,\"planet\":5.0969057,\"mars\":1.7405145,\"earth\":2.6087382,\"hello\":3.3210192},\"id\":\"s2\"},\"_id\":\"2\",\"_score\":16.480486}\n", + responseCaptor.getValue() + ); + + } + @Test @SneakyThrows public void testRunWithRuntimeExceptionDuringSearch() { - NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); ragTool.setXContentRegistry(mockNamedXContentRegistry); doAnswer(invocation -> { SearchRequest searchRequest = invocation.getArgument(0); @@ -245,7 +364,7 @@ public void testRunWithRuntimeExceptionDuringSearch() { actionListener.onFailure(new RuntimeException("Failed to search index")); return null; }).when(client).search(any(), any()); - ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); verify(listener).onFailure(any(RuntimeException.class)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -255,7 +374,7 @@ public void testRunWithRuntimeExceptionDuringSearch() { @Test @SneakyThrows public void testRunWithRuntimeExceptionDuringExecute() { - NamedXContentRegistry mockNamedXContentRegistry = getNeuralQueryNamedXContentRegistry(); + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); ragTool.setXContentRegistry(mockNamedXContentRegistry); SearchResponse mockedSearchResponse = SearchResponse @@ -278,7 +397,7 @@ public void testRunWithRuntimeExceptionDuringExecute() { return null; }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); - ragTool.run(Map.of(INPUT_FIELD, "hello?"), listener); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); verify(listener).onFailure(any(RuntimeException.class)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -292,50 +411,83 @@ public void testRunWithEmptyInput() { } @Test - public void testFactory() { + public void testFactoryNeuralQuery() { RAGTool.Factory factoryMock = new RAGTool.Factory(); - RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); - factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); String defaultDescription = factoryMock.getDefaultDescription(); assertEquals(RAGTool.DEFAULT_DESCRIPTION, defaultDescription); + assertEquals(factoryMock.getDefaultType(), RAGTool.TYPE); + assertEquals(factoryMock.getDefaultVersion(), null); assertNotNull(RAGTool.Factory.getInstance()); + RAGTool rAGtool1 = factoryMock.create(params); + VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + params.put(VectorDBTool.MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); + VectorDBTool queryTool = VectorDBTool.Factory.getInstance().create(params); + RAGTool rAGtool2 = new RAGTool(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY, TEST_INFERENCE_MODEL_ID, true, queryTool); - RAGTool rAGtool2 = new RAGTool( - client, - TEST_XCONTENT_REGISTRY_FOR_QUERY, - TEST_INDEX, - TEST_EMBEDDING_FIELD, - TEST_SOURCE_FIELDS, - DEFAULT_K, - TEST_DOC_SIZE, - TEST_EMBEDDING_MODEL_ID, - TEST_INFERENCE_MODEL_ID - ); + assertEquals(rAGtool1.getClient(), rAGtool2.getClient()); + assertEquals(rAGtool1.getInferenceModelId(), rAGtool2.getInferenceModelId()); + assertEquals(rAGtool1.getName(), rAGtool2.getName()); + assertEquals(rAGtool1.getQueryTool().getDocSize(), rAGtool2.getQueryTool().getDocSize()); + assertEquals(rAGtool1.getQueryTool().getIndex(), rAGtool2.getQueryTool().getIndex()); + assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields()); + assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry()); + assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType()); + } + + @Test + public void testFactoryNeuralSparseQuery() { + RAGTool.Factory factoryMock = new RAGTool.Factory(); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + + String defaultDescription = factoryMock.getDefaultDescription(); + assertEquals(RAGTool.DEFAULT_DESCRIPTION, defaultDescription); + assertNotNull(RAGTool.Factory.getInstance()); + assertEquals(factoryMock.getDefaultType(), RAGTool.TYPE); + assertEquals(factoryMock.getDefaultVersion(), null); + + RAGTool rAGtool1 = factoryMock.create(params); + NeuralSparseSearchTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + NeuralSparseSearchTool queryTool = NeuralSparseSearchTool.Factory.getInstance().create(params); + RAGTool rAGtool2 = new RAGTool(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY, TEST_INFERENCE_MODEL_ID, true, queryTool); assertEquals(rAGtool1.getClient(), rAGtool2.getClient()); - assertEquals(rAGtool1.getK(), rAGtool2.getK()); assertEquals(rAGtool1.getInferenceModelId(), rAGtool2.getInferenceModelId()); assertEquals(rAGtool1.getName(), rAGtool2.getName()); - assertEquals(rAGtool1.getDocSize(), rAGtool2.getDocSize()); - assertEquals(rAGtool1.getIndex(), rAGtool2.getIndex()); - assertEquals(rAGtool1.getEmbeddingModelId(), rAGtool2.getEmbeddingModelId()); - assertEquals(rAGtool1.getEmbeddingField(), rAGtool2.getEmbeddingField()); - assertEquals(rAGtool1.getSourceFields(), rAGtool2.getSourceFields()); + assertEquals(rAGtool1.getQueryTool().getDocSize(), rAGtool2.getQueryTool().getDocSize()); + assertEquals(rAGtool1.getQueryTool().getIndex(), rAGtool2.getQueryTool().getIndex()); + assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields()); assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry()); + assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType()); } - private static NamedXContentRegistry getNeuralQueryNamedXContentRegistry() { + private static NamedXContentRegistry getQueryNamedXContentRegistry() { QueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); List entries = new ArrayList<>(); - NamedXContentRegistry.Entry entry = new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField("neural"), (p, c) -> { - p.map(); - return matchAllQueryBuilder; - }); - entries.add(entry); + NamedXContentRegistry.Entry neural_query_entry = new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField("neural"), + (p, c) -> { + p.map(); + return matchAllQueryBuilder; + } + ); + entries.add(neural_query_entry); + NamedXContentRegistry.Entry neural_sparse_query_entry = new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField("neural_sparse"), + (p, c) -> { + p.map(); + return matchAllQueryBuilder; + } + ); + entries.add(neural_sparse_query_entry); NamedXContentRegistry mockNamedXContentRegistry = new NamedXContentRegistry(entries); return mockNamedXContentRegistry; } diff --git a/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java index cc67604f..cce80d5b 100644 --- a/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java @@ -48,7 +48,7 @@ public void testCreateTool() { assertEquals(TEST_K, tool.getK()); assertEquals("VectorDBTool", tool.getType()); assertEquals("VectorDBTool", tool.getName()); - assertEquals("Use this tool to search data in OpenSearch index.", VectorDBTool.Factory.getInstance().getDefaultDescription()); + assertEquals(VectorDBTool.DEFAULT_DESCRIPTION, VectorDBTool.Factory.getInstance().getDefaultDescription()); } @Test diff --git a/src/test/resources/org/opensearch/agent/tools/neural_sparse_tool_search_response.json b/src/test/resources/org/opensearch/agent/tools/neural_sparse_tool_search_response.json new file mode 100644 index 00000000..196e8a04 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/neural_sparse_tool_search_response.json @@ -0,0 +1,71 @@ +{ + "took" : 688, + "timed_out" : false, + "_shards" : { + "total" : 1, + "successful" : 1, + "skipped" : 0, + "failed" : 0 + }, + "hits" : { + "total" : { + "value" : 2, + "relation" : "eq" + }, + "max_score" : 30.0029, + "hits" : [ + { + "_index" : "my-nlp-index", + "_id" : "1", + "_score" : 30.0029, + "_source" : { + "passage_text" : "Hello world", + "passage_embedding" : { + "!" : 0.8708904, + "door" : 0.8587369, + "hi" : 2.3929274, + "worlds" : 2.7839446, + "yes" : 0.75845814, + "##world" : 2.5432441, + "born" : 0.2682308, + "nothing" : 0.8625516, + "goodbye" : 0.17146169, + "greeting" : 0.96817183, + "birth" : 1.2788506, + "come" : 0.1623208, + "global" : 0.4371151, + "it" : 0.42951578, + "life" : 1.5750692, + "thanks" : 0.26481047, + "world" : 4.7300377, + "tiny" : 0.5462298, + "earth" : 2.6555297, + "universe" : 2.0308156, + "worldwide" : 1.3903781, + "hello" : 6.696973, + "so" : 0.20279501, + "?" : 0.67785245 + }, + "id" : "s1" + } + }, + { + "_index" : "my-nlp-index", + "_id" : "2", + "_score" : 16.480486, + "_source" : { + "passage_text" : "Hi planet", + "passage_embedding" : { + "hi" : 4.338913, + "planets" : 2.7755864, + "planet" : 5.0969057, + "mars" : 1.7405145, + "earth" : 2.6087382, + "hello" : 3.3210192 + }, + "id" : "s2" + } + } + ] + } +} \ No newline at end of file