From 894f571eb3ef200a6ba02831ddf033bfe1ed6f51 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 11 Oct 2023 23:10:11 +0100 Subject: [PATCH 1/6] temp --- src/global_context.rs | 7 ++++--- src/http_server.rs | 2 +- src/lsp.rs | 2 +- src/main.rs | 2 +- src/scratchpads/chat_llama2.rs | 8 +++++--- src/scratchpads/mod.rs | 2 +- src/vecdb_search.rs | 2 +- 7 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/global_context.rs b/src/global_context.rs index c872da5..ebc05ad 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -10,7 +10,7 @@ use std::io::Write; use crate::caps::CodeAssistantCaps; use crate::completion_cache::CompletionCache; use crate::telemetry_storage; -use crate::vecdb_search::VecdbSearch; +use crate::vecdb_search::VecdbSearchTest; #[derive(Debug, StructOpt, Clone)] @@ -34,7 +34,7 @@ pub struct CommandLine { } -#[derive(Debug)] +// #[derive(Debug)] pub struct GlobalContext { pub http_client: reqwest::Client, pub ask_shutdown_sender: Arc>>, @@ -45,7 +45,7 @@ pub struct GlobalContext { pub cmdline: CommandLine, pub completions_cache: Arc>, pub telemetry: Arc>, - pub vecdb_search: Arc>, + pub vecdb_search: Arc>, } @@ -124,6 +124,7 @@ pub async fn create_global_context( cmdline: cmdline.clone(), completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())), telemetry: Arc::new(StdRwLock::new(telemetry_storage::Storage::new())), + vecdb_search: Arc::new(Mutex::new(crate::vecdb_search::VecdbSearchTest::new())), }; (Arc::new(ARwLock::new(cx)), ask_shutdown_receiver, cmdline) } diff --git a/src/http_server.rs b/src/http_server.rs index 20bc77e..d15505f 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -193,7 +193,7 @@ async fn handle_v1_chat( ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e)) )?; - let vecdb_search = ; + let vecdb_search = global_context.read().await.vecdb_search.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( chat_post.clone(), &scratchpad_name, diff --git a/src/lsp.rs b/src/lsp.rs index 96cc0d1..f5acf11 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -61,7 +61,7 @@ impl Document { } } -#[derive(Debug)] +// #[derive(Debug)] GlobalContext does not implement Debug pub struct Backend { pub gcx: Arc>, pub client: tower_lsp::Client, diff --git a/src/main.rs b/src/main.rs index 28a5256..87b564f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,7 +55,7 @@ async fn main() { info!("started"); info!("cache dir: {}", cache_dir.display()); test_vecdb().await; - return; + // return; let gcx2 = gcx.clone(); let gcx3 = gcx.clone(); diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 9aa95fe..2ccac87 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -24,14 +24,14 @@ pub struct ChatLlama2 { pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>>, + pub vecdb_search: Arc>, } impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, + vecdb_search: Arc>, ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), @@ -40,7 +40,7 @@ impl ChatLlama2 { keyword_s: "".to_string(), keyword_slash_s: "".to_string(), default_system_message: "".to_string(), - vecdb_search: vecdb_search + vecdb_search } } } @@ -98,6 +98,8 @@ impl ScratchpadAbstract for ChatLlama2 { prompt.push_str("[INST]"); } } + let vdb_suggestion = self.vecdb_search.search(prompt.as_str()); + // This only supports assistant, not suggestions for user self.dd.role = "assistant".to_string(); if DEBUG { diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index f9076ab..b324d40 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -46,7 +46,7 @@ pub fn create_chat_scratchpad( scratchpad_name: &str, scratchpad_patch: &serde_json::Value, tokenizer_arc: Arc>, - vecdb_search: Arc>>, + vecdb_search: Arc>, ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs index 8c2b0c8..3424f12 100644 --- a/src/vecdb_search.rs +++ b/src/vecdb_search.rs @@ -51,7 +51,7 @@ impl VecdbSearch for VecdbSearchTest { headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); let body = json!({ "texts": [query], - "account": "smc", + "account": "XXX", "top_k": 3, }); let res = reqwest::Client::new() From ba3757aa4c513112dbe21786ade9f7d8b6a6c75c Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 12 Oct 2023 20:08:47 +0100 Subject: [PATCH 2/6] temp; does not work --- src/http_server.rs | 4 +- src/main.rs | 2 +- src/scratchpads/chat_llama2.rs | 10 +++-- src/scratchpads/mod.rs | 2 +- src/vecdb_search.rs | 82 +++++++++++++++++++++++++++++----- 5 files changed, 83 insertions(+), 17 deletions(-) diff --git a/src/http_server.rs b/src/http_server.rs index d15505f..002003b 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -21,6 +21,7 @@ use crate::custom_error::ScratchError; use crate::telemetry_basic; use crate::telemetry_snippets; use crate::completion_cache; +use crate::vecdb_search::VecdbSearch; async fn _get_caps_and_tokenizer( @@ -193,7 +194,8 @@ async fn handle_v1_chat( ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e)) )?; - let vecdb_search = global_context.read().await.vecdb_search.clone(); + let mut vecdb_search = global_context.read().await.vecdb_search.lock().unwrap().clone(); + // let res = vecdb_search.search("").await; let mut scratchpad = scratchpads::create_chat_scratchpad( chat_post.clone(), &scratchpad_name, diff --git a/src/main.rs b/src/main.rs index 87b564f..12c4b43 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,7 @@ use crate::vecdb_search::VecdbSearch; async fn test_vecdb() { let mut v = vecdb_search::VecdbSearchTest::new(); - let res = v.search("ParallelTasksV3").await; + let res = v.sync_search("ParallelTasksV3"); info!("{:?}", res); } diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 2ccac87..6707b52 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -6,9 +6,12 @@ use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; use crate::vecdb_search; +use crate::vecdb_search::VecdbSearch; use std::sync::Arc; use std::sync::RwLock as StdRwLock; use std::sync::Mutex; +use async_trait::async_trait; + use tokenizers::Tokenizer; use tracing::info; @@ -24,14 +27,14 @@ pub struct ChatLlama2 { pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>, + pub vecdb_search: vecdb_search::VecdbSearchTest, } impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>, + vecdb_search: vecdb_search::VecdbSearchTest ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), @@ -98,11 +101,12 @@ impl ScratchpadAbstract for ChatLlama2 { prompt.push_str("[INST]"); } } - let vdb_suggestion = self.vecdb_search.search(prompt.as_str()); + let vdb_suggestion = self.vecdb_search.sync_search("abc"); // This only supports assistant, not suggestions for user self.dd.role = "assistant".to_string(); if DEBUG { + info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion); info!("llama2 chat prompt\n{}", prompt); info!("llama2 chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); } diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index b324d40..350be32 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -46,7 +46,7 @@ pub fn create_chat_scratchpad( scratchpad_name: &str, scratchpad_patch: &serde_json::Value, tokenizer_arc: Arc>, - vecdb_search: Arc>, + vecdb_search: vecdb_search::VecdbSearchTest, ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs index 3424f12..61d28bf 100644 --- a/src/vecdb_search.rs +++ b/src/vecdb_search.rs @@ -26,6 +26,8 @@ pub trait VecdbSearch: Send { &mut self, query: &str, ) -> Result; + + fn sync_search(&mut self, query: &str) -> Result; } #[derive(Debug, Clone)] @@ -39,6 +41,8 @@ impl VecdbSearchTest { } } +// unsafe impl Send for VecdbSearchTest {} + #[async_trait] impl VecdbSearch for VecdbSearchTest { async fn search( @@ -60,17 +64,6 @@ impl VecdbSearch for VecdbSearchTest { .body(body.to_string()) .send() .await.map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; - // print Allow header - // println!("{:?}", res.headers().get("allow")); - - // let x = VecdbResult { - // results: vec![VecdbResultRec { - // file_name: "test.txt".to_string(), - // text: "test".to_string(), - // score: "0.0".to_string(), - // }], - // }; - // info!("example: {:?}", serde_json::to_string(&x).unwrap()); let body = res.text().await.map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; // info!("Vecdb search result: {:?}", &body); @@ -84,5 +77,72 @@ impl VecdbSearch for VecdbSearchTest { // info!("Vecdb search result: {:?}", &result0); Ok(result0) } + + fn sync_search(&mut self, query: &str) -> Result { + let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); + let mut headers = HeaderMap::new(); + // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); + headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); + let body = json!({ + "texts": [query], + "account": "XXX", + "top_k": 3, + }); + + let res = reqwest::blocking::Client::new() + .post(&url) + .headers(headers) + .body(body.to_string()) + .send() + .map_err(|e| format!("Vecdb search HTTP error (1): {}", e)); + + let body = res?.text().map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; + let result: Vec = serde_json::from_str(&body).map_err(|e| { + format!("vecdb JSON problem: {}", e) + })?; + if result.is_empty() { + return Err("Vecdb search result is empty".to_string()); + } + let result0 = result[0].clone(); + Ok(result0) + } + } + +// trait SyncVecdbSearch { +// fn sync_search(&self, query: &str) -> Result; +// } +// +// impl SyncVecdbSearch for VecdbSearchTest { +// fn sync_search(&self, query: &str) -> Result { +// let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); +// let mut headers = HeaderMap::new(); +// // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); +// headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); +// let body = json!({ +// "texts": [query], +// "account": "XXX", +// "top_k": 3, +// }); +// +// let client = reqwest::blocking::Client::new(); +// let res = client +// .post(&url) +// .headers(headers) +// .body(body.to_string()) +// .send() +// .map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; +// +// let body = res.text().map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; +// // info!("Vecdb search result: {:?}", &body); +// let result: Vec = serde_json::from_str(&body).map_err(|e| { +// format!("vecdb JSON problem: {}", e) +// })?; +// if result.is_empty() { +// return Err("Vecdb search result is empty".to_string()); +// } +// let result0 = result[0].clone(); +// Ok(result0) +// } +// } From d6bfc8fe178994affc6666d71add21f0ebf6bb19 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 13 Oct 2023 13:16:05 +0200 Subject: [PATCH 3/6] WIP --- src/http_server.rs | 3 +- src/main.rs | 3 +- src/scratchpads/chat_llama2.rs | 13 +++---- src/scratchpads/mod.rs | 2 +- src/vecdb_search.rs | 70 ---------------------------------- 5 files changed, 9 insertions(+), 82 deletions(-) diff --git a/src/http_server.rs b/src/http_server.rs index 002003b..c40c39e 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -194,8 +194,7 @@ async fn handle_v1_chat( ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e)) )?; - let mut vecdb_search = global_context.read().await.vecdb_search.lock().unwrap().clone(); - // let res = vecdb_search.search("").await; + let mut vecdb_search = global_context.read().await.vecdb_search.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( chat_post.clone(), &scratchpad_name, diff --git a/src/main.rs b/src/main.rs index 12c4b43..b636620 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,7 @@ use crate::vecdb_search::VecdbSearch; async fn test_vecdb() { let mut v = vecdb_search::VecdbSearchTest::new(); - let res = v.sync_search("ParallelTasksV3"); + let res = v.search("ParallelTasksV3").await; info!("{:?}", res); } @@ -55,7 +55,6 @@ async fn main() { info!("started"); info!("cache dir: {}", cache_dir.display()); test_vecdb().await; - // return; let gcx2 = gcx.clone(); let gcx3 = gcx.clone(); diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 6707b52..64f2faa 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -5,12 +5,11 @@ use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search; -use crate::vecdb_search::VecdbSearch; use std::sync::Arc; use std::sync::RwLock as StdRwLock; use std::sync::Mutex; -use async_trait::async_trait; +use crate::vecdb_search; +use crate::vecdb_search::VecdbSearch; use tokenizers::Tokenizer; @@ -27,14 +26,14 @@ pub struct ChatLlama2 { pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: vecdb_search::VecdbSearchTest, + pub vecdb_search: Arc>, } impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: vecdb_search::VecdbSearchTest + vecdb_search: Arc>, ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), @@ -101,12 +100,12 @@ impl ScratchpadAbstract for ChatLlama2 { prompt.push_str("[INST]"); } } - let vdb_suggestion = self.vecdb_search.sync_search("abc"); + // let vdb_suggestion = self.vecdb_search.sync_search("abc"); // This only supports assistant, not suggestions for user self.dd.role = "assistant".to_string(); if DEBUG { - info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion); + // info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion); info!("llama2 chat prompt\n{}", prompt); info!("llama2 chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); } diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index 350be32..b324d40 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -46,7 +46,7 @@ pub fn create_chat_scratchpad( scratchpad_name: &str, scratchpad_patch: &serde_json::Value, tokenizer_arc: Arc>, - vecdb_search: vecdb_search::VecdbSearchTest, + vecdb_search: Arc>, ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs index 61d28bf..2d9bfa9 100644 --- a/src/vecdb_search.rs +++ b/src/vecdb_search.rs @@ -26,8 +26,6 @@ pub trait VecdbSearch: Send { &mut self, query: &str, ) -> Result; - - fn sync_search(&mut self, query: &str) -> Result; } #[derive(Debug, Clone)] @@ -77,72 +75,4 @@ impl VecdbSearch for VecdbSearchTest { // info!("Vecdb search result: {:?}", &result0); Ok(result0) } - - fn sync_search(&mut self, query: &str) -> Result { - let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); - let mut headers = HeaderMap::new(); - // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); - headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); - let body = json!({ - "texts": [query], - "account": "XXX", - "top_k": 3, - }); - - let res = reqwest::blocking::Client::new() - .post(&url) - .headers(headers) - .body(body.to_string()) - .send() - .map_err(|e| format!("Vecdb search HTTP error (1): {}", e)); - - let body = res?.text().map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; - let result: Vec = serde_json::from_str(&body).map_err(|e| { - format!("vecdb JSON problem: {}", e) - })?; - if result.is_empty() { - return Err("Vecdb search result is empty".to_string()); - } - let result0 = result[0].clone(); - Ok(result0) - } - } - - -// trait SyncVecdbSearch { -// fn sync_search(&self, query: &str) -> Result; -// } -// -// impl SyncVecdbSearch for VecdbSearchTest { -// fn sync_search(&self, query: &str) -> Result { -// let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); -// let mut headers = HeaderMap::new(); -// // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); -// headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); -// let body = json!({ -// "texts": [query], -// "account": "XXX", -// "top_k": 3, -// }); -// -// let client = reqwest::blocking::Client::new(); -// let res = client -// .post(&url) -// .headers(headers) -// .body(body.to_string()) -// .send() -// .map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; -// -// let body = res.text().map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; -// // info!("Vecdb search result: {:?}", &body); -// let result: Vec = serde_json::from_str(&body).map_err(|e| { -// format!("vecdb JSON problem: {}", e) -// })?; -// if result.is_empty() { -// return Err("Vecdb search result is empty".to_string()); -// } -// let result0 = result[0].clone(); -// Ok(result0) -// } -// } From 91d9961b05ddffe6b34e7da318f6163937e9d057 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 13 Oct 2023 13:21:13 +0200 Subject: [PATCH 4/6] async prompt --- src/http_server.rs | 8 ++++---- src/scratchpad_abstract.rs | 4 +++- src/scratchpads/chat_generic.rs | 4 +++- src/scratchpads/chat_llama2.rs | 7 +++++-- src/scratchpads/completion_single_file_fim.rs | 5 ++++- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/http_server.rs b/src/http_server.rs index c40c39e..2f950dd 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -21,7 +21,7 @@ use crate::custom_error::ScratchError; use crate::telemetry_basic; use crate::telemetry_snippets; use crate::completion_cache; -use crate::vecdb_search::VecdbSearch; +// use crate::vecdb_search::VecdbSearch; async fn _get_caps_and_tokenizer( @@ -156,7 +156,7 @@ async fn handle_v1_code_completion( let prompt = scratchpad.prompt( 2048, &mut code_completion_post.parameters, - ).map_err(|e| + ).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Prompt: {}", e)) )?; // info!("prompt {:?}\n{}", t1.elapsed(), prompt); @@ -194,7 +194,7 @@ async fn handle_v1_chat( ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e)) )?; - let mut vecdb_search = global_context.read().await.vecdb_search.clone(); + let vecdb_search = global_context.read().await.vecdb_search.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( chat_post.clone(), &scratchpad_name, @@ -208,7 +208,7 @@ async fn handle_v1_chat( let prompt = scratchpad.prompt( 2048, &mut chat_post.parameters, - ).map_err(|e| + ).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Prompt: {}", e)) )?; // info!("chat prompt {:?}\n{}", t1.elapsed(), prompt); diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 227cfc8..563ee60 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -3,15 +3,17 @@ use std::sync::Arc; use std::sync::RwLock; use tokenizers::Tokenizer; use crate::call_validation::SamplingParameters; +use async_trait::async_trait; +#[async_trait] pub trait ScratchpadAbstract: Send { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, ) -> Result<(), String>; - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 1bb39c7..ffe3f17 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -7,6 +7,7 @@ use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; use std::sync::Arc; use std::sync::RwLock; +use async_trait::async_trait; use tokenizers::Tokenizer; use tracing::info; @@ -44,6 +45,7 @@ impl GenericChatScratchpad { } } +#[async_trait] impl ScratchpadAbstract for GenericChatScratchpad { fn apply_model_adaptation_patch( &mut self, @@ -68,7 +70,7 @@ impl ScratchpadAbstract for GenericChatScratchpad { Ok(()) } - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 64f2faa..ed89b7e 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -8,7 +8,9 @@ use crate::scratchpads::chat_utils_limit_history::limit_messages_history; use std::sync::Arc; use std::sync::RwLock as StdRwLock; use std::sync::Mutex; -use crate::vecdb_search; +use async_trait::async_trait; + +// use crate::vecdb_search; use crate::vecdb_search::VecdbSearch; @@ -47,6 +49,7 @@ impl ChatLlama2 { } } +#[async_trait] impl ScratchpadAbstract for ChatLlama2 { fn apply_model_adaptation_patch( &mut self, @@ -64,7 +67,7 @@ impl ScratchpadAbstract for ChatLlama2 { Ok(()) } - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 5536c07..0e63585 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -8,6 +8,8 @@ use std::sync::RwLock as StdRwLock; use tokenizers::Tokenizer; use ropey::Rope; use tracing::info; +use async_trait::async_trait; + use crate::completion_cache; use crate::telemetry_storage; use crate::telemetry_snippets; @@ -42,6 +44,7 @@ impl SingleFileFIM { } +#[async_trait] impl ScratchpadAbstract for SingleFileFIM { fn apply_model_adaptation_patch( &mut self, @@ -59,7 +62,7 @@ impl ScratchpadAbstract for SingleFileFIM { Ok(()) } - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, From 8006cf1e7571e2c83dd381e4daaeae212e458054 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 13 Oct 2023 15:00:39 +0200 Subject: [PATCH 5/6] llama2 async prompt() + vdb --- src/global_context.rs | 7 ++++--- src/main.rs | 10 ---------- src/scratchpads/chat_llama2.rs | 29 ++++++++++++++++++----------- src/scratchpads/mod.rs | 5 +++-- src/vecdb_search.rs | 2 +- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/global_context.rs b/src/global_context.rs index ebc05ad..abb8539 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::sync::RwLock as StdRwLock; +use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; use tokenizers::Tokenizer; use structopt::StructOpt; @@ -10,7 +11,7 @@ use std::io::Write; use crate::caps::CodeAssistantCaps; use crate::completion_cache::CompletionCache; use crate::telemetry_storage; -use crate::vecdb_search::VecdbSearchTest; +use crate::vecdb_search::VecdbSearch; #[derive(Debug, StructOpt, Clone)] @@ -45,7 +46,7 @@ pub struct GlobalContext { pub cmdline: CommandLine, pub completions_cache: Arc>, pub telemetry: Arc>, - pub vecdb_search: Arc>, + pub vecdb_search: Arc>>, } @@ -124,7 +125,7 @@ pub async fn create_global_context( cmdline: cmdline.clone(), completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())), telemetry: Arc::new(StdRwLock::new(telemetry_storage::Storage::new())), - vecdb_search: Arc::new(Mutex::new(crate::vecdb_search::VecdbSearchTest::new())), + vecdb_search: Arc::new(AMutex::new(Box::new(crate::vecdb_search::VecdbSearchTest::new()))), }; (Arc::new(ARwLock::new(cx)), ask_shutdown_receiver, cmdline) } diff --git a/src/main.rs b/src/main.rs index b636620..fd261cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,15 +20,6 @@ mod telemetry_snippets; mod telemetry_storage; mod vecdb_search; mod lsp; -use crate::vecdb_search::VecdbSearch; - - -async fn test_vecdb() -{ - let mut v = vecdb_search::VecdbSearchTest::new(); - let res = v.search("ParallelTasksV3").await; - info!("{:?}", res); -} #[tokio::main] @@ -54,7 +45,6 @@ async fn main() { .init(); info!("started"); info!("cache dir: {}", cache_dir.display()); - test_vecdb().await; let gcx2 = gcx.clone(); let gcx3 = gcx.clone(); diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index ed89b7e..fa19a89 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -1,3 +1,10 @@ +use tracing::info; +use std::sync::Arc; +use std::sync::RwLock as StdRwLock; +use tokio::sync::Mutex as AMutex; +use tokenizers::Tokenizer; +use async_trait::async_trait; + use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; @@ -5,18 +12,9 @@ use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use std::sync::Arc; -use std::sync::RwLock as StdRwLock; -use std::sync::Mutex; -use async_trait::async_trait; - -// use crate::vecdb_search; use crate::vecdb_search::VecdbSearch; -use tokenizers::Tokenizer; -use tracing::info; - const DEBUG: bool = true; @@ -28,14 +26,15 @@ pub struct ChatLlama2 { pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>, + pub vecdb_search: Arc>>, } + impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>, + vecdb_search: Arc>>, ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), @@ -72,6 +71,14 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { + let my_vdb = self.vecdb_search.clone(); + let vdb_result; + { + let mut vecdb_locked = my_vdb.lock().await; + vdb_result = vecdb_locked.search("ParallelTasksV3").await; + } + info!("llama2 vdb_result {:?}", vdb_result); + let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index b324d40..e6b008f 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::sync::RwLock as StdRwLock; -use std::sync::Mutex; +use tokio::sync::Mutex as AMutex; + use tokenizers::Tokenizer; pub mod completion_single_file_fim; @@ -46,7 +47,7 @@ pub fn create_chat_scratchpad( scratchpad_name: &str, scratchpad_patch: &serde_json::Value, tokenizer_arc: Arc>, - vecdb_search: Arc>, + vecdb_search: Arc>>, ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs index 2d9bfa9..71acdb2 100644 --- a/src/vecdb_search.rs +++ b/src/vecdb_search.rs @@ -4,7 +4,7 @@ use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use serde::{Deserialize, Serialize}; use serde_json::json; -use tracing::info; +// use tracing::info; use async_trait::async_trait; From fd9fd05f8a3310c6589d8c77cc0af218e60d99a3 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Sun, 15 Oct 2023 20:36:51 +0100 Subject: [PATCH 6/6] any chat now works with vecdb --- src/scratchpads/chat_generic.rs | 8 ++++- src/scratchpads/chat_llama2.rs | 13 ++------ src/scratchpads/mod.rs | 2 +- src/vecdb_search.rs | 56 ++++++++++++++++++++++++++++++++- 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index ffe3f17..335b82e 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -5,9 +5,12 @@ use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; +use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; + use std::sync::Arc; use std::sync::RwLock; use async_trait::async_trait; +use tokio::sync::Mutex as AMutex; use tokenizers::Tokenizer; use tracing::info; @@ -15,7 +18,6 @@ use tracing::info; const DEBUG: bool = true; -#[derive(Debug)] pub struct GenericChatScratchpad { pub t: HasTokenizerAndEot, pub dd: DeltaDeltaChatStreamer, @@ -25,12 +27,14 @@ pub struct GenericChatScratchpad { pub keyword_user: String, pub keyword_asst: String, pub default_system_message: String, + pub vecdb_search: Arc>>, } impl GenericChatScratchpad { pub fn new( tokenizer: Arc>, post: ChatPost, + vecdb_search: Arc>>, ) -> Self { GenericChatScratchpad { t: HasTokenizerAndEot::new(tokenizer), @@ -41,6 +45,7 @@ impl GenericChatScratchpad { keyword_user: "".to_string(), keyword_asst: "".to_string(), default_system_message: "".to_string(), + vecdb_search } } } @@ -75,6 +80,7 @@ impl ScratchpadAbstract for GenericChatScratchpad { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { + embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index fa19a89..dfaaeb7 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -12,7 +12,7 @@ use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search::VecdbSearch; +use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; const DEBUG: bool = true; @@ -71,14 +71,7 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - let my_vdb = self.vecdb_search.clone(); - let vdb_result; - { - let mut vecdb_locked = my_vdb.lock().await; - vdb_result = vecdb_locked.search("ParallelTasksV3").await; - } - info!("llama2 vdb_result {:?}", vdb_result); - + embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 @@ -110,8 +103,6 @@ impl ScratchpadAbstract for ChatLlama2 { prompt.push_str("[INST]"); } } - // let vdb_suggestion = self.vecdb_search.sync_search("abc"); - // This only supports assistant, not suggestions for user self.dd.role = "assistant".to_string(); if DEBUG { diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index e6b008f..ad7e670 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -51,7 +51,7 @@ pub fn create_chat_scratchpad( ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { - result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post)); + result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, vecdb_search)); } else if scratchpad_name == "CHAT-LLAMA2" { result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, vecdb_search)); } else { diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs index 71acdb2..e13a067 100644 --- a/src/vecdb_search.rs +++ b/src/vecdb_search.rs @@ -1,10 +1,13 @@ +use crate::call_validation::{ChatMessage, ChatPost}; // use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use serde::{Deserialize, Serialize}; use serde_json::json; -// use tracing::info; + +use std::sync::Arc; +use tokio::sync::Mutex as AMutex; use async_trait::async_trait; @@ -20,6 +23,57 @@ pub struct VecdbResult { pub results: Vec, } +pub async fn embed_vecdb_results( + vecdb_search: Arc>>, + post: &mut ChatPost, + limit_examples_cnt: usize, +) { + let my_vdb = vecdb_search.clone(); + let latest_msg_cont = &post.messages.last().unwrap().content; + let mut vecdb_locked = my_vdb.lock().await; + let vdb_resp = vecdb_locked.search(&latest_msg_cont).await; + let vdb_cont = vecdb_resp_to_prompt(&vdb_resp, limit_examples_cnt); + if vdb_cont.len() > 0 { + post.messages = [ + &post.messages[..post.messages.len() -1], + &[ChatMessage { + role: "user".to_string(), + content: vdb_cont, + }], + &post.messages[post.messages.len() -1..], + ].concat(); + } +} + + +fn vecdb_resp_to_prompt( + resp: &Result, + limit_examples_cnt: usize, +) -> String { + let mut cont = "".to_string(); + match resp { + Ok(resp) => { + cont.push_str("CONTEXT:\n"); + for i in 0..limit_examples_cnt { + if i >= resp.results.len() { + break; + } + cont.push_str("FILENAME:\n"); + cont.push_str(resp.results[i].file_name.clone().as_str()); + cont.push_str("\nTEXT:"); + cont.push_str(resp.results[i].text.clone().as_str()); + cont.push_str("\n"); + } + cont.push_str("\nRefer to the context to answer my next question.\n"); + cont + } + Err(e) => { + format!("Vecdb error: {}", e); + cont + } + } +} + #[async_trait] pub trait VecdbSearch: Send { async fn search(