Skip to content

Commit

Permalink
feat: support azure-openai (sigoden#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored and rooct committed Nov 30, 2023
1 parent e2af540 commit 8ac400b
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 3 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,21 @@ clients: # Setup AIs
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
organization_id: org-xxx # Organization ID. Optional

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai # Azure openai configuration
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
models: # Support models
- name: MyGPT4 # Model deployment name
max_tokens: 8192

# See https://github.com/go-skynet/LocalAI
- type: localai # LocalAI configuration
url: http://localhost:8080/v1/chat/completions # LocalAI api server
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
models: # Support models
- name: gpt4all-j
max_tokens: 4096
max_tokens: 8192
```
> You can use `.info` to view the current configuration file path and roles file path.
Expand Down
12 changes: 11 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@ clients: # Setup AIs
proxy: socks5://127.0.0.1:1080
connect_timeout: 10

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai # Azure openai configuration
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
models: # Support models
- name: MyGPT4 # Model deployment name
max_tokens: 8192
proxy: socks5://127.0.0.1:1080 # Set proxy server. Optional
connect_timeout: 10 # Set a timeout in seconds for connect to gpt. Optional

# See https://github.com/go-skynet/LocalAI
- type: localai # LocalAI configuration
url: http://localhost:8080/v1/chat/completions # LocalAI api server
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
models: # Support models
- name: gpt4all-j
max_tokens: 4096
max_tokens: 8192
proxy: socks5://127.0.0.1:1080
connect_timeout: 10
184 changes: 184 additions & 0 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{set_proxy, Client, ClientConfig, ModelInfo};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;

use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::json;
use std::env;
use std::time::Duration;

#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct AzureOpenAIClient {
global_config: SharedConfig,
local_config: AzureOpenAIConfig,
model_info: ModelInfo,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub api_base: String,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to server
pub connect_timeout: Option<u64>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIModel {
name: String,
max_tokens: usize,
}

#[async_trait]
impl Client for AzureOpenAIClient {
fn get_config(&self) -> &SharedConfig {
&self.global_config
}

async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
openai_send_message(builder).await
}

async fn send_message_streaming_inner(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
openai_send_message_streaming(builder, handler).await
}
}

impl AzureOpenAIClient {
pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != AzureOpenAIClient::name() {
return None;
}
let local_config = {
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
local_config,
model_info,
}))
}

pub fn name() -> &'static str {
"azure-openai"
}

pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
local_config
.models
.iter()
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
.collect()
}

pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());

let api_base = Text::new("api_base:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for api base, try again later."))?;

client_config.push_str(&format!(" api_base: {api_base}\n"));

if env::var("AZURE_OPENAI_KEY").is_err() {
let api_key = Text::new("API key:").prompt().map_err(|_| {
anyhow!("An error happened when asking for api key, try again later.")
})?;

client_config.push_str(&format!(" api_key: {api_key}\n"));
}

let model_name = Text::new("Model Name:").prompt().map_err(|_| {
anyhow!("An error happened when asking for model name, try again later.")
})?;

let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
anyhow!("An error happened when asking for max tokens, try again later.")
})?;

let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;

if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}

client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));

Ok(client_config)
}

fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let messages = self.global_config.read().build_messages(content)?;

let mut body = json!({
"messages": messages,
});

if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}

if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}

let client = {
let mut builder = ReqwestClient::builder();
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
};
let mut api_base = self.local_config.api_base.clone();
if !api_base.ends_with('/') {
api_base = format!("{api_base}/");
}

let url = format!(
"{api_base}openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.model_info.name
);

let mut builder = client.post(url);

if let Some(api_key) = &self.local_config.api_key {
builder = builder.header("api-key", api_key)
} else if let Ok(api_key) = env::var("AZURE_OPENAI_KEY") {
builder = builder.header("api-key", api_key)
}
builder = builder.json(&body);

Ok(builder)
}
}
14 changes: 13 additions & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub mod azure_openai;
pub mod localai;
pub mod openai;

use self::{
azure_openai::{AzureOpenAIClient, AzureOpenAIConfig},
localai::LocalAIConfig,
openai::{OpenAIClient, OpenAIConfig},
};
Expand All @@ -27,6 +29,8 @@ pub enum ClientConfig {
OpenAI(OpenAIConfig),
#[serde(rename = "localai")]
LocalAI(LocalAIConfig),
#[serde(rename = "azure-openai")]
AzureOpenAI(AzureOpenAIConfig),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -128,6 +132,7 @@ pub trait Client {
pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
OpenAIClient::init(config.clone())
.or_else(|| LocalAIClient::init(config.clone()))
.or_else(|| AzureOpenAIClient::init(config.clone()))
.ok_or_else(|| {
let model_info = config.read().model_info.clone();
anyhow!(
Expand All @@ -139,14 +144,20 @@ pub fn init_client(config: SharedConfig) -> Result<Box<dyn Client>> {
}

pub fn all_clients() -> Vec<&'static str> {
vec![OpenAIClient::name(), LocalAIClient::name()]
vec![
OpenAIClient::name(),
LocalAIClient::name(),
AzureOpenAIClient::name(),
]
}

pub fn create_client_config(client: &str) -> Result<String> {
if client == OpenAIClient::name() {
OpenAIClient::create_config()
} else if client == LocalAIClient::name() {
LocalAIClient::create_config()
} else if client == AzureOpenAIClient::name() {
AzureOpenAIClient::create_config()
} else {
bail!("Unknown client {}", &client)
}
Expand All @@ -160,6 +171,7 @@ pub fn list_models(config: &Config) -> Vec<ModelInfo> {
.flat_map(|(i, v)| match v {
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c, i),
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c, i),
ClientConfig::AzureOpenAI(c) => AzureOpenAIClient::list_models(c, i),
})
.collect()
}
Expand Down

0 comments on commit 8ac400b

Please sign in to comment.