diff --git a/build.gradle b/build.gradle index fba99bc2..49411109 100644 --- a/build.gradle +++ b/build.gradle @@ -76,7 +76,7 @@ configurations { zipArchive all { resolutionStrategy { - force "org.mockito:mockito-core:5.8.0" + force "org.mockito:mockito-core:${versions.mockito}" force "com.google.guava:guava:32.1.3-jre" // CVE for 31.1 force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0, forces JDK17 for spotless } @@ -107,8 +107,8 @@ dependencies { compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.22.0" compileOnly group: 'org.json', name: 'json', version: '20231013' - implementation("com.google.guava:guava:32.1.3-jre") - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' + compileOnly("com.google.guava:guava:32.1.3-jre") + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' // Plugin dependencies compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}" @@ -128,8 +128,8 @@ dependencies { testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' - testImplementation("net.bytebuddy:byte-buddy:1.14.7") - testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7") + testImplementation("net.bytebuddy:byte-buddy:1.14.7") + testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7") testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1' testImplementation 'org.mockito:mockito-junit-jupiter:5.8.0' testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index eba2f6a1..8e3d0844 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -10,7 +10,9 @@ import java.util.List; import java.util.function.Supplier; +import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; +import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -54,11 +56,13 @@ public Collection createComponents( this.xContentRegistry = xContentRegistry; PPLTool.Factory.getInstance().init(client); + NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); + VectorDBTool.Factory.getInstance().init(client, xContentRegistry); return Collections.emptyList(); } @Override public List> getToolFactories() { - return List.of(PPLTool.Factory.getInstance()); + return List.of(PPLTool.Factory.getInstance(), NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance()); } } diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java new file mode 100644 index 00000000..dba48070 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Abstract tool supports search paradigms in neural-search plugin. + */ +@Log4j2 +@Getter +@Setter +public abstract class AbstractRetrieverTool implements Tool { + public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String SOURCE_FIELD = "source_field"; + public static final String DOC_SIZE_FIELD = "doc_size"; + public static final int DEFAULT_DOC_SIZE = 2; + + protected String description = DEFAULT_DESCRIPTION; + protected Client client; + protected NamedXContentRegistry xContentRegistry; + protected String index; + protected String[] sourceFields; + protected Integer docSize; + protected String version; + + protected AbstractRetrieverTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String[] sourceFields, + Integer docSize + ) { + this.client = client; + this.xContentRegistry = xContentRegistry; + this.index = index; + this.sourceFields = sourceFields; + this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; + } + + protected abstract String getQueryBody(String queryText); + + private SearchRequest buildSearchRequest(Map parameters) throws IOException { + String question = parameters.get(INPUT_FIELD); + if (StringUtils.isBlank(question)) { + throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); + } + + String query = getQueryBody(question); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.fetchSource(sourceFields, null); + searchSourceBuilder.size(docSize); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); + return searchRequest; + } + + @Override + public void run(Map parameters, ActionListener listener) { + SearchRequest searchRequest; + try { + searchRequest = buildSearchRequest(parameters); + } catch (Exception e) { + log.error("Failed to build search request.", e); + listener.onFailure(e); + return; + } + + ActionListener actionListener = ActionListener.wrap(r -> { + SearchHit[] hits = r.getHits().getHits(); + + if (hits != null && hits.length > 0) { + StringBuilder contextBuilder = new StringBuilder(); + for (int i = 0; i < hits.length; i++) { + SearchHit hit = hits[i]; + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + contextBuilder.append(gson.toJson(docContent)).append("\n"); + } + listener.onResponse((T) contextBuilder.toString()); + } else { + listener.onResponse((T) "Can not get any match from search result."); + } + }, e -> { + log.error("Failed to search index.", e); + listener.onFailure(e); + }); + client.search(searchRequest, actionListener); + } + + @Override + public boolean validate(Map parameters) { + return parameters != null && parameters.size() > 0 && !StringUtils.isBlank(parameters.get("input")); + } + + protected static abstract class Factory implements Tool.Factory { + protected Client client; + protected NamedXContentRegistry xContentRegistry; + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java new file mode 100644 index 00000000..40c57aba --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports neural_sparse search with sparse encoding models and rank_features field. + */ +@Log4j2 +@Getter +@Setter +@ToolAnnotation(NeuralSparseSearchTool.TYPE) +public class NeuralSparseSearchTool extends AbstractRetrieverTool { + public static final String TYPE = "NeuralSparseSearchTool"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; + + private String name = TYPE; + private String modelId; + private String embeddingField; + + @Builder + public NeuralSparseSearchTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer docSize, + String modelId + ) { + super(client, xContentRegistry, index, sourceFields, docSize); + this.modelId = modelId; + this.embeddingField = embeddingField; + } + + @Override + protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + ); + } + return "{\"query\":{\"neural_sparse\":{\"" + + embeddingField + + "\":{\"query_text\":\"" + + queryText + + "\",\"model_id\":\"" + + modelId + + "\"}}}" + + " }"; + } + + @Override + public String getType() { + return TYPE; + } + + public static class Factory extends AbstractRetrieverTool.Factory { + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (NeuralSparseSearchTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + @Override + public NeuralSparseSearchTool create(Map params) { + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; + return NeuralSparseSearchTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .modelId(modelId) + .docSize(docSize) + .build(); + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java new file mode 100644 index 00000000..428b9f14 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports neural search with embedding models and knn index. + */ +@Log4j2 +@Getter +@Setter +@ToolAnnotation(VectorDBTool.TYPE) +public class VectorDBTool extends AbstractRetrieverTool { + public static final String TYPE = "VectorDBTool"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String K_FIELD = "k"; + public static final Integer DEFAULT_K = 10; + + private String name = TYPE; + private String modelId; + private String embeddingField; + private Integer k; + + @Builder + public VectorDBTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer docSize, + String modelId, + Integer k + ) { + super(client, xContentRegistry, index, sourceFields, docSize); + this.modelId = modelId; + this.embeddingField = embeddingField; + this.k = k; + } + + @Override + protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + ); + } + return "{\"query\":{\"neural\":{\"" + + embeddingField + + "\":{\"query_text\":\"" + + queryText + + "\",\"model_id\":\"" + + modelId + + "\",\"k\":" + + k + + "}}}" + + " }"; + } + + @Override + public String getType() { + return TYPE; + } + + public static class Factory extends AbstractRetrieverTool.Factory { + private static VectorDBTool.Factory INSTANCE; + + public static VectorDBTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (VectorDBTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new VectorDBTool.Factory(); + return INSTANCE; + } + } + + @Override + public VectorDBTool create(Map params) { + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; + Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K; + return VectorDBTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .modelId(modelId) + .docSize(docSize) + .k(k) + .build(); + } + } +} diff --git a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java new file mode 100644 index 00000000..5e0faa9c --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchModule; + +import lombok.SneakyThrows; + +public class AbstractRetrieverToolTests { + static public final String TEST_QUERY = "{\"query\":{\"match_all\":{}}}"; + static public final String TEST_INDEX = "test index"; + static public final String[] TEST_SOURCE_FIELDS = new String[] { "test 1", "test 2" }; + static public final Integer TEST_DOC_SIZE = 3; + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + + private String mockedSearchResponseString; + private String mockedEmptySearchResponseString; + private AbstractRetrieverTool mockedImpl; + + @Before + @SneakyThrows + public void setup() { + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_empty_search_response.json")) { + if (searchResponseIns != null) { + mockedEmptySearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + + mockedImpl = Mockito + .mock( + AbstractRetrieverTool.class, + Mockito + .withSettings() + .useConstructor(null, TEST_XCONTENT_REGISTRY_FOR_QUERY, TEST_INDEX, TEST_SOURCE_FIELDS, TEST_DOC_SIZE) + .defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + when(mockedImpl.getQueryBody(any(String.class))).thenReturn(TEST_QUERY); + } + + @Test + @SneakyThrows + public void testRunAsyncWithSearchResults() { + Client client = mock(Client.class); + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + mockedImpl.setClient(client); + + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + + future.join(); + assertEquals( + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n" + + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invokation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", + future.get() + ); + } + + @Test + @SneakyThrows + public void testRunAsyncWithEmptySearchResponse() { + Client client = mock(Client.class); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + mockedImpl.setClient(client); + + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + + future.join(); + assertEquals("Can not get any match from search result.", future.get()); + } + + @Test + @SneakyThrows + public void testRunAsyncWithIllegalQueryThenListenerOnFailure() { + Client client = mock(Client.class); + mockedImpl.setClient(client); + + final CompletableFuture future1 = new CompletableFuture<>(); + ActionListener listener1 = ActionListener.wrap(future1::complete, future1::completeExceptionally); + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""), listener1); + + Exception exception1 = assertThrows(Exception.class, future1::join); + assertTrue(exception1.getCause() instanceof IllegalArgumentException); + assertEquals(exception1.getCause().getMessage(), "[input] is null or empty, can not process it."); + + final CompletableFuture future2 = new CompletableFuture<>(); + ActionListener listener2 = ActionListener.wrap(future2::complete, future2::completeExceptionally); + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "), listener2); + + Exception exception2 = assertThrows(Exception.class, future2::join); + assertTrue(exception2.getCause() instanceof IllegalArgumentException); + assertEquals(exception2.getCause().getMessage(), "[input] is null or empty, can not process it."); + + final CompletableFuture future3 = new CompletableFuture<>(); + ActionListener listener3 = ActionListener.wrap(future3::complete, future3::completeExceptionally); + mockedImpl.run(Map.of("test", "hello world"), listener3); + + Exception exception3 = assertThrows(Exception.class, future3::join); + assertTrue(exception3.getCause() instanceof IllegalArgumentException); + assertEquals(exception3.getCause().getMessage(), "[input] is null or empty, can not process it."); + + final CompletableFuture future4 = new CompletableFuture<>(); + ActionListener listener4 = ActionListener.wrap(future4::complete, future4::completeExceptionally); + mockedImpl.run(null, listener4); + + Exception exception4 = assertThrows(Exception.class, future4::join); + assertTrue(exception4.getCause() instanceof NullPointerException); + } + + @Test + @SneakyThrows + public void testValidate() { + assertTrue(mockedImpl.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hi"))); + assertFalse(mockedImpl.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""))); + assertFalse(mockedImpl.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "))); + assertFalse(mockedImpl.validate(Map.of("test", " "))); + assertFalse(mockedImpl.validate(new HashMap<>())); + assertFalse(mockedImpl.validate(null)); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java b/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java new file mode 100644 index 00000000..fac45f54 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import com.google.gson.JsonSyntaxException; + +import lombok.SneakyThrows; + +public class NeuralSparseSearchToolTests { + public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh"; + public static final String TEST_EMBEDDING_FIELD = "test embedding"; + public static final String TEST_MODEL_ID = "123fsd23134"; + private Map params = new HashMap<>(); + + @Before + public void setup() { + params.put(NeuralSparseSearchTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); + params.put(NeuralSparseSearchTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(NeuralSparseSearchTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); + params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(NeuralSparseSearchTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + } + + @Test + @SneakyThrows + public void testCreateTool() { + NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params); + assertEquals(AbstractRetrieverToolTests.TEST_INDEX, tool.getIndex()); + assertEquals(TEST_EMBEDDING_FIELD, tool.getEmbeddingField()); + assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); + assertEquals(TEST_MODEL_ID, tool.getModelId()); + assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); + assertEquals("NeuralSparseSearchTool", tool.getType()); + assertEquals("NeuralSparseSearchTool", tool.getName()); + assertEquals( + "Use this tool to search data in OpenSearch index.", + NeuralSparseSearchTool.Factory.getInstance().getDefaultDescription() + ); + } + + @Test + @SneakyThrows + public void testGetQueryBody() { + NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params); + assertEquals( + "{\"query\":{\"neural_sparse\":{\"test embedding\":{\"" + + "query_text\":\"123fsd23134sdfouh\",\"model_id\":\"123fsd23134\"}}} }", + tool.getQueryBody(TEST_QUERY_TEXT) + ); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithIllegalParams() { + Map illegalParams1 = new HashMap<>(params); + illegalParams1.remove(NeuralSparseSearchTool.MODEL_ID_FIELD); + NeuralSparseSearchTool tool1 = NeuralSparseSearchTool.Factory.getInstance().create(illegalParams1); + Exception exception1 = assertThrows( + IllegalArgumentException.class, + () -> tool1.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception1.getMessage()); + + Map illegalParams2 = new HashMap<>(params); + illegalParams2.remove(NeuralSparseSearchTool.EMBEDDING_FIELD); + NeuralSparseSearchTool tool2 = NeuralSparseSearchTool.Factory.getInstance().create(illegalParams2); + Exception exception2 = assertThrows( + IllegalArgumentException.class, + () -> tool2.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception2.getMessage()); + } + + @Test + @SneakyThrows + public void testCreateToolsParseParams() { + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.INDEX_FIELD, 123)) + ); + + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.EMBEDDING_FIELD, 123)) + ); + + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.MODEL_ID_FIELD, 123)) + ); + + assertThrows( + JsonSyntaxException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.SOURCE_FIELD, "123")) + ); + + // although it will be parsed as integer, but the parameters value should always be String + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.DOC_SIZE_FIELD, 123)) + ); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java new file mode 100644 index 00000000..cc67604f --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import com.google.gson.JsonSyntaxException; + +import lombok.SneakyThrows; + +public class VectorDBToolTests { + public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh"; + public static final String TEST_EMBEDDING_FIELD = "test embedding"; + public static final String TEST_MODEL_ID = "123fsd23134"; + public static final Integer TEST_K = 123; + private Map params = new HashMap<>(); + + @Before + public void setup() { + params.put(VectorDBTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); + params.put(VectorDBTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); + params.put(VectorDBTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(VectorDBTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + params.put(VectorDBTool.K_FIELD, TEST_K.toString()); + } + + @Test + @SneakyThrows + public void testCreateTool() { + VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params); + assertEquals(AbstractRetrieverToolTests.TEST_INDEX, tool.getIndex()); + assertEquals(TEST_EMBEDDING_FIELD, tool.getEmbeddingField()); + assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); + assertEquals(TEST_MODEL_ID, tool.getModelId()); + assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); + assertEquals(TEST_K, tool.getK()); + assertEquals("VectorDBTool", tool.getType()); + assertEquals("VectorDBTool", tool.getName()); + assertEquals("Use this tool to search data in OpenSearch index.", VectorDBTool.Factory.getInstance().getDefaultDescription()); + } + + @Test + @SneakyThrows + public void testGetQueryBody() { + VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params); + assertEquals( + "{\"query\":{\"neural\":{\"test embedding\":{\"" + + "query_text\":\"123fsd23134sdfouh\",\"model_id\":\"123fsd23134\",\"k\":123}}} }", + tool.getQueryBody(TEST_QUERY_TEXT) + ); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithIllegalParams() { + Map illegalParams1 = new HashMap<>(params); + illegalParams1.remove(VectorDBTool.MODEL_ID_FIELD); + VectorDBTool tool1 = VectorDBTool.Factory.getInstance().create(illegalParams1); + Exception exception1 = assertThrows( + IllegalArgumentException.class, + () -> tool1.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception1.getMessage()); + + Map illegalParams2 = new HashMap<>(params); + illegalParams2.remove(VectorDBTool.EMBEDDING_FIELD); + VectorDBTool tool2 = VectorDBTool.Factory.getInstance().create(illegalParams2); + Exception exception2 = assertThrows( + IllegalArgumentException.class, + () -> tool2.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception2.getMessage()); + } + + @Test + @SneakyThrows + public void testCreateToolsParseParams() { + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.INDEX_FIELD, 123))); + + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.EMBEDDING_FIELD, 123))); + + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.MODEL_ID_FIELD, 123))); + + assertThrows(JsonSyntaxException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.SOURCE_FIELD, "123"))); + + // although it will be parsed as integer, but the parameters value should always be String + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.DOC_SIZE_FIELD, 123))); + + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.K_FIELD, 123))); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/retrieval_tool_empty_search_response.json b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_empty_search_response.json new file mode 100644 index 00000000..7ca6bfa7 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_empty_search_response.json @@ -0,0 +1,18 @@ +{ + "took": 4, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 0, + "relation": "eq" + }, + "max_score": null, + "hits": [] + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json new file mode 100644 index 00000000..7e66dd60 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json @@ -0,0 +1,35 @@ +{ + "took": 201, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "max_score": 89.2917, + "hits": [ + { + "_index": "hybrid-index", + "_id": "1", + "_score": 89.2917, + "_source": { + "passage_text": "Company test_mock have a history of 100 years." + } + }, + { + "_index": "hybrid-index", + "_id": "2", + "_score": 0.10702579, + "_source": { + "passage_text": "the price of the api is 2$ per invokation" + } + } + ] + } +} \ No newline at end of file