Skip to content

Commit

Permalink
Add visualization tool (#41) (#68)
Browse files Browse the repository at this point in the history
* Visualization Tool



* fix build failure due to forbiddenApis



* Address review comments



* spotlessApply



* update default tool name



* update number of visualization be dynamic



---------


(cherry picked from commit 3774eb9)

Signed-off-by: Hailong Cui <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent cd54a22 commit b0db422
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 1 deletion.
8 changes: 8 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ test {
systemProperty 'tests.security.manager', 'false'
}

jacocoTestReport {
dependsOn test
reports {
html.required = true // human readable
xml.required = true // for coverlay
}
}

spotless {
if (JavaVersion.current() >= JavaVersion.VERSION_17) {
// Spotless configuration for Java files
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.agent.tools.VisualizationsTool;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -56,13 +57,20 @@ public Collection<Object> createComponents(
this.xContentRegistry = xContentRegistry;

PPLTool.Factory.getInstance().init(client);
VisualizationsTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
return Collections.emptyList();
}

@Override
public List<Tool.Factory<? extends Tool>> getToolFactories() {
return List.of(PPLTool.Factory.getInstance(), NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance());
return List
.of(
PPLTool.Factory.getInstance(),
NeuralSparseSearchTool.Factory.getInstance(),
VectorDBTool.Factory.getInstance(),
VisualizationsTool.Factory.getInstance()
);
}
}
171 changes: 171 additions & 0 deletions src/main/java/org/opensearch/agent/tools/VisualizationsTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import org.opensearch.ExceptionsHelper;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.client.Requests;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
@ToolAnnotation(VisualizationsTool.TYPE)
public class VisualizationsTool implements Tool {
public static final String NAME = "FindVisualizations";
public static final String TYPE = "VisualizationTool";
public static final String VERSION = "v1.0";

public static final String SAVED_OBJECT_TYPE = "visualization";

/**
* default number of visualizations returned
*/
private static final int DEFAULT_SIZE = 3;
private static final String DEFAULT_DESCRIPTION =
"Use this tool to find user created visualizations. This tool takes the visualization name as input and returns matching visualizations";
@Setter
@Getter
private String description = DEFAULT_DESCRIPTION;

@Getter
@Setter
private String name = NAME;
@Getter
@Setter
private String type = TYPE;
@Getter
private final String version = VERSION;
private final Client client;
@Getter
private final String index;
@Getter
private final int size;

@Builder
public VisualizationsTool(Client client, String index, int size) {
this.client = client;
this.index = index;
this.size = size;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE));
boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input")));

SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder);
searchSourceBuilder.from(0).size(3);
SearchRequest searchRequest = Requests.searchRequest(index).source(searchSourceBuilder);

client.search(searchRequest, new ActionListener<>() {
@Override
public void onResponse(SearchResponse searchResponse) {
SearchHits hits = searchResponse.getHits();
StringBuilder visBuilder = new StringBuilder();
visBuilder.append("Title,Id\n");
if (hits.getTotalHits().value > 0) {
Arrays.stream(hits.getHits()).forEach(h -> {
String id = trimIdPrefix(h.getId());
Map<String, String> visMap = (Map<String, String>) h.getSourceAsMap().get(SAVED_OBJECT_TYPE);
String title = visMap.get("title");
visBuilder.append(String.format(Locale.ROOT, "%s,%s\n", title, id));
});

listener.onResponse((T) visBuilder.toString());
} else {
listener.onResponse((T) "No Visualization found");
}
}

@Override
public void onFailure(Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) {
listener.onResponse((T) "No Visualization found");
} else {
listener.onFailure(e);
}
}
});
}

@VisibleForTesting
String trimIdPrefix(String id) {
id = Optional.ofNullable(id).orElse("");
if (id.startsWith(SAVED_OBJECT_TYPE)) {
String prefix = String.format(Locale.ROOT, "%s:", SAVED_OBJECT_TYPE);
return id.substring(prefix.length());
}
return id;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input"));
}

public static class Factory implements Tool.Factory<VisualizationsTool> {
private Client client;

private static VisualizationsTool.Factory INSTANCE;

public static VisualizationsTool.Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (VisualizationsTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new VisualizationsTool.Factory();
return INSTANCE;
}
}

public void init(Client client) {
this.client = client;
}

@Override
public VisualizationsTool create(Map<String, Object> params) {
String index = params.get("index") == null ? ".kibana" : (String) params.get("index");
String sizeStr = params.get("size") == null ? "3" : (String) params.get("size");
int size;
try {
size = Integer.parseInt(sizeStr);
} catch (NumberFormatException ignored) {
size = DEFAULT_SIZE;
}
return VisualizationsTool.builder().client(client).index(index).size(size).build();
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
161 changes: 161 additions & 0 deletions src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.agent.tools;

import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.spi.tools.Tool;

public class VisualizationsToolTests {
@Mock
private Client client;

private String searchResponse = "{}";
private String searchResponseNotFound = "{}";

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
VisualizationsTool.Factory.getInstance().init(client);
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization.json")) {
if (searchResponseIns != null) {
searchResponse = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
}
}
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization_not_found.json")) {
if (searchResponseIns != null) {
searchResponseNotFound = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
}
}
}

@Test
public void testToolIndexName() {
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(tool1.getIndex(), ".kibana");

VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("index", "test-index"));
assertEquals(tool2.getIndex(), "test-index");
}

@Test
public void testNumberOfVisualizationReturned() {
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(tool1.getSize(), 3);

VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "1"));
assertEquals(tool2.getSize(), 1);

VisualizationsTool tool3 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "badString"));
assertEquals(tool3.getSize(), 3);
}

@Test
public void testTrimPrefix() {
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(tool.trimIdPrefix(null), "");
assertEquals(tool.trimIdPrefix("abc"), "abc");
assertEquals(tool.trimIdPrefix("visualization:abc"), "abc");
}

@Test
public void testParameterValidation() {
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
Assert.assertFalse(tool.validate(Collections.emptyMap()));
Assert.assertFalse(tool.validate(Map.of("input", "")));
Assert.assertTrue(tool.validate(Map.of("input", "question")));
}

@Test
public void testRunToolWithVisualizationFound() throws Exception {
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class);
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture());

Map<String, String> params = Map.of("input", "Sales by gender");

tool.run(params, listener);

SearchResponse response = SearchResponse
.fromXContent(
JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponse)
);
searchResponseListener.getValue().onResponse(response);

future.join();
assertEquals("Title,Id\n[Ecommerce]Sales by gender,aeb212e0-4c84-11e8-b3d7-01146121b73d\n", future.get());
}

@Test
public void testRunToolWithNoVisualizationFound() throws Exception {
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class);
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture());

Map<String, String> params = Map.of("input", "Sales by gender");

tool.run(params, listener);

SearchResponse response = SearchResponse
.fromXContent(
JsonXContent.jsonXContent
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponseNotFound)
);
searchResponseListener.getValue().onResponse(response);

future.join();
assertEquals("No Visualization found", future.get());
}

@Test
public void testRunToolWithIndexNotExists() throws Exception {
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class);
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture());

Map<String, String> params = Map.of("input", "Sales by gender");

tool.run(params, listener);

IndexNotFoundException notFoundException = new IndexNotFoundException("test-index");
searchResponseListener.getValue().onFailure(notFoundException);

future.join();
assertEquals("No Visualization found", future.get());
}
}
Loading

0 comments on commit b0db422

Please sign in to comment.