From 158f4091294542a423b8f50b43f81cfd90c6a38d Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sun, 3 Nov 2024 14:37:48 +0000 Subject: [PATCH] Anthropic computer use (#226) --- CHANGELOG.md | 1 + Project.toml | 2 +- src/PromptingTools.jl | 3 +- src/extraction.jl | 29 +++++++++++ src/llm_anthropic.jl | 112 +++++++++++++++++++++++++++++++++++++----- src/llm_openai.jl | 40 +++++++++------ src/llm_shared.jl | 6 +++ test/extraction.jl | 18 ++++++- test/llm_anthropic.jl | 79 ++++++++++++++++++++++++++++- test/llm_openai.jl | 18 ++++++- test/llm_shared.jl | 15 ++++-- 11 files changed, 283 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f9858683..db38bb788 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Ollama streaming with schema `OllamaSchema` (see `?StreamCallback` for more information). Schema `OllamaManaged` is NOT supported (it's legacy and will be removed in the future). - Moved the implementation of streaming callbacks to a new `StreamCallbacks` package. - Added new error types for tool execution to enable better error handling and reporting (see `?AbstractToolError`). +- Added support for Anthropic's new pre-trained tools via `ToolRef` (see `?ToolRef`), to enable the feature, use the `:computer_use` beta header (eg, `aitools(..., betas = [:computer_use])`). ### Fixed - Fixed a bug in `call_cost` where the cost was not calculated if any non-AIMessages were provided in the conversation. diff --git a/Project.toml b/Project.toml index 4efaba87a..8e18c2de4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.60.0-DEV" +version = "0.60.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index e0e0f1895..cc7cce720 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -49,7 +49,8 @@ const RESERVED_KWARGS = [ :no_system_message, :aiprefill, :name_user, - :name_assistant + :name_assistant, + :betas ] # export replace_words, recursive_splitter, split_by_length, call_cost, auth_header # for debugging only diff --git a/src/extraction.jl b/src/extraction.jl index ce2d73958..0c2fa2514 100644 --- a/src/extraction.jl +++ b/src/extraction.jl @@ -48,6 +48,31 @@ Base.@kwdef struct Tool <: AbstractTool end Base.show(io::IO, t::AbstractTool) = dump(io, t; maxdepth = 1) +""" + ToolRef(ref::Symbol, callable::Any) + +Represents a reference to a tool with a symbolic name and a callable object (to call during tool execution). +It can be rendered with a `render` method and a prompt schema. + +# Arguments +- `ref::Symbol`: The symbolic name of the tool. +- `callable::Any`: The callable object of the tool, eg, a type or a function. + +# Examples +```julia +# Define a tool with a symbolic name and a callable object +tool = ToolRef(:computer, println) + +# Show the rendered tool signature +PT.render(PT.AnthropicSchema(), tool) +``` +""" +Base.@kwdef struct ToolRef <: AbstractTool + ref::Symbol + callable::Any = identity +end +Base.show(io::IO, t::ToolRef) = print(io, "ToolRef($(t.ref))") + ### Useful Error Types """ AbstractToolError @@ -556,6 +581,10 @@ function tool_call_signature( end return Dict(tool.name => tool) end +function tool_call_signature( + tool::ToolRef; kwargs...) + Dict(string(tool.ref) => tool) +end ## Add support for function signatures function tool_call_signature(f::Function; kwargs...) diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index e999a0427..b5fc83951 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -102,14 +102,75 @@ Renders the tool signatures into the Anthropic format. function render(schema::AbstractAnthropicSchema, tools::Vector{<:AbstractTool}; kwargs...) - tools = [Dict(:name => tool.name, - :description => isnothing(tool.description) ? "" : tool.description, - :input_schema => tool.parameters) for tool in tools] - return tools + [render(schema, tool; kwargs...) for tool in tools] +end +function render(schema::AbstractAnthropicSchema, + tool::AbstractTool; + kwargs...) + return Dict( + :name => tool.name, + :description => isnothing(tool.description) ? "" : tool.description, + :input_schema => tool.parameters + ) end """ - anthropic_extra_headers + render(schema::AbstractAnthropicSchema, + tool::ToolRef; + kwargs...) + +Renders the tool reference into the Anthropic format. + +Available tools: +- `:computer`: A tool for using the computer. +- `:str_replace_editor`: A tool for replacing text in a string. +- `:bash`: A tool for running bash commands. +""" +function render(schema::AbstractAnthropicSchema, + tool::ToolRef; + kwargs...) + ## WARNING: We ignore the tool name here, because the names are strict + rendered = if tool.ref == :computer + Dict( + "type" => "computer_20241022", + "name" => "computer", + "display_width_px" => 1024, + "display_height_px" => 768, + "display_number" => 1 + ) + elseif tool.ref == :str_replace_editor + Dict( + "type" => "text_editor_20241022", + "name" => "str_replace_editor" + ) + elseif tool.ref == :bash + Dict( + "type" => "bash_20241022", + "name" => "bash" + ) + else + throw(ArgumentError("Unknown tool reference: $(tool.ref)")) + end + return rendered +end + +""" + BETA_HEADERS_ANTHROPIC + +A vector of symbols representing the beta features to be used. + +Allowed: +- `:tools`: Enables tools in the conversation. +- `:cache`: Enables prompt caching. +- `:long_output`: Enables long outputs (up to 8K tokens) with Anthropic's Sonnet 3.5. +- `:computer_use`: Enables the use of the computer tool. +""" +const BETA_HEADERS_ANTHROPIC = [:tools, :cache, :long_output, :computer_use] + +""" + anthropic_extra_headers(; + has_tools = false, has_cache = false, has_long_output = false, + betas::Union{Nothing, Vector{Symbol}} = nothing) Adds API version and beta headers to the request. @@ -117,20 +178,31 @@ Adds API version and beta headers to the request. - `has_tools`: Enables tools in the conversation. - `has_cache`: Enables prompt caching. - `has_long_output`: Enables long outputs (up to 8K tokens) with Anthropic's Sonnet 3.5. +- `betas`: A vector of symbols representing the beta features to be used. Currently only `:computer_use`, `:long_output`, `:tools` and `:cache` are supported. + +Refer to `BETA_HEADERS_ANTHROPIC` for the allowed beta features. """ function anthropic_extra_headers(; - has_tools = false, has_cache = false, has_long_output = false) + has_tools = false, has_cache = false, has_long_output = false, + betas::Union{Nothing, Vector{Symbol}} = nothing) + global BETA_HEADERS_ANTHROPIC + betas_parsed = isnothing(betas) ? Symbol[] : betas + @assert all(b -> b in BETA_HEADERS_ANTHROPIC, betas_parsed) "Unknown beta feature: $(setdiff(betas_parsed, BETA_HEADERS_ANTHROPIC))" + ## extra_headers = ["anthropic-version" => "2023-06-01"] beta_headers = String[] - if has_tools + if has_tools || :tools in betas_parsed push!(beta_headers, "tools-2024-04-04") end - if has_cache + if has_cache || :cache in betas_parsed push!(beta_headers, "prompt-caching-2024-07-31") end - if has_long_output + if has_long_output || :long_output in betas_parsed push!(beta_headers, "max-tokens-3-5-sonnet-2024-07-15") end + if :computer_use in betas_parsed + push!(beta_headers, "computer-use-2024-10-22") + end if !isempty(beta_headers) extra_headers = [extra_headers..., "anthropic-beta" => join(beta_headers, ",")] end @@ -150,6 +222,7 @@ end stream::Bool = false, url::String = "https://api.anthropic.com/v1", cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, kwargs...) Simple wrapper for a call to Anthropic API. @@ -165,6 +238,7 @@ Simple wrapper for a call to Anthropic API. - `stream`: A boolean indicating whether to stream the response. Defaults to `false`. - `url`: The URL of the Ollama API. Defaults to "localhost". - `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported. +- `betas`: A vector of symbols representing the beta features to be used. Currently only `:tools` and `:cache` are supported. - `kwargs`: Prompt variables to be used to fill the prompt/template """ function anthropic_api( @@ -179,6 +253,7 @@ function anthropic_api( streamcallback::Any = nothing, url::String = "https://api.anthropic.com/v1", cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, kwargs...) @assert endpoint in ["messages"] "Only 'messages' endpoint is supported." ## @@ -191,7 +266,8 @@ function anthropic_api( ## Build the headers extra_headers = anthropic_extra_headers(; has_tools = haskey(kwargs, :tools), has_cache = !isnothing(cache), - has_long_output = (max_tokens > 4096 && model in ["claude-3-5-sonnet-20240620"])) + has_long_output = (max_tokens > 4096 && model in ["claude-3-5-sonnet-20240620"]), + betas = betas) headers = auth_header( api_key; bearer = false, x_api_key = true, extra_headers) @@ -234,6 +310,7 @@ end aiprefill::Union{Nothing, AbstractString} = nothing, http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, kwargs...) Generate an AI response based on a given prompt using the Anthropic API. @@ -259,6 +336,7 @@ Generate an AI response based on a given prompt using the Anthropic API. - `:tools`: Caches the tool definitions (and everything before them) - `:last`: Caches the last message in the conversation (and everything before it) - `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost) +- `betas::Union{Nothing, Vector{Symbol}}`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details. - `kwargs`: Prompt variables to be used to fill the prompt/template Note: At the moment, the cache is only allowed for prompt segments over 1024 tokens (in some cases, over 2048 tokens). You'll get an error if you try to cache short prompts. @@ -351,6 +429,7 @@ function aigenerate( aiprefill::Union{Nothing, AbstractString} = nothing, http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, kwargs...) ## global MODEL_ALIASES @@ -364,7 +443,8 @@ function aigenerate( if !dry_run time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, - conv_rendered.system, endpoint = "messages", model = model_id, streamcallback, http_kwargs, cache, + conv_rendered.system, endpoint = "messages", model = model_id, + streamcallback, http_kwargs, cache, betas, api_kwargs...) tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) @@ -420,6 +500,7 @@ end retries = 5, readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, kwargs...) Extract required information (defined by a struct **`return_type`**) from the provided prompt by leveraging Anthropic's function calling mode. @@ -452,6 +533,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi - `:tools`: Caches the tool definitions (and everything before them) - `:last`: Caches the last message in the conversation (and everything before it) - `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost) +- `betas::Union{Nothing, Vector{Symbol}}`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details. - `kwargs`: Prompt variables to be used to fill the prompt/template Note: At the moment, the cache is only allowed for prompt segments over 1024 tokens (in some cases, over 2048 tokens). You'll get an error if you try to cache short prompts. @@ -580,6 +662,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP retries = 5, readtimeout = 120), api_kwargs::NamedTuple = (; tool_choice = nothing), cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, kwargs...) ## global MODEL_ALIASES @@ -622,7 +705,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP if !dry_run time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, - conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, + conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, betas, api_kwargs...) tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) @@ -681,6 +764,7 @@ end conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], no_system_message::Bool = false, cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = (; @@ -706,6 +790,7 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. - `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history. - `cache::Union{Nothing, Symbol} = nothing`: Whether to cache the prompt. Defaults to `nothing`. +- `betas::Union{Nothing, Vector{Symbol}} = nothing`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details. - `http_kwargs`: A named tuple of HTTP keyword arguments. - `api_kwargs`: A named tuple of API keyword arguments. Several important arguments are highlighted below: - `tool_choice`: The choice of tool mode. Can be "auto", "exact", or can depend on the provided.. Defaults to `nothing`, which translates to "auto". @@ -761,6 +846,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_ conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], no_system_message::Bool = false, cache::Union{Nothing, Symbol} = nothing, + betas::Union{Nothing, Vector{Symbol}} = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = (; @@ -800,7 +886,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_ if !dry_run time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, - conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, + conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, betas, api_kwargs...) tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 0b6f65075..902d2b515 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -98,23 +98,31 @@ function render(schema::AbstractOpenAISchema, tools::Vector{<:AbstractTool}; json_mode::Union{Nothing, Bool} = nothing, kwargs...) - output = Dict{Symbol, Any}[] - for tool in tools - rendered = Dict(:type => "function", - :function => Dict( - :parameters => tool.parameters, :name => tool.name)) - ## Add strict flag - tool.strict == true && (rendered[:function][:strict] = tool.strict) - if json_mode == true - rendered[:function][:schema] = pop!(rendered[:function], :parameters) - else - ## Add description if not in JSON mode - !isnothing(tool.description) && - (rendered[:function][:description] = tool.description) - end - push!(output, rendered) + [render(schema, tool; json_mode, kwargs...) for tool in tools] +end +function render(schema::AbstractOpenAISchema, + tool::AbstractTool; + json_mode::Union{Nothing, Bool} = nothing, + kwargs...) + rendered = Dict(:type => "function", + :function => Dict( + :parameters => tool.parameters, :name => tool.name)) + ## Add strict flag + tool.strict == true && (rendered[:function][:strict] = tool.strict) + if json_mode == true + rendered[:function][:schema] = pop!(rendered[:function], :parameters) + else + ## Add description if not in JSON mode + !isnothing(tool.description) && + (rendered[:function][:description] = tool.description) end - return output + return rendered +end +function render(schema::AbstractOpenAISchema, + tool::ToolRef; + json_mode::Union{Nothing, Bool} = nothing, + kwargs...) + throw(ArgumentError("Function `render` is not implemented for the provided schema ($(typeof(schema))) and $(typeof(tool)).")) end """ diff --git a/src/llm_shared.jl b/src/llm_shared.jl index 2acb8968b..fcb6e5702 100644 --- a/src/llm_shared.jl +++ b/src/llm_shared.jl @@ -99,6 +99,12 @@ function render(schema::AbstractPromptSchema, kwargs...) render(schema, collect(values(tools)); kwargs...) end +# For ToolRef +function render(schema::AbstractPromptSchema, + tool::AbstractTool; + kwargs...) + throw(ArgumentError("Function `render` is not implemented for the provided schema ($(typeof(schema))) and $(typeof(tool)).")) +end """ finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any, diff --git a/test/extraction.jl b/test/extraction.jl index 90a7b3a30..2c4efce8a 100644 --- a/test/extraction.jl +++ b/test/extraction.jl @@ -4,7 +4,7 @@ using PromptingTools: tool_call_signature, set_properties_strict!, update_field_descriptions!, generate_struct using PromptingTools: Tool, isabstracttool, execute_tool, parse_tool, get_arg_names, get_arg_types, get_method, get_function, remove_field!, - tool_call_signature + tool_call_signature, ToolRef using PromptingTools: AbstractToolError, ToolNotFoundError, ToolExecutionError, ToolGenericError @@ -77,6 +77,17 @@ end @test isabstracttool(tool) == true @test isabstracttool(tool_struct) == true @test isabstracttool(my_test_function) == false + + ## ToolRef + tool = ToolRef(:computer, println) + @test tool isa ToolRef + @test tool.ref == :computer + @test tool.callable == println + io = IOBuffer() + show(io, tool) + output = String(take!(io)) + @test occursin("ToolRef", output) + @test occursin("computer", output) end @testset "has_null_type" begin @@ -744,6 +755,11 @@ end @test tool2.parameters["properties"]["age"]["type"] == "integer" @test tool2.parameters["properties"]["height"]["type"] == "integer" @test tool2.parameters["properties"]["weight"]["type"] == "number" + + ## ToolRef + tool = ToolRef(:computer, println) + tool_map = tool_call_signature(tool) + @test tool_map == Dict("computer" => tool) end @testset "parse_tool" begin diff --git a/test/llm_anthropic.jl b/test/llm_anthropic.jl index 9008b04e8..036f16a78 100644 --- a/test/llm_anthropic.jl +++ b/test/llm_anthropic.jl @@ -3,7 +3,7 @@ using PromptingTools: AIMessage, SystemMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, AIToolRequest, ToolMessage, Tool using PromptingTools: call_cost, anthropic_api, function_call_signature, - anthropic_extra_headers + anthropic_extra_headers, ToolRef, BETA_HEADERS_ANTHROPIC @testset "render-Anthropic" begin schema = AnthropicSchema() @@ -271,6 +271,41 @@ end rendered = render(schema, tool_map) @test length(rendered) == 2 @test Set(t[:name] for t in rendered) == Set(["get_weather", "get_time"]) + + ## ToolRef + schema = AnthropicSchema() + + # Test computer tool rendering + computer_tool = ToolRef(ref = :computer) + rendered = render(schema, computer_tool) + @test rendered["type"] == "computer_20241022" + @test rendered["name"] == "computer" + @test rendered["display_width_px"] == 1024 + @test rendered["display_height_px"] == 768 + @test rendered["display_number"] == 1 + + # Test text editor tool rendering + editor_tool = ToolRef(ref = :str_replace_editor) + rendered = render(schema, editor_tool) + @test rendered["type"] == "text_editor_20241022" + @test rendered["name"] == "str_replace_editor" + + # Test bash tool rendering + bash_tool = ToolRef(ref = :bash) + rendered = render(schema, bash_tool) + @test rendered["type"] == "bash_20241022" + @test rendered["name"] == "bash" + + # Test invalid tool reference + @test_throws ArgumentError render(schema, ToolRef(ref = :invalid_tool)) + + # Test rendering multiple tool refs + tools = [computer_tool, editor_tool, bash_tool] + rendered = render(schema, tools) + @test length(rendered) == 3 + @test rendered[1]["name"] == "computer" + @test rendered[2]["name"] == "str_replace_editor" + @test rendered[3]["name"] == "bash" end @testset "anthropic_extra_headers" begin @@ -295,6 +330,48 @@ end "anthropic-version" => "2023-06-01", "anthropic-beta" => "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15" ] + + # Test with betas + @test anthropic_extra_headers(betas = [:tools]) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04" + ] + + @test anthropic_extra_headers(betas = [:cache]) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "prompt-caching-2024-07-31" + ] + + @test anthropic_extra_headers(betas = [:long_output]) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "max-tokens-3-5-sonnet-2024-07-15" + ] + + @test anthropic_extra_headers(betas = [:computer_use]) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "computer-use-2024-10-22" + ] + + # Test multiple betas + @test anthropic_extra_headers(betas = [:tools, :cache, :computer_use]) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04,prompt-caching-2024-07-31,computer-use-2024-10-22" + ] + + # Test all betas + @test anthropic_extra_headers(betas = BETA_HEADERS_ANTHROPIC) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15,computer-use-2024-10-22" + ] + + # Test invalid beta + @test_throws AssertionError anthropic_extra_headers(betas = [:invalid_beta]) + + # Test mixing has_* flags with betas + @test anthropic_extra_headers(has_tools = true, betas = [:cache]) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04,prompt-caching-2024-07-31" + ] end @testset "anthropic_api" begin diff --git a/test/llm_openai.jl b/test/llm_openai.jl index 4439a1a39..dd514b719 100644 --- a/test/llm_openai.jl +++ b/test/llm_openai.jl @@ -1,7 +1,7 @@ using PromptingTools: TestEchoOpenAISchema, render, OpenAISchema, role4render using PromptingTools: AIMessage, SystemMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, AIToolRequest, - ToolMessage, Tool + ToolMessage, Tool, ToolRef using PromptingTools: CustomProvider, CustomOpenAISchema, MistralOpenAISchema, MODEL_EMBEDDING, MODEL_IMAGE_GENERATION @@ -200,7 +200,7 @@ using PromptingTools: pick_tokenizer, OPENAI_TOKEN_IDS_GPT35_GPT4, OPENAI_TOKEN_ messages = [ SystemMessage("System message"), UserMessage("User message"), - AIToolRequest(;content="content") + AIToolRequest(; content = "content") ] conversation = render(schema, messages) expected_output = Dict{String, Any}[ @@ -323,6 +323,20 @@ end rendered = render(schema, [strict_tool]) @test rendered[1][:function][:strict] == true + + ## ToolRef rendering + schema = OpenAISchema() + + # Test that rendering ToolRef throws ArgumentError + tool = ToolRef(ref = :computer) + @test_throws ArgumentError render(schema, tool) + + # Test with json_mode=true + @test_throws ArgumentError render(schema, tool; json_mode = true) + + # Test with multiple tools + tools = [ToolRef(ref = :computer), ToolRef(ref = :str_replace_editor)] + @test_throws ArgumentError render(schema, tools) end @testset "OpenAI.build_url,OpenAI.auth_header" begin diff --git a/test/llm_shared.jl b/test/llm_shared.jl index d604bb56a..50951bda4 100644 --- a/test/llm_shared.jl +++ b/test/llm_shared.jl @@ -1,7 +1,7 @@ using PromptingTools: render, NoSchema, AbstractPromptSchema using PromptingTools: AIMessage, SystemMessage, AbstractMessage, AbstractChatMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, AIToolRequest, - ToolMessage + ToolMessage, ToolRef using PromptingTools: finalize_outputs, role4render @testset "render-NoSchema" begin @@ -214,11 +214,16 @@ using PromptingTools: finalize_outputs, role4render [Tool(; name = "f", description = "f", callable = () -> nothing)]) ## different ways to enter tools for rendering - opt1=render(OpenAISchema(), - [Tool(; name = "f", description = "f", callable = () -> nothing)]) - opt2=render(OpenAISchema(), - Dict("f" => Tool(; name = "f", description = "f", callable = () -> nothing))) + opt1 = render(OpenAISchema(), + [Tool(; name = "f", description = "f", callable = () -> nothing)]) + opt2 = render(OpenAISchema(), + Dict("f" => Tool(; name = "f", description = "f", callable = () -> nothing))) @test opt1 == opt2 + + ## ToolRef + schema = NoSchema() + tool = ToolRef(:computer, println) + @test_throws ArgumentError render(schema, tool) end @testset "finalize_outputs" begin