Skip to content

Commit

Permalink
Improvement to groq
Browse files Browse the repository at this point in the history
Fixes #207
  • Loading branch information
hadley committed Dec 17, 2024
1 parent eecd7df commit c9ebdc6
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 8 deletions.
100 changes: 94 additions & 6 deletions R/provider-groq.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#' @include provider-openai.R
NULL

#' Chat with a model hosted on Groq
#'
#' @description
Expand All @@ -17,20 +20,105 @@ chat_groq <- function(system_prompt = NULL,
seed = NULL,
api_args = list(),
echo = NULL) {

turns <- normalize_turns(turns, system_prompt)
model <- set_default(model, "llama3-8b-8192")
echo <- check_echo(echo)

if (is_testing() && is.null(seed)) {
seed <- seed %||% 1014
}

chat_openai(
system_prompt = system_prompt,
turns = turns,
provider <- ProviderGroq(
base_url = base_url,
api_key = api_key,
model = model,
seed = seed,
api_args = api_args,
echo = echo
extra_args = api_args,
api_key = api_key
)
Chat$new(provider = provider, turns = turns, echo = echo)
}

ProviderGroq <- new_class("ProviderGroq", parent = ProviderOpenAI)


# method(value_turn, ProviderGroq) <- function(provider, result, has_type = FALSE) {
# if (has_name(result$choices[[1]], "delta")) { # streaming
# message <- result$choices[[1]]$delta
# } else {
# message <- result$choices[[1]]$message
# }

# if (has_type) {
# json <- jsonlite::parse_json(message$content[[1]])
# content <- list(ContentJson(json))
# } else {
# content <- lapply(message$content, as_content)
# }

# if (has_name(message, "tool_calls")) {
# calls <- lapply(message$tool_calls, function(call) {
# name <- call$`function`$name
# # TODO: record parsing error
# args <- jsonlite::parse_json(call$`function`$arguments)
# ContentToolRequest(name = name, arguments = args, id = call$id)
# })
# content <- c(content, calls)
# }
# tokens <- c(
# result$usage$prompt_tokens %||% NA_integer_,
# result$usage$completion_tokens %||% NA_integer_
# )
# tokens_log(paste0("OpenAI-", provider@base_url), tokens)

# Turn(message$role, content, json = result, tokens = tokens)
# }


method(as_json, list(ProviderGroq, Turn)) <- function(provider, x) {
if (x@role == "assistant") {
# Tool requests come out of content and go into own argument
is_tool <- map_lgl(x@contents, S7_inherits, ContentToolRequest)
tool_calls <- as_json(provider, x@contents[is_tool])

# Grok contents is just a string. Hopefully it never sends back more
# than a single text response.
content <- x@contents[!is_tool][[1]]@text

list(
compact(list(role = "assistant", content = content, tool_calls = tool_calls))
)
} else {
as_json(super(provider, ProviderOpenAI), x)
}
}

# method(as_json, list(ProviderGroq, TypeObject)) <- function(provider, x) {
# if (x@additional_properties) {
# cli::cli_abort("{.arg .additional_properties} not supported for Groq.")
# }
# required <- map_lgl(x@properties, function(prop) prop@required)

# compact(list(
# type = "object",
# description = x@description,
# properties = as_json(provider, x@properties),
# required = as.list(names2(x@properties)[required])
# ))
# }


# method(as_json, list(ProviderGroq, ToolDef)) <- function(provider, x) {
# list(
# type = "function",
# "function" = compact(list(
# name = x@name,
# description = x@description,
# parameters = as_json(provider, x@arguments)
# ))
# )
# }

groq_key <- function() {
key_get("GROQ_API_KEY")
}
4 changes: 2 additions & 2 deletions R/provider-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ method(chat_request, ProviderOpenAI) <- function(provider,
stream = stream,
stream_options = if (stream) list(include_usage = TRUE),
tools = tools,
tool_choice = "auto",
response_format = response_format,
!!!extra_args
))
Expand Down Expand Up @@ -298,7 +299,6 @@ method(as_json, list(ProviderOpenAI, TypeObject)) <- function(provider, x) {
type = "object",
description = x@description %||% "",
properties = properties,
required = as.list(names),
additionalProperties = x@additional_properties
required = as.list(names)
)
}

0 comments on commit c9ebdc6

Please sign in to comment.