Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* fixPPLAllowedFields

* rename variables

---------

(cherry picked from commit 9174e4c)

Signed-off-by: xinyual <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
2 people authored and yuye-aws committed Apr 26, 2024
1 parent 4f27764 commit c4465e6
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.*;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
Expand All @@ -16,9 +14,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -93,14 +93,35 @@ public class PPLTool implements Tool {

private static Gson gson = new Gson();

private static Map<String, String> defaultPromptDict;
private static Map<String, String> DEFAULT_PROMPT_DICT;

private static Set<String> ALLOWED_FIELDS_TYPE;

static {
ALLOWED_FIELDS_TYPE = new HashSet<>(); // from
// https://github.com/opensearch-project/sql/blob/2.x/docs/user/ppl/general/datatypes.rst#data-types-mapping
ALLOWED_FIELDS_TYPE.add("boolean");
ALLOWED_FIELDS_TYPE.add("byte");
ALLOWED_FIELDS_TYPE.add("short");
ALLOWED_FIELDS_TYPE.add("integer");
ALLOWED_FIELDS_TYPE.add("long");
ALLOWED_FIELDS_TYPE.add("float");
ALLOWED_FIELDS_TYPE.add("half_float");
ALLOWED_FIELDS_TYPE.add("scaled_float");
ALLOWED_FIELDS_TYPE.add("double");
ALLOWED_FIELDS_TYPE.add("keyword");
ALLOWED_FIELDS_TYPE.add("text");
ALLOWED_FIELDS_TYPE.add("date");
ALLOWED_FIELDS_TYPE.add("ip");
ALLOWED_FIELDS_TYPE.add("binary");
ALLOWED_FIELDS_TYPE.add("object");
ALLOWED_FIELDS_TYPE.add("nested");

try {
defaultPromptDict = loadDefaultPromptDict();
DEFAULT_PROMPT_DICT = loadDefaultPromptDict();
} catch (IOException e) {
log.error("fail to load default prompt dict" + e.getMessage());
defaultPromptDict = new HashMap<>();
DEFAULT_PROMPT_DICT = new HashMap<>();
}
}

Expand All @@ -127,7 +148,7 @@ public PPLTool(Client client, String modelId, String contextPrompt, String pplMo
this.modelId = modelId;
this.pplModelType = PPLModelType.from(pplModelType);
if (contextPrompt.isEmpty()) {
this.contextPrompt = this.defaultPromptDict.getOrDefault(this.pplModelType.toString(), "");
this.contextPrompt = this.DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), "");
} else {
this.contextPrompt = contextPrompt;
}
Expand All @@ -147,13 +168,15 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
+ indexName
);
}
SearchRequest searchRequest = buildSearchRequest(indexName);

GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName);
client.admin().indices().getMappings(getMappingsRequest, ActionListener.<GetMappingsResponse>wrap(getMappingsResponse -> {
Map<String, MappingMetadata> mappings = getMappingsResponse.getMappings();
if (mappings.size() == 0) {
throw new IllegalArgumentException("No matching mapping with index name: " + indexName);
}
String firstIndexName = (String) mappings.keySet().toArray()[0];
SearchRequest searchRequest = buildSearchRequest(firstIndexName);
client.search(searchRequest, ActionListener.<SearchResponse>wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
String tableInfo = constructTableInfo(searchHits, mappings);
Expand Down Expand Up @@ -318,13 +341,17 @@ private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMet
extractSamples(sampleSource, fieldsToSample, "");

for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key) + " (" + fieldsToSample.get(key) + ")";
tableInfoJoiner.add(line);
if (ALLOWED_FIELDS_TYPE.contains(fieldsToType.get(key))) {
String line = "- " + key + ": " + fieldsToType.get(key) + " (" + fieldsToSample.get(key) + ")";
tableInfoJoiner.add(line);
}
}
} else {
for (String key : sortedKeys) {
String line = "- " + key + ": " + fieldsToType.get(key);
tableInfoJoiner.add(line);
if (ALLOWED_FIELDS_TYPE.contains(fieldsToType.get(key))) {
String line = "- " + key + ": " + fieldsToType.get(key);
tableInfoJoiner.add(line);
}
}
}

Expand Down

0 comments on commit c4465e6

Please sign in to comment.