Skip to content

Commit

Permalink
add RepC option1
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Feb 21, 2024
1 parent 2c23a4c commit b169b2f
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class MLAgent implements ToXContentObject, Writeable {
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
public static final String APP_TYPE_FIELD = "app_type";
public static final String TOOL_SELECTION_FIELD = "tool_selection";

private String name;
private String type;
Expand All @@ -49,6 +50,7 @@ public class MLAgent implements ToXContentObject, Writeable {
private List<MLToolSpec> tools;
private Map<String, String> parameters;
private MLMemorySpec memory;
private MLToolSelectionSpec toolSelectionSpec;

private Instant createdTime;
private Instant lastUpdateTime;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class MLToolSelectionSpec implements ToXContentObject {
public static final String TOOL_SELECTION_TYPE_FIELD = "type";
public static final String TOOL_SELECTION_MODEL_ID_FIELD = "model_id";

private String type;
private String model_id;

@Builder(toBuilder = true)
public MLToolSelectionSpec(String type,
String model_id) {
if (type == null) {
type = "original";
}
else {
this.type = type;
}
this.model_id = model_id;
}


public MLToolSelectionSpec(StreamInput input) throws IOException{
type = input.readString();
model_id = input.readOptionalString();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(type);
out.writeOptionalString(model_id);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (type != null) {
builder.field(TOOL_SELECTION_MODEL_ID_FIELD, type);
}
if (model_id != null) {
builder.field(TOOL_SELECTION_MODEL_ID_FIELD, model_id);
}
builder.endObject();
return builder;
}

public static MLToolSelectionSpec parse(XContentParser parser) throws IOException {
String type = null;
String model_id = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case TOOL_SELECTION_TYPE_FIELD:
type = parser.text();
break;
case TOOL_SELECTION_MODEL_ID_FIELD:
model_id = parser.text();
break;
default:
parser.skipChildren();
break;
}
}
return MLToolSelectionSpec.builder()
.type(type)
.model_id(model_id)
.build();
}

public static MLToolSelectionSpec fromStream(StreamInput in) throws IOException {
MLToolSelectionSpec toolSelectionSpec = new MLToolSelectionSpec(in);
return toolSelectionSpec;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

log.info("to llm");
log.info(payload);
if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.MLTaskResponse;

public class AgentUtils {

Expand Down Expand Up @@ -294,4 +296,11 @@ public static List<String> getToolNames(Map<String, Tool> tools) {
}
return inputTools;
}

public static MLTaskResponse constructNextStepResponseFromMap(Map<String, String> map) {
ModelTensor modelTensor = new ModelTensor("nextOutput", null, null, null, null, null, map);
ModelTensors modelTensors = new ModelTensors(List.of(modelTensor));
ModelTensorOutput modelTensorOutput = new ModelTensorOutput(List.of(modelTensors));
return MLTaskResponse.builder().output(modelTensorOutput).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,12 @@
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.*;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.*;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -177,7 +167,7 @@ private void runAgent(MLAgent mlAgent, Map<String, String> params, ActionListene
Map<String, MLToolSpec> toolSpecMap = new HashMap<>();
createTools(toolFactories, params, toolSpecs, tools, toolSpecMap);

runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, listener);
runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, listener, mlAgent.getParameters());
}

private void runReAct(
Expand All @@ -187,7 +177,8 @@ private void runReAct(
Map<String, String> parameters,
Memory memory,
String sessionId,
ActionListener<Object> listener
ActionListener<Object> listener,
Map<String, String> agentParameters
) {
final List<String> inputTools = getToolNames(tools);
String question = parameters.get(MLAgentExecutor.QUESTION);
Expand Down Expand Up @@ -234,9 +225,14 @@ private void runReAct(
if (finalI % 2 == 0) {
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
List<String> llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class);
Map<String, String> modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns);

// List<String> llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class);
// Map<String, String> modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns);
Map<String, String> modelOutput = (Map<String, String>) tmpModelTensorOutput
.getMlModelOutputs()
.get(0)
.getMlModelTensors()
.get(0)
.getDataAsMap();
String thought = String.valueOf(modelOutput.get(THOUGHT));
String action = String.valueOf(modelOutput.get(ACTION));
String actionInput = String.valueOf(modelOutput.get(ACTION_INPUT));
Expand Down Expand Up @@ -371,6 +367,8 @@ private void runReAct(
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
} else {
getNextStep(llm, tmpParameters, parameters, (ActionListener<MLTaskResponse>) nextStepListener, tools, inputTools, scratchpadBuilder, agentParameters);
/*
ActionRequest request = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
Expand All @@ -380,6 +378,7 @@ private void runReAct(
.build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
*/
}
}
}, e -> {
Expand All @@ -391,6 +390,8 @@ private void runReAct(
}
}

getNextStep(llm, tmpParameters, parameters, firstListener, tools, inputTools, scratchpadBuilder, agentParameters);
/*
ActionRequest request = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
Expand All @@ -400,6 +401,144 @@ private void runReAct(
.build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
*/

}

