Skip to content

Commit

Permalink
feat: added function calling and added token count
Browse files Browse the repository at this point in the history
  • Loading branch information
Gmin2 committed Dec 26, 2024
1 parent d706c8f commit 7953568
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 21 deletions.
106 changes: 98 additions & 8 deletions src/providers/vertexai/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::models::embeddings::{
Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse,
};
use crate::models::streaming::{ChatCompletionChunk, Choice, ChoiceDelta};
use crate::models::tool_calls::{ChatMessageToolCall, FunctionCall};
use crate::models::usage::Usage;
use serde::{Deserialize, Serialize};

Expand All @@ -14,6 +15,9 @@ pub(crate) struct VertexAIChatCompletionRequest {
pub contents: Vec<Content>,
#[serde(rename = "generation_config")]
pub generation_config: Option<GenerationConfig>,
#[serde(rename = "tools")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<VertexAITool>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand Down Expand Up @@ -42,7 +46,10 @@ pub(crate) struct Content {

#[derive(Deserialize, Serialize, Clone, Debug)]
pub(crate) struct Part {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
pub function_call: Option<VertexFunctionCall>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand All @@ -63,6 +70,8 @@ pub(crate) struct GenerateContentResponse {
pub safety_ratings: Option<Vec<SafetyRating>>,
#[serde(rename = "avgLogprobs")]
pub avg_logprobs: Option<f32>,
#[serde(rename = "functionCall")]
pub function_call: Option<FunctionCall>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand Down Expand Up @@ -140,6 +149,40 @@ pub(crate) struct VertexAIEmbeddingStatistics {
pub token_count: u32,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub(crate) struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub(crate) struct FunctionDeclaration {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub(crate) struct VertexAITool {
pub function_declarations: Vec<VertexAIFunctionDeclaration>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub(crate) struct VertexAIFunctionDeclaration {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub(crate) struct VertexFunctionCall {
pub name: String,
pub args: serde_json::Value,
}

impl From<crate::models::chat::ChatCompletionRequest> for VertexAIChatCompletionRequest {
fn from(request: crate::models::chat::ChatCompletionRequest) -> Self {
let contents = request
Expand All @@ -162,11 +205,32 @@ impl From<crate::models::chat::ChatCompletionRequest> for VertexAIChatCompletion
"assistant" => "model".to_string(),
_ => "user".to_string(),
},
parts: vec![Part { text }],
parts: vec![Part {
text: Some(text),
function_call: None,
}],
}
})
.collect();

let tools = if let Some(tools) = request.tools {
vec![VertexAITool {
function_declarations: tools
.into_iter()
.map(|tool| VertexAIFunctionDeclaration {
name: tool.function.name,
description: tool.function.description,
parameters: tool
.function
.parameters
.map(|p| serde_json::to_value(p).unwrap_or_default()),
})
.collect(),
}]
} else {
Vec::new()
};

VertexAIChatCompletionRequest {
contents,
generation_config: Some(GenerationConfig {
Expand All @@ -176,6 +240,7 @@ impl From<crate::models::chat::ChatCompletionRequest> for VertexAIChatCompletion
candidate_count: request.n.map(|n| n as i32),
max_output_tokens: request.max_tokens.or(Some(default_max_tokens())),
}),
tools,
}
}
}
Expand All @@ -187,10 +252,24 @@ impl From<VertexAIChatCompletionResponse> for ChatCompletion {
.into_iter()
.enumerate()
.map(|(index, candidate)| {
let content = if let Some(part) = candidate.content.parts.first() {
ChatMessageContent::String(part.text.clone())
let (content, tool_calls) = if let Some(part) = candidate.content.parts.first() {
match (&part.text, &part.function_call) {
(Some(text), None) => (ChatMessageContent::String(text.clone()), None),
(None, Some(func_call)) => (
ChatMessageContent::String(String::new()),
Some(vec![ChatMessageToolCall {
id: uuid::Uuid::new_v4().to_string(),
function: FunctionCall {
name: func_call.name.clone(),
arguments: func_call.args.to_string(),
},
r#type: "function".to_string(),
}]),
),
_ => (ChatMessageContent::String(String::new()), None),
}
} else {
ChatMessageContent::String(String::new())
(ChatMessageContent::String(String::new()), None)
};

ChatCompletionChoice {
Expand All @@ -199,21 +278,32 @@ impl From<VertexAIChatCompletionResponse> for ChatCompletion {
role: "assistant".to_string(),
content: Some(content),
name: None,
tool_calls: None,
tool_calls,
},
finish_reason: Some(candidate.finish_reason),
logprobs: None,
}
})
.collect();

let usage = response
.usage_metadata
.map(|metadata| Usage {
prompt_tokens: metadata.prompt_token_count as u32,
completion_tokens: metadata.candidates_token_count as u32,
total_tokens: metadata.total_token_count as u32,
completion_tokens_details: None,
prompt_tokens_details: None,
})
.unwrap_or_default();

ChatCompletion {
id: uuid::Uuid::new_v4().to_string(),
object: None,
created: None,
model: "gemini-pro".to_string(),
choices,
usage: crate::models::usage::Usage::default(),
usage,
system_fingerprint: None,
}
}
Expand All @@ -231,7 +321,7 @@ impl From<VertexAIStreamChunk> for ChatCompletionChunk {
delta: ChoiceDelta {
content: candidate
.content
.and_then(|c| c.parts.first().map(|p| p.text.clone())),
.and_then(|c| c.parts.first().and_then(|p| p.text.clone())),
role: Some("assistant".to_string()),
tool_calls: None,
},
Expand Down
31 changes: 23 additions & 8 deletions src/providers/vertexai/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ impl Provider for VertexAIProvider {
);
headers.insert("Content-Type", HeaderValue::from_static("application/json"));

let request_body = json!({
"contents": request.contents,
"generation_config": request.generation_config,
});
let request_body = if payload.stream.unwrap_or(false) {
json!({
"contents": request.contents,
"generation_config": request.generation_config,
"tools": request.tools,
})
} else {
json!({
"contents": request.contents,
"generation_config": request.generation_config,
"tools": request.tools,
})
};

let response = self
.http_client
Expand Down Expand Up @@ -125,9 +134,17 @@ impl Provider for VertexAIProvider {

Ok(ChatCompletionResponse::Stream(Box::pin(stream)))
} else {
let response_text = response.text().await.map_err(|e| {
eprintln!("Failed to get response text: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;

let vertex_response: VertexAIChatCompletionResponse =
response.json().await.map_err(|e| {
eprintln!("VertexAI API response error: {}", e);
serde_json::from_str(&response_text).map_err(|e| {
eprintln!(
"Failed to parse response: {}. Response was: {}",
e, response_text
);
StatusCode::INTERNAL_SERVER_ERROR
})?;

Expand Down Expand Up @@ -172,8 +189,6 @@ impl Provider for VertexAIProvider {
model = model
);

println!("Request {:?}", request);

let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
Expand Down
94 changes: 89 additions & 5 deletions src/providers/vertexai/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ mod tests {
use crate::models::chat::{ChatCompletion, ChatCompletionRequest};
use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest};
use crate::models::tool_definition::{FunctionDefinition, ToolDefinition};
use crate::providers::provider::Provider;
use crate::providers::vertexai::models::{
Content, GenerateContentResponse, Part, UsageMetadata, VertexAIChatCompletionRequest,
VertexAIChatCompletionResponse, VertexAIEmbeddingsRequest,
VertexAIChatCompletionResponse, VertexAIEmbeddingsRequest, VertexFunctionCall,
};
use crate::providers::vertexai::provider::VertexAIProvider;
use serde_json::json;
use std::collections::HashMap;

fn create_test_config() -> ProviderConfig {
Expand All @@ -33,7 +35,9 @@ mod tests {
model: "gemini-pro".to_string(),
messages: vec![ChatCompletionMessage {
role: "user".to_string(),
content: Some(ChatMessageContent::String("Test message".to_string())),
content: Some(ChatMessageContent::String(
"What's the weather in London?".to_string(),
)),
name: None,
tool_calls: None,
}],
Expand Down Expand Up @@ -90,9 +94,15 @@ mod tests {

assert_eq!(vertex_request.contents.len(), 2);
assert_eq!(vertex_request.contents[0].role, "user");
assert_eq!(vertex_request.contents[0].parts[0].text, "Hello");
assert_eq!(
vertex_request.contents[0].parts[0].text,
Some("Hello".to_string())
);
assert_eq!(vertex_request.contents[1].role, "model");
assert_eq!(vertex_request.contents[1].parts[0].text, "Hi there!");
assert_eq!(
vertex_request.contents[1].parts[0].text,
Some("Hi there!".to_string())
);

let gen_config = vertex_request.generation_config.unwrap();
assert_eq!(gen_config.temperature, Some(0.7));
Expand All @@ -107,12 +117,14 @@ mod tests {
content: Content {
role: "model".to_string(),
parts: vec![Part {
text: "Generated response".to_string(),
text: Some("Generated response".to_string()),
function_call: None,
}],
},
finish_reason: "stop".to_string(),
safety_ratings: None,
avg_logprobs: None,
function_call: None,
}],
usage_metadata: Some(UsageMetadata {
prompt_token_count: 10,
Expand Down Expand Up @@ -227,4 +239,76 @@ mod tests {
assert_eq!(vertex_request.instances[0].content, "test text");
assert!(vertex_request.parameters.unwrap().auto_truncate.unwrap());
}

#[test]
fn test_function_calling_request() {
let mut chat_request = create_test_chat_request();
chat_request.tools = Some(vec![ToolDefinition {
function: FunctionDefinition {
name: "get_weather".to_string(),
description: Some("Get the current weather in a location".to_string()),
parameters: Some(HashMap::from([
("type".to_string(), json!("object")),
(
"properties".to_string(),
json!({
"location": {
"type": "string",
"description": "The city name"
}
}),
),
("required".to_string(), json!(["location"])),
])),
strict: None,
},
tool_type: "function".to_string(),
}]);

let vertex_request: VertexAIChatCompletionRequest = chat_request.into();

assert!(!vertex_request.tools.is_empty());
assert_eq!(
vertex_request.tools[0].function_declarations[0].name,
"get_weather"
);
assert_eq!(
vertex_request.tools[0].function_declarations[0].description,
Some("Get the current weather in a location".to_string())
);
}

#[test]
fn test_function_calling_response() {
let vertex_response = VertexAIChatCompletionResponse {
candidates: vec![GenerateContentResponse {
content: Content {
role: "model".to_string(),
parts: vec![Part {
text: None,
function_call: Some(VertexFunctionCall {
name: "get_weather".to_string(),
args: json!({"location": "London"}),
}),
}],
},
finish_reason: "stop".to_string(),
safety_ratings: None,
avg_logprobs: None,
function_call: None,
}],
usage_metadata: None,
model_version: None,
};

let chat_completion: ChatCompletion = vertex_response.into();

let tool_calls = chat_completion.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls[0].function.name, "get_weather");
assert_eq!(tool_calls[0].function.arguments, r#"{"location":"London"}"#);
}
}
Loading

0 comments on commit 7953568

Please sign in to comment.