-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Visualization Tool * fix build failure due to forbiddenApis * Address review comments * spotlessApply * update default tool name * update number of visualization be dynamic --------- (cherry picked from commit 3774eb9) Signed-off-by: Hailong Cui <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
cd54a22
commit b0db422
Showing
6 changed files
with
425 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
171 changes: 171 additions & 0 deletions
171
src/main/java/org/opensearch/agent/tools/VisualizationsTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import java.util.Arrays; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.client.Requests; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.index.query.BoolQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.search.SearchHits; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
import com.google.common.base.Strings; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
@ToolAnnotation(VisualizationsTool.TYPE) | ||
public class VisualizationsTool implements Tool { | ||
public static final String NAME = "FindVisualizations"; | ||
public static final String TYPE = "VisualizationTool"; | ||
public static final String VERSION = "v1.0"; | ||
|
||
public static final String SAVED_OBJECT_TYPE = "visualization"; | ||
|
||
/** | ||
* default number of visualizations returned | ||
*/ | ||
private static final int DEFAULT_SIZE = 3; | ||
private static final String DEFAULT_DESCRIPTION = | ||
"Use this tool to find user created visualizations. This tool takes the visualization name as input and returns matching visualizations"; | ||
@Setter | ||
@Getter | ||
private String description = DEFAULT_DESCRIPTION; | ||
|
||
@Getter | ||
@Setter | ||
private String name = NAME; | ||
@Getter | ||
@Setter | ||
private String type = TYPE; | ||
@Getter | ||
private final String version = VERSION; | ||
private final Client client; | ||
@Getter | ||
private final String index; | ||
@Getter | ||
private final int size; | ||
|
||
@Builder | ||
public VisualizationsTool(Client client, String index, int size) { | ||
this.client = client; | ||
this.index = index; | ||
this.size = size; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); | ||
boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE)); | ||
boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input"))); | ||
|
||
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder); | ||
searchSourceBuilder.from(0).size(3); | ||
SearchRequest searchRequest = Requests.searchRequest(index).source(searchSourceBuilder); | ||
|
||
client.search(searchRequest, new ActionListener<>() { | ||
@Override | ||
public void onResponse(SearchResponse searchResponse) { | ||
SearchHits hits = searchResponse.getHits(); | ||
StringBuilder visBuilder = new StringBuilder(); | ||
visBuilder.append("Title,Id\n"); | ||
if (hits.getTotalHits().value > 0) { | ||
Arrays.stream(hits.getHits()).forEach(h -> { | ||
String id = trimIdPrefix(h.getId()); | ||
Map<String, String> visMap = (Map<String, String>) h.getSourceAsMap().get(SAVED_OBJECT_TYPE); | ||
String title = visMap.get("title"); | ||
visBuilder.append(String.format(Locale.ROOT, "%s,%s\n", title, id)); | ||
}); | ||
|
||
listener.onResponse((T) visBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) "No Visualization found"); | ||
} | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) { | ||
listener.onResponse((T) "No Visualization found"); | ||
} else { | ||
listener.onFailure(e); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
@VisibleForTesting | ||
String trimIdPrefix(String id) { | ||
id = Optional.ofNullable(id).orElse(""); | ||
if (id.startsWith(SAVED_OBJECT_TYPE)) { | ||
String prefix = String.format(Locale.ROOT, "%s:", SAVED_OBJECT_TYPE); | ||
return id.substring(prefix.length()); | ||
} | ||
return id; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input")); | ||
} | ||
|
||
public static class Factory implements Tool.Factory<VisualizationsTool> { | ||
private Client client; | ||
|
||
private static VisualizationsTool.Factory INSTANCE; | ||
|
||
public static VisualizationsTool.Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (VisualizationsTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new VisualizationsTool.Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
public void init(Client client) { | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
public VisualizationsTool create(Map<String, Object> params) { | ||
String index = params.get("index") == null ? ".kibana" : (String) params.get("index"); | ||
String sizeStr = params.get("size") == null ? "3" : (String) params.get("size"); | ||
int size; | ||
try { | ||
size = Integer.parseInt(sizeStr); | ||
} catch (NumberFormatException ignored) { | ||
size = DEFAULT_SIZE; | ||
} | ||
return VisualizationsTool.builder().client(client).index(index).size(size).build(); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
} | ||
} |
161 changes: 161 additions & 0 deletions
161
src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.agent.tools; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.Collections; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.ArgumentMatchers; | ||
import org.mockito.Mock; | ||
import org.mockito.Mockito; | ||
import org.mockito.MockitoAnnotations; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.json.JsonXContent; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.DeprecationHandler; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
|
||
public class VisualizationsToolTests { | ||
@Mock | ||
private Client client; | ||
|
||
private String searchResponse = "{}"; | ||
private String searchResponseNotFound = "{}"; | ||
|
||
@Before | ||
public void setup() throws IOException { | ||
MockitoAnnotations.openMocks(this); | ||
VisualizationsTool.Factory.getInstance().init(client); | ||
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization.json")) { | ||
if (searchResponseIns != null) { | ||
searchResponse = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); | ||
} | ||
} | ||
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization_not_found.json")) { | ||
if (searchResponseIns != null) { | ||
searchResponseNotFound = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
public void testToolIndexName() { | ||
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(tool1.getIndex(), ".kibana"); | ||
|
||
VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("index", "test-index")); | ||
assertEquals(tool2.getIndex(), "test-index"); | ||
} | ||
|
||
@Test | ||
public void testNumberOfVisualizationReturned() { | ||
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(tool1.getSize(), 3); | ||
|
||
VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "1")); | ||
assertEquals(tool2.getSize(), 1); | ||
|
||
VisualizationsTool tool3 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "badString")); | ||
assertEquals(tool3.getSize(), 3); | ||
} | ||
|
||
@Test | ||
public void testTrimPrefix() { | ||
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(tool.trimIdPrefix(null), ""); | ||
assertEquals(tool.trimIdPrefix("abc"), "abc"); | ||
assertEquals(tool.trimIdPrefix("visualization:abc"), "abc"); | ||
} | ||
|
||
@Test | ||
public void testParameterValidation() { | ||
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
Assert.assertFalse(tool.validate(Collections.emptyMap())); | ||
Assert.assertFalse(tool.validate(Map.of("input", ""))); | ||
Assert.assertTrue(tool.validate(Map.of("input", "question"))); | ||
} | ||
|
||
@Test | ||
public void testRunToolWithVisualizationFound() throws Exception { | ||
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally); | ||
|
||
ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); | ||
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); | ||
|
||
Map<String, String> params = Map.of("input", "Sales by gender"); | ||
|
||
tool.run(params, listener); | ||
|
||
SearchResponse response = SearchResponse | ||
.fromXContent( | ||
JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponse) | ||
); | ||
searchResponseListener.getValue().onResponse(response); | ||
|
||
future.join(); | ||
assertEquals("Title,Id\n[Ecommerce]Sales by gender,aeb212e0-4c84-11e8-b3d7-01146121b73d\n", future.get()); | ||
} | ||
|
||
@Test | ||
public void testRunToolWithNoVisualizationFound() throws Exception { | ||
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally); | ||
|
||
ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); | ||
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); | ||
|
||
Map<String, String> params = Map.of("input", "Sales by gender"); | ||
|
||
tool.run(params, listener); | ||
|
||
SearchResponse response = SearchResponse | ||
.fromXContent( | ||
JsonXContent.jsonXContent | ||
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponseNotFound) | ||
); | ||
searchResponseListener.getValue().onResponse(response); | ||
|
||
future.join(); | ||
assertEquals("No Visualization found", future.get()); | ||
} | ||
|
||
@Test | ||
public void testRunToolWithIndexNotExists() throws Exception { | ||
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally); | ||
|
||
ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); | ||
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); | ||
|
||
Map<String, String> params = Map.of("input", "Sales by gender"); | ||
|
||
tool.run(params, listener); | ||
|
||
IndexNotFoundException notFoundException = new IndexNotFoundException("test-index"); | ||
searchResponseListener.getValue().onFailure(notFoundException); | ||
|
||
future.join(); | ||
assertEquals("No Visualization found", future.get()); | ||
} | ||
} |
Oops, something went wrong.