Skip to content

Commit

Permalink
Support pass prompt to CreateAlertTool (#452)
Browse files Browse the repository at this point in the history
* Support pass prompt to CreateAlertTool

Signed-off-by: Heng Qian <[email protected]>

* Use Claude as ModelType if passed-in modelType is empty

Signed-off-by: Heng Qian <[email protected]>

* Fix UT

Signed-off-by: Heng Qian <[email protected]>

* Fix spotlessApply

Signed-off-by: Heng Qian <[email protected]>

* Fix spotlessApply

Signed-off-by: Heng Qian <[email protected]>

---------

Signed-off-by: Heng Qian <[email protected]>
Co-authored-by: zane-neo <[email protected]>
  • Loading branch information
qianheng-aws and zane-neo authored Nov 22, 2024
1 parent 34d2005 commit a95cfae
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 25 deletions.
58 changes: 43 additions & 15 deletions src/main/java/org/opensearch/agent/tools/CreateAlertTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -62,29 +63,55 @@ public class CreateAlertTool implements Tool {
private String description = DEFAULT_DESCRIPTION;

private final Client client;
@Getter
private final String modelId;
private final String TOOL_PROMPT_TEMPLATE;
@Getter
private final String modelType;
@Getter
private final String toolPrompt;

private static final String MODEL_ID = "model_id";
private static final String PROMPT_FILE_PATH = "CreateAlertDefaultPrompt.json";
private static final String DEFAULT_QUESTION = "Create an alert as your recommendation based on the context";
private static final Map<String, String> promptDict = ToolHelper.loadDefaultPromptDictFromFile(CreateAlertTool.class, PROMPT_FILE_PATH);

public CreateAlertTool(Client client, String modelId, String modelType) {
public enum ModelType {
CLAUDE,
OPENAI;

public static ModelType from(String value) {
if (value.isEmpty()) {
return ModelType.CLAUDE;
}
try {
return ModelType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
log.error("Wrong Model type, should be CLAUDE or OPENAI");
return ModelType.CLAUDE;
}
}
}

public CreateAlertTool(Client client, String modelId, String modelType, String prompt) {
this.client = client;
this.modelId = modelId;
if (!promptDict.containsKey(modelType)) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]",
modelType,
String.join(",", promptDict.keySet())
)
);
this.modelType = String.valueOf(ModelType.from(modelType));
if (prompt.isEmpty()) {
if (!promptDict.containsKey(this.modelType)) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
null,
"Failed to find the right prompt for modelType: {}, this tool supports prompts for these models: [{}]",
modelType,
String.join(",", promptDict.keySet())
)
);
}
this.toolPrompt = promptDict.get(this.modelType);
} else {
this.toolPrompt = prompt;
}
TOOL_PROMPT_TEMPLATE = promptDict.get(modelType);
}

@Override
Expand Down Expand Up @@ -205,7 +232,7 @@ private ActionRequest constructMLPredictRequest(Map<String, String> tmpParams, S
tmpParams.putIfAbsent("chat_history", "");
tmpParams.putIfAbsent("question", DEFAULT_QUESTION); // In case no question is provided, use a default question.
StringSubstitutor substitute = new StringSubstitutor(tmpParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT_TEMPLATE);
String finalToolPrompt = substitute.replace(toolPrompt);
tmpParams.put("prompt", finalToolPrompt);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParams).build();
Expand Down Expand Up @@ -279,7 +306,8 @@ public CreateAlertTool create(Map<String, Object> params) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
}
String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString());
return new CreateAlertTool(client, modelId, modelType);
String prompt = (String) params.getOrDefault("prompt", "");
return new CreateAlertTool(client, modelId, modelType, prompt);
}

@Override
Expand Down
36 changes: 26 additions & 10 deletions src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,32 @@ public void testTool_WithBlankModelId() {

@Test
public void testTool_WithNonSupportedModelType() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType"))
);
assertEquals(
"Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]",
exception.getMessage()
);
CreateAlertTool alertTool = CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType"));
assertEquals("CLAUDE", alertTool.getModelType());
}

@Test
public void testTool_WithEmptyModelType() {
CreateAlertTool alertTool = CreateAlertTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", ""));
assertEquals("CLAUDE", alertTool.getModelType());
}

@Test
public void testToolWithCustomPrompt() {
CreateAlertTool tool = CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "custom prompt"));
assertEquals(CreateAlertTool.TYPE, tool.getName());
assertEquals("modelId", tool.getModelId());
assertEquals("custom prompt", tool.getToolPrompt());

tool
.run(
ImmutableMap.of("indices", mockedIndexName),
ActionListener.<String>wrap(response -> assertEquals(jsonResponse, response), log::info)
);
}

@Test
Expand Down

0 comments on commit a95cfae

Please sign in to comment.