Skip to content

Commit

Permalink
Restore the ability to use both API keys and tokens in chat_azure().
Browse files Browse the repository at this point in the history
This implements the suggestion from testers of #248, who rightly pointed
out that my assumption that API keys and Entra ID credentials are
mutually exclusive was incorrect.

Unit tests are included.

Signed-off-by: Aaron Jacobs <[email protected]>
  • Loading branch information
atheriel committed Jan 16, 2025
1 parent f1117d8 commit 23f377c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
31 changes: 16 additions & 15 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ chat_azure <- function(endpoint = azure_endpoint(),
check_string(deployment_id)
api_version <- set_default(api_version, "2024-06-01")
turns <- normalize_turns(turns, system_prompt)
check_exclusive(api_key, token, credentials, .require = FALSE)
check_exclusive(token, credentials, .require = FALSE)
check_string(api_key, allow_null = TRUE)
api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY")
check_string(token, allow_null = TRUE)
echo <- check_echo(echo)
if (is_list(credentials)) {
Expand All @@ -63,6 +64,7 @@ chat_azure <- function(endpoint = azure_endpoint(),
endpoint = endpoint,
deployment_id = deployment_id,
api_version = api_version,
api_key = api_key,
credentials = credentials,
extra_args = api_args
)
Expand All @@ -72,21 +74,21 @@ chat_azure <- function(endpoint = azure_endpoint(),
ProviderAzure <- new_class(
"ProviderAzure",
parent = ProviderOpenAI,
constructor = function(endpoint, deployment_id, api_version, credentials,
extra_args = list()) {
constructor = function(endpoint, deployment_id, api_version, api_key,
credentials, extra_args = list()) {
new_object(
ProviderOpenAI(
base_url = paste0(endpoint, "/openai/deployments/", deployment_id),
model = deployment_id,
api_key = "",
api_key = api_key,
extra_args = extra_args
),
api_version = api_version,
credentials = credentials
)
},
properties = list(
credentials = class_function | NULL,
credentials = class_function,
api_version = prop_string()
)
)
Expand All @@ -109,11 +111,10 @@ method(chat_request, ProviderAzure) <- function(provider,
req <- req_url_query(req, `api-version` = provider@api_version)
# Note: could use req_headers_redacted() here but it requires a very new
# httr2 version.
req <- req_headers(
req,
!!!provider@credentials(),
.redact = c("api-key", "Authorization")
)
if (nchar(provider@api_key)) {
req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key")
}
req <- req_headers(req, !!!provider@credentials(), .redact = "Authorization")
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

Expand Down Expand Up @@ -150,15 +151,15 @@ method(chat_request, ProviderAzure) <- function(provider,
}

default_azure_credentials <- function(api_key = NULL, token = NULL) {
api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY")
if (nchar(api_key)) {
return(function() list(`api-key` = api_key))
}

if (!is.null(token)) {
return(function() list(Authorization = paste("Bearer", token)))
}

# If we have an API key, rely on that for credentials.
if (nchar(api_key)) {
return(function() list())
}

if (is_testing()) {
testthat::skip("no Azure credentials available")
}
Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/_snaps/provider-azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,20 @@
* retry_on_failure: FALSE
* error_body: a function

---

Code
req
Message
<httr2_request>
POST
https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01
Headers:
* api-key: '<REDACTED>'
* Authorization: '<REDACTED>'
Body: json encoded data
Policies:
* retry_max_tries: 2
* retry_on_failure: FALSE
* error_body: a function

13 changes: 13 additions & 0 deletions tests/testthat/test-provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ test_that("Azure request headers are generated correctly", {
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
api_key = "key",
credentials = default_azure_credentials("key")
)
req <- chat_request(p, FALSE, list(turn))
Expand All @@ -36,8 +37,20 @@ test_that("Azure request headers are generated correctly", {
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
api_key = "",
credentials = default_azure_credentials("", "token")
)
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(req)

# Both.
p <- ProviderAzure(
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
api_key = "key",
credentials = default_azure_credentials("key", "token")
)
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(req)
})

0 comments on commit 23f377c

Please sign in to comment.