diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt index 25088ebda..1f8577308 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/tools/ReActAgent.kt @@ -19,7 +19,19 @@ class ReActAgent( private val scope: Conversation, private val tools: List, private val maxIterations: Int = 10, - private val configuration: PromptConfiguration = PromptConfiguration(temperature = 0.0) + private val configuration: PromptConfiguration = PromptConfiguration(temperature = 0.0), + private val critique: suspend ReActAgent.(String, Finish) -> Critique = + { input: String, finish: Finish -> + critiqueCall(input, finish) + }, + private val decide: suspend ReActAgent.(String, Int, ThoughtObservation) -> Decide = + { input: String, iterations: Int, thought: ThoughtObservation -> + decideCall(input, iterations, thought) + }, + private val finish: suspend ReActAgent.(String) -> Finish = { input: String -> + finishCall(input) + }, + private val runTool: suspend ReActAgent.() -> RunTool = { runToolCall() }, ) : Conversation by scope { sealed class Result { @@ -32,7 +44,7 @@ class ReActAgent( data class Finish(val result: String) : Result() } - private data class ThoughtObservation(val thought: String, val observation: String) + data class ThoughtObservation(val thought: String, val observation: String) @Serializable enum class NextStep { @@ -161,11 +173,11 @@ class ReActAgent( if (currentIteration > maxIterations) { emit(Result.MaxIterationsReached("🤷‍ Max iterations reached")) } else { - val decide = decideCall(prompt = prompt, thought = thought, iterations = currentIteration) + val decide = decide(prompt, currentIteration, thought) emit(Result.Log("🤖 I decided : ${decide.thought}")) when (decide.nextStep) { NextStep.RunTool -> { - val runTool = runToolCall() + val runTool = runTool() val tool = tools.find { it.name.equals(runTool.tool, ignoreCase = true) } if (tool == null) { emit(Result.Log("🤖 I don't know how to use the tool ${runTool.tool}")) @@ -187,8 +199,8 @@ class ReActAgent( } } NextStep.Finish -> { - val result = finishCall(prompt = prompt) - val critique = critiqueCall(prompt = prompt, finish = result) + val result = finish(prompt) + val critique = critique(prompt, result) emit(Result.Log("🤖 After critiquing the answer I decided : ${critique.thought}")) when (critique.outcome) { CompleteAnswerForUserRequest -> emit(Result.Finish(result.result))