private void getNextStep(
LLMSpec llm,
Map<String, String> tmpParameters,
Map<String, String> parameters,
ActionListener<MLTaskResponse> nextListener,
Map<String, Tool> tools,
List<String> inputTools,
StringBuilder scratchpadBuilder,
Map<String, String> agentParameters
) {
String toolSeletctionType = "original";
String toolSelectionModelId;
if (agentParameters.containsKey("tool_selection"))
{
Map<String, String> toolSelectionConfig = gson.fromJson(agentParameters.get("tool_selection"), Map.class);
toolSeletctionType = toolSelectionConfig.getOrDefault("type", "original");
toolSelectionModelId = toolSelectionConfig.getOrDefault("model_id", "");

} else {
toolSelectionModelId = "";
}
if (toolSeletctionType.equals("original")) {
ActionRequest request = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
.build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(output -> {
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
List<String> llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class);
Map<String, String> modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns);
nextListener.onResponse(constructNextStepResponseFromMap(modelOutput));
}, nextListener::onFailure));
} else {
tmpParameters.put("prompt.format_instruction", PROMPT_FORMAT_INSTRUCTION_FOR_THOUGHT_EXTRACT);
tmpParameters.put("prompt.suffix", PROMPT_TEMPLATE_SUFFIX_FOR_THOUGHT_EXTRACT);
String thuoghtPrompt = constructLLMPrompt(tools, parameters, inputTools, tmpParameters);
StringSubstitutor tmpSubstitutor = new StringSubstitutor(
Map.of(SCRATCHPAD, scratchpadBuilder.toString()),
"${parameters.",
"}"
);
tmpParameters.put("prompt", tmpSubstitutor.replace(thuoghtPrompt));
Map<String, String> nextStepOutput = new HashMap<>();
ActionRequest thoughtRequest = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
.build()
);
client.execute(MLPredictionTaskAction.INSTANCE, thoughtRequest, ActionListener.<MLTaskResponse>wrap(thoughtOutput -> {
ModelTensorOutput thoughtModelTensorOutput = (ModelTensorOutput) thoughtOutput.getOutput();
Map<String, String> thoughtMap = parseLLMOutput(
thoughtModelTensorOutput,
gson.fromJson(parameters.get("llm_response_pattern"), List.class)
);
if (thoughtMap.containsKey(FINAL_ANSWER)) {
nextStepOutput.put(THOUGHT, thoughtMap.getOrDefault(THOUGHT, ""));
nextStepOutput.put(FINAL_ANSWER, thoughtMap.get(FINAL_ANSWER));
nextStepOutput.put(THOUGHT_RESPONSE, gson.toJson(nextStepOutput));
nextListener.onResponse(constructNextStepResponseFromMap(nextStepOutput));
} else {
String thought = thoughtMap.get(THOUGHT);
nextStepOutput.put(THOUGHT, thought);
Map<String, String> nameToType = new HashMap<>();
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
nameToType.put(entry.getKey(), entry.getValue().getType());
}
String questionString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(List.of(thought)));
String nameToTypeString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(nameToType));
ActionRequest getActionRequest = new MLPredictionTaskRequest(
toolSelectionModelId,
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(
RemoteInferenceInputDataSet
.builder()
.parameters(ImmutableMap.of("questions", questionString, "nameToType", nameToTypeString))
.build()
)
.build()
);
client.execute(MLPredictionTaskAction.INSTANCE, getActionRequest, ActionListener.<MLTaskResponse>wrap(actionOutput -> {
ModelTensorOutput actionModelTensorOutput = (ModelTensorOutput) actionOutput.getOutput();
String action = (String) actionModelTensorOutput
.getMlModelOutputs()
.get(0)
.getMlModelTensors()
.get(0)
.getDataAsMap()
.get("response");

//String action = "CatIndexTool";
nextStepOutput.put(ACTION, action);
tmpParameters.put("prompt.format_instruction", PROMPT_FORMAT_INSTRUCTION_FOR_ACTION_INPUT);
tmpParameters.put("prompt.suffix", PROMPT_TEMPLATE_SUFFIX_FOR_ACTION_INPUT);
tmpParameters.put("current_thought", thought);
String actionPrompt = constructLLMPrompt(tools, parameters, List.of(action), tmpParameters);
tmpParameters.put("prompt", tmpSubstitutor.replace(actionPrompt));
ActionRequest getActionInputRequest = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
.build()
);
client
.execute(
MLPredictionTaskAction.INSTANCE,
getActionInputRequest,
ActionListener.<MLTaskResponse>wrap(actionInputOutput -> {
ModelTensorOutput actionInputModelTensorOutput = (ModelTensorOutput) actionInputOutput.getOutput();
String actionInput = (String) parseLLMOutput(
actionInputModelTensorOutput,
gson.fromJson(parameters.get("llm_response_pattern"), List.class)
).get(ACTION_INPUT);
nextStepOutput.put(ACTION_INPUT, actionInput);
nextStepOutput.put(THOUGHT_RESPONSE, gson.toJson(nextStepOutput));
nextListener.onResponse(constructNextStepResponseFromMap(nextStepOutput));
}, nextListener::onFailure)
);
}, nextListener::onFailure));
}
}, nextListener::onFailure));
}

}

private static Map<String, String> parseLLMOutput(ModelTensorOutput tmpModelTensorOutput, List<String> llmResponsePatterns) {
Expand Down
Loading

0 comments on commit b169b2f

Please sign in to comment.