Skip to content

Commit

Permalink
add more uts
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Feb 2, 2024
1 parent a9a20ed commit c7439f4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,46 @@ public void testParsingJsonBlockFromResponse2() {
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse3() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testRunWithIncludeOutputNotSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,15 @@ public void testAgenttestRunMethod() {
public void testAgentWithChatAgentInput() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
Map<String, String> chatAgentInput = ImmutableMap.of("input", gson.toJson(parameters));
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", gson.toJson(parameters));
doTestRunMethod(chatAgentInput);
}

@Test
public void testAgentWithChatAgentInputWrongFormat() {
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", "wrong format");
doTestRunMethod(chatAgentInput);
}

Expand Down

0 comments on commit c7439f4

Please sign in to comment.