Skip to content

Commit

Permalink
Fix json parsing exception for NeuralSparseSearchTool and VectorDBTool (
Browse files Browse the repository at this point in the history
#203)

* fix the json string parsing

Signed-off-by: zhichao-aws <[email protected]>

* add it

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored Feb 7, 2024
1 parent 60d9d11 commit 38f5847
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -109,7 +111,9 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
Map<String, Object> docContent = processResponse(hit);
contextBuilder.append(gson.toJson(docContent)).append("\n");
String docContentInString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(docContent));
contextBuilder.append(docContentInString).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -57,14 +60,15 @@ protected String getQueryBody(String queryText) {
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
);
}
return "{\"query\":{\"neural_sparse\":{\""
+ embeddingField
+ "\":{\"query_text\":\""
+ queryText
+ "\",\"model_id\":\""
+ modelId
+ "\"}}}"
+ " }";

Map<String, Object> queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}

@Override
Expand Down
22 changes: 12 additions & 10 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -65,16 +68,15 @@ protected String getQueryBody(String queryText) {
"Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty."
);
}
return "{\"query\":{\"neural\":{\""
+ embeddingField
+ "\":{\"query_text\":\""
+ queryText
+ "\",\"model_id\":\""
+ modelId
+ "\",\"k\":"
+ k
+ "}}}"
+ " }";

Map<String, Object> queryBody = Map
.of("query", Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k))));

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ public void testCreateTool() {
@SneakyThrows
public void testGetQueryBody() {
NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params);
assertEquals(
"{\"query\":{\"neural_sparse\":{\"test embedding\":{\""
+ "query_text\":\"123fsd23134sdfouh\",\"model_id\":\"123fsd23134\"}}} }",
tool.getQueryBody(TEST_QUERY_TEXT)
);
Map<String, Map<String, Map<String, Map<String, String>>>> queryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("123fsd23134sdfouh", queryBody.get("query").get("neural_sparse").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params);
String jsonInput = gson.toJson(Map.of("hi", "a"));
Map<String, Map<String, Map<String, Map<String, String>>>> queryBody = gson.fromJson(tool.getQueryBody(jsonInput), Map.class);
assertEquals("{\"hi\":\"a\"}", queryBody.get("query").get("neural_sparse").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
Expand Down
20 changes: 15 additions & 5 deletions src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,21 @@ public void testCreateTool() {
@SneakyThrows
public void testGetQueryBody() {
VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params);
assertEquals(
"{\"query\":{\"neural\":{\"test embedding\":{\""
+ "query_text\":\"123fsd23134sdfouh\",\"model_id\":\"123fsd23134\",\"k\":123}}} }",
tool.getQueryBody(TEST_QUERY_TEXT)
);
Map<String, Map<String, Map<String, Map<String, String>>>> queryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("123fsd23134sdfouh", queryBody.get("query").get("neural").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("query").get("neural").get("test embedding").get("model_id"));
assertEquals(123.0, queryBody.get("query").get("neural").get("test embedding").get("k"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params);
String jsonInput = gson.toJson(Map.of("hi", "a"));
Map<String, Map<String, Map<String, Map<String, String>>>> queryBody = gson.fromJson(tool.getQueryBody(jsonInput), Map.class);
assertEquals("{\"hi\":\"a\"}", queryBody.get("query").get("neural").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("query").get("neural").get("test embedding").get("model_id"));
assertEquals(123.0, queryBody.get("query").get("neural").get("test embedding").get("k"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -114,6 +115,16 @@ public void testNeuralSparseSearchToolInFlowAgent() {
exception.getMessage(),
allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException"))
);

// use json string input
String jsonInput = gson.toJson(Map.of("parameters", Map.of("question", gson.toJson(Map.of("hi", "a")))));
String result3 = executeAgent(agentId, jsonInput);
assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n"
+ "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n",
result3
);
}

public void testNeuralSparseSearchToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
Expand Down

0 comments on commit 38f5847

Please sign in to comment.