Skip to content

Commit

Permalink
move common function to utils
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Jan 30, 2024
1 parent 57e62e3 commit fd458c7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -23,6 +26,8 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {
Expand Down Expand Up @@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) {
throw new IllegalArgumentException("Model output is invalid");
}
}

public static String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
if (outputModel.getDataAsMap() != null) {
outputString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
} else {
outputString = outputModel.getResult();
}
} else if (output instanceof String) {
outputString = (String) output;
} else {
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
}
return outputString;
}

public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map) {
return gson.toJson(actionInput);
} else {
return String.valueOf(actionInput);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.*;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -657,34 +654,4 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
if (outputModel.getDataAsMap() != null) {
outputString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
} else {
outputString = outputModel.getResult();
}
} else if (output instanceof String) {
outputString = (String) output;
} else {
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
}
return outputString;
}

private String parseInputFromLLMReturn(Map<String, ?> retMap){
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map)
{
return gson.toJson(actionInput);
}
else {
return String.valueOf(actionInput);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ private void extractFromChatParameters(Map<String, String> parameters) {
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
parameters.putAll(chatParameters);
} catch (Exception exception) {
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
} finally {
return ;
return;
}
}
}
Expand Down

0 comments on commit fd458c7

Please sign in to comment.