Skip to content

Commit

Permalink
Add a more flexible credentials mechanism for chat_azure().
Browse files Browse the repository at this point in the history
In lieu of adding support for Azure authentication packages directly,
this commit adds a mechanism that at least allows them to be used and
refreshed manually (see #195 and #196): a `credentials` argument that
takes a function, similar to what we have for `chat_cortex()` today.

The `credentials` function is called on every request to Azure, making
it possible to refresh tokens that have expired prior to their use.

I also did some internal refactoring of the `ProviderAzure` class in the
process, and removed the need to set `api_key = ""` to use token-based
authentication.

Unit tests are included.

Signed-off-by: Aaron Jacobs <[email protected]>
  • Loading branch information
atheriel committed Jan 10, 2025
1 parent 3e440c2 commit 6247cfc
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 30 deletions.
86 changes: 62 additions & 24 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ NULL
#' value of the `AZURE_OPENAI_ENDPOINT` envinronment variable.
#' @param deployment_id Deployment id for the model you want to use.
#' @param api_version The API version to use.
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `AZURE_OPENAI_API_KEY` environment
#' variable.
#' @param token Azure token for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it.
#' @param api_key An API key to use for authentication. You generally should not
#' supply this directly, but instead set the `AZURE_OPENAI_API_KEY`
#' environment variable.
#' @param token A literal Azure token to use for authentication.
#' @param credentials A list of authentication headers to pass into
#' [`httr2::req_headers()`], a function that returns them, or `NULL` to use
#' `token` or `api_key` to generate these headers instead. This is an escape
#' hatch that allows users to incorporate Azure credentials generated by other
#' packages into \pkg{ellmer}, or to manage the lifetime of credentials that
#' need to be refreshed.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand All @@ -34,46 +39,59 @@ chat_azure <- function(endpoint = azure_endpoint(),
api_version = NULL,
system_prompt = NULL,
turns = NULL,
api_key = azure_key(),
api_key = NULL,
token = NULL,
credentials = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
check_string(endpoint)
check_string(deployment_id)
api_version <- set_default(api_version, "2024-06-01")
turns <- normalize_turns(turns, system_prompt)
check_exclusive(api_key, token, credentials, .require = FALSE)
check_string(api_key, allow_null = TRUE)
check_string(token, allow_null = TRUE)
echo <- check_echo(echo)

base_url <- paste0(endpoint, "/openai/deployments/", deployment_id)
if (is_list(credentials)) {
static_credentials <- force(credentials)
credentials <- function() static_credentials
}
check_function(credentials, allow_null = TRUE)
credentials <- credentials %||% default_azure_credentials(api_key, token)

provider <- ProviderAzure(
base_url = base_url,
endpoint = endpoint,
model = deployment_id,
deployment_id = deployment_id,
api_version = api_version,
token = token,
extra_args = api_args,
api_key = api_key
credentials = credentials,
extra_args = api_args
)
Chat$new(provider = provider, turns = turns, echo = echo)
}

ProviderAzure <- new_class(
"ProviderAzure",
parent = ProviderOpenAI,
constructor = function(endpoint, deployment_id, api_version, credentials,
extra_args = list()) {
new_object(
ProviderOpenAI(
base_url = paste0(endpoint, "/openai/deployments/", deployment_id),
model = deployment_id,
api_key = "",
extra_args = extra_args
),
api_version = api_version,
credentials = credentials
)
},
properties = list(
api_key = prop_string(),
token = prop_string(allow_null = TRUE),
endpoint = prop_string(),
credentials = class_function | NULL,
api_version = prop_string()
)
)

# https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#api-key
azure_key <- function() {
key_get("AZURE_OPENAI_API_KEY")
}

azure_endpoint <- function() {
key_get("AZURE_OPENAI_ENDPOINT")
}
Expand All @@ -89,10 +107,13 @@ method(chat_request, ProviderAzure) <- function(provider,
req <- request(provider@base_url)
req <- req_url_path_append(req, "/chat/completions")
req <- req_url_query(req, `api-version` = provider@api_version)
req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key")
if (!is.null(provider@token)) {
req <- req_auth_bearer_token(req, provider@token)
}
# Note: could use req_headers_redacted() here but it requires a very new
# httr2 version.
req <- req_headers(
req,
!!!provider@credentials(),
.redact = c("api-key", "Authorization")
)
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

Expand Down Expand Up @@ -127,3 +148,20 @@ method(chat_request, ProviderAzure) <- function(provider,

req
}

default_azure_credentials <- function(api_key = NULL, token = NULL) {
api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY")
if (nchar(api_key)) {
return(function() list(`api-key` = api_key))
}

if (!is.null(token)) {
return(function() list(Authorization = paste("Bearer", token)))
}

if (is_testing()) {
testthat::skip("no Azure credentials available")
}

cli::cli_abort("No Azure credentials are available.")
}
19 changes: 13 additions & 6 deletions man/chat_azure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions tests/testthat/_snaps/provider-azure.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Azure request headers are generated correctly

Code
req
Message
<httr2_request>
POST
https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01
Headers:
* api-key: '<REDACTED>'
Body: json encoded data
Policies:
* retry_max_tries: 2
* retry_on_failure: FALSE
* error_body: a function

---

Code
req
Message
<httr2_request>
POST
https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01
Headers:
* Authorization: '<REDACTED>'
Body: json encoded data
Policies:
* retry_max_tries: 2
* retry_on_failure: FALSE
* error_body: a function

29 changes: 29 additions & 0 deletions tests/testthat/test-provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,32 @@ test_that("can make simple request", {
expect_match(resp, "2")
expect_equal(chat$last_turn()@tokens, c(44, 1))
})

test_that("Azure request headers are generated correctly", {
turn <- Turn(
role = "user",
contents = list(ContentText("What is 1 + 1?"))
)
endpoint <- "https://ai-hwickhamai260967855527.openai.azure.com"
deployment_id <- "gpt-4o-mini"

# API key.
p <- ProviderAzure(
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
credentials = default_azure_credentials("key")
)
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(req)

# Token.
p <- ProviderAzure(
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
credentials = default_azure_credentials("", "token")
)
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(req)
})

0 comments on commit 6247cfc

Please sign in to comment.