From cfc88d8d2d67854ca6b0c9fb8dbe6621eed87f8b Mon Sep 17 00:00:00 2001 From: gaobinlong Date: Thu, 14 Nov 2024 14:16:02 +0800 Subject: [PATCH] CreateAnomalyDetectorTool supports empty model_type (#457) * CreateAnomalyDetectorTool supports empty model_type Signed-off-by: gaobinlong * Optimize code Signed-off-by: gaobinlong * Add some comment Signed-off-by: gaobinlong --------- Signed-off-by: gaobinlong --- .../agent/tools/CreateAnomalyDetectorTool.java | 6 +++++- .../tools/CreateAnomalyDetectorToolTests.java | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 60892940..e1540669 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -429,7 +429,11 @@ public CreateAnomalyDetectorTool create(Map map) { throw new IllegalArgumentException("model_id cannot be empty."); } String modelType = (String) map.getOrDefault("model_type", ModelType.CLAUDE.toString()); - if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + // if model type is empty, use the default value + if (modelType.isEmpty()) { + modelType = ModelType.CLAUDE.toString(); + } else if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) + && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { throw new IllegalArgumentException("Unsupported model_type: " + modelType); } String prompt = (String) map.getOrDefault("prompt", ""); diff --git a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java index 30ca722d..5aacc12b 100644 --- a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java @@ -260,6 +260,22 @@ public void testToolWithCustomPrompt() { ); } + @Test + public void testToolWithEmptyModelType() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + + tool + .run( + ImmutableMap.of("index", mockedIndexName), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + } + private void createMappings() { indexMappings = new HashMap<>(); indexMappings