Skip to content

Commit

Permalink
Allow overriding React agent steps (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
raulraja authored and gerson24 committed Oct 31, 2023
1 parent 3584319 commit aa50881
Showing 1 changed file with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@ class ReActAgent(
private val scope: Conversation,
private val tools: List<Tool>,
private val maxIterations: Int = 10,
private val configuration: PromptConfiguration = PromptConfiguration(temperature = 0.0)
private val configuration: PromptConfiguration = PromptConfiguration(temperature = 0.0),
private val critique: suspend ReActAgent.(String, Finish) -> Critique =
{ input: String, finish: Finish ->
critiqueCall(input, finish)
},
private val decide: suspend ReActAgent.(String, Int, ThoughtObservation) -> Decide =
{ input: String, iterations: Int, thought: ThoughtObservation ->
decideCall(input, iterations, thought)
},
private val finish: suspend ReActAgent.(String) -> Finish = { input: String ->
finishCall(input)
},
private val runTool: suspend ReActAgent.() -> RunTool = { runToolCall() },
) : Conversation by scope {

sealed class Result {
Expand All @@ -32,7 +44,7 @@ class ReActAgent(
data class Finish(val result: String) : Result()
}

private data class ThoughtObservation(val thought: String, val observation: String)
data class ThoughtObservation(val thought: String, val observation: String)

@Serializable
enum class NextStep {
Expand Down Expand Up @@ -161,11 +173,11 @@ class ReActAgent(
if (currentIteration > maxIterations) {
emit(Result.MaxIterationsReached("🤷‍ Max iterations reached"))
} else {
val decide = decideCall(prompt = prompt, thought = thought, iterations = currentIteration)
val decide = decide(prompt, currentIteration, thought)
emit(Result.Log("🤖 I decided : ${decide.thought}"))
when (decide.nextStep) {
NextStep.RunTool -> {
val runTool = runToolCall()
val runTool = runTool()
val tool = tools.find { it.name.equals(runTool.tool, ignoreCase = true) }
if (tool == null) {
emit(Result.Log("🤖 I don't know how to use the tool ${runTool.tool}"))
Expand All @@ -187,8 +199,8 @@ class ReActAgent(
}
}
NextStep.Finish -> {
val result = finishCall(prompt = prompt)
val critique = critiqueCall(prompt = prompt, finish = result)
val result = finish(prompt)
val critique = critique(prompt, result)
emit(Result.Log("🤖 After critiquing the answer I decided : ${critique.thought}"))
when (critique.outcome) {
CompleteAnswerForUserRequest -> emit(Result.Finish(result.result))
Expand Down

0 comments on commit aa50881

Please sign in to comment.