Skip to content

Commit

Permalink
refactor: improve code quality (#956)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Oct 29, 2024
1 parent bb542f6 commit bf07cdf
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 90 deletions.
24 changes: 14 additions & 10 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub trait Client: Sync + Send {
handler.done();
ret.with_context(|| "Failed to call chat-completions api")
}
_ = watch_abort_signal(abort_signal) => {
_ = wait_abort_signal(&abort_signal) => {
handler.done();
Ok(())
},
Expand Down Expand Up @@ -401,20 +401,25 @@ pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(St

pub async fn call_chat_completions(
input: &Input,
extract_code: bool,
client: &dyn Client,
config: &GlobalConfig,
) -> Result<(String, Vec<ToolResult>)> {
let task = client.chat_completions(input.clone());
let ret = run_with_spinner(task, "Generating").await;
match ret {
Ok(ret) => {
let ChatCompletionsOutput {
text, tool_calls, ..
mut text,
tool_calls,
..
} = ret;
if !text.is_empty() {
config.read().print_markdown(&text)?;
if extract_code && text.trim_start().starts_with("```") {
text = extract_block(&text);
}
client.global_config().read().print_markdown(&text)?;
}
Ok((text, eval_tool_calls(config, tool_calls)?))
Ok((text, eval_tool_calls(client.global_config(), tool_calls)?))
}
Err(err) => Err(err),
}
Expand All @@ -423,15 +428,14 @@ pub async fn call_chat_completions(
pub async fn call_chat_completions_streaming(
input: &Input,
client: &dyn Client,
config: &GlobalConfig,
abort: AbortSignal,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let (tx, rx) = unbounded_channel();
let mut handler = SseHandler::new(tx, abort.clone());
let mut handler = SseHandler::new(tx, abort_signal.clone());

let (send_ret, render_ret) = tokio::join!(
client.chat_completions_streaming(input, &mut handler),
render_stream(rx, config, abort.clone()),
render_stream(rx, client.global_config(), abort_signal.clone()),
);

render_ret?;
Expand All @@ -442,7 +446,7 @@ pub async fn call_chat_completions_streaming(
if !text.is_empty() && !text.ends_with('\n') {
println!();
}
Ok((text, eval_tool_calls(config, tool_calls)?))
Ok((text, eval_tool_calls(client.global_config(), tool_calls)?))
}
Err(err) => {
if !text.is_empty() {
Expand Down
12 changes: 6 additions & 6 deletions src/client/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ use tokio::sync::mpsc::UnboundedSender;

pub struct SseHandler {
sender: UnboundedSender<SseEvent>,
abort: AbortSignal,
abort_signal: AbortSignal,
buffer: String,
tool_calls: Vec<ToolCall>,
}

impl SseHandler {
pub fn new(sender: UnboundedSender<SseEvent>, abort: AbortSignal) -> Self {
pub fn new(sender: UnboundedSender<SseEvent>, abort_signal: AbortSignal) -> Self {
Self {
sender,
abort,
abort_signal,
buffer: String::new(),
tool_calls: Vec::new(),
}
Expand All @@ -36,7 +36,7 @@ impl SseHandler {
.send(SseEvent::Text(text.to_string()))
.with_context(|| "Failed to send SseEvent:Text");
if let Err(err) = ret {
if self.abort.aborted() {
if self.abort_signal.aborted() {
return Ok(());
}
return Err(err);
Expand All @@ -48,7 +48,7 @@ impl SseHandler {
// debug!("HandleDone");
let ret = self.sender.send(SseEvent::Done);
if ret.is_err() {
if self.abort.aborted() {
if self.abort_signal.aborted() {
return;
}
warn!("Failed to send SseEvent:Done");
Expand All @@ -62,7 +62,7 @@ impl SseHandler {
}

pub fn abort(&self) -> AbortSignal {
self.abort.clone()
self.abort_signal.clone()
}

pub fn tool_calls(&self) -> &[ToolCall] {
Expand Down
45 changes: 8 additions & 37 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ mod utils;
extern crate log;

use crate::cli::Cli;
use crate::client::{
call_chat_completions, call_chat_completions_streaming, list_chat_models, ChatCompletionsOutput,
};
use crate::client::{call_chat_completions, call_chat_completions_streaming, list_chat_models};
use crate::config::{
ensure_parent_exists, list_agents, load_env_file, Config, GlobalConfig, Input, WorkingMode,
CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, TEMP_SESSION_NAME,
};
use crate::function::{eval_tool_calls, need_send_tool_results};
use crate::function::need_send_tool_results;
use crate::render::render_error;
use crate::repl::Repl;
use crate::utils::*;
Expand Down Expand Up @@ -175,28 +173,9 @@ async fn start_directive(
let extract_code = !*IS_STDOUT_TERMINAL && code_mode;
config.write().before_chat_completion(&input)?;
let (output, tool_results) = if !input.stream() || extract_code {
let task = client.chat_completions(input.clone());
let ret = run_with_spinner(task, "Generating").await;
match ret {
Ok(ret) => {
let ChatCompletionsOutput {
mut text,
tool_calls,
..
} = ret;
if !text.is_empty() {
if extract_code && text.trim_start().starts_with("```") {
text = extract_block(&text);
}
config.read().print_markdown(&text)?;
}
(text, eval_tool_calls(config, tool_calls)?)
}
Err(err) => return Err(err),
}
call_chat_completions(&input, extract_code, client.as_ref()).await?
} else {
call_chat_completions_streaming(&input, client.as_ref(), config, abort_signal.clone())
.await?
call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await?
};
config
.write()
Expand Down Expand Up @@ -225,15 +204,7 @@ async fn start_interactive(config: &GlobalConfig) -> Result<()> {
async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -> Result<()> {
let client = input.create_client()?;
config.write().before_chat_completion(&input)?;
let ret = if *IS_STDOUT_TERMINAL {
let spinner = create_spinner("Generating").await;
let ret = client.chat_completions(input.clone()).await;
spinner.stop();
ret
} else {
client.chat_completions(input.clone()).await
};
let mut eval_str = ret?.text;
let (mut eval_str, _) = call_chat_completions(&input, false, client.as_ref()).await?;
if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) {
eval_str = extract_block(&eval_str);
}
Expand Down Expand Up @@ -287,12 +258,12 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
"d" => {
let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;
let input = Input::from_str(config, &eval_str, Some(role));
let abort = create_abort_signal();
let abort_signal = create_abort_signal();
if input.stream() {
call_chat_completions_streaming(&input, client.as_ref(), config, abort)
call_chat_completions_streaming(&input, client.as_ref(), abort_signal)
.await?;
} else {
call_chat_completions(&input, client.as_ref(), config).await?;
call_chat_completions(&input, false, client.as_ref()).await?;
}
println!();
continue;
Expand Down
6 changes: 3 additions & 3 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Rag {
spinner.stop();
ret?;
}
_ = watch_abort_signal(abort_signal) => {
_ = wait_abort_signal(&abort_signal) => {
spinner.stop();
bail!("Aborted!")
},
Expand Down Expand Up @@ -142,7 +142,7 @@ impl Rag {
spinner.stop();
ret?;
}
_ = watch_abort_signal(abort_signal) => {
_ = wait_abort_signal(&abort_signal) => {
spinner.stop();
bail!("Aborted!")
},
Expand Down Expand Up @@ -320,7 +320,7 @@ impl Rag {
ret = self.hybird_search(text, top_k, min_score_vector_search, min_score_keyword_search, rerank_model) => {
ret
}
_ = watch_abort_signal(abort_signal) => {
_ = wait_abort_signal(&abort_signal) => {
bail!("Aborted!")
},
};
Expand Down
6 changes: 3 additions & 3 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ use tokio::sync::mpsc::UnboundedReceiver;
pub async fn render_stream(
rx: UnboundedReceiver<SseEvent>,
config: &GlobalConfig,
abort: AbortSignal,
abort_signal: AbortSignal,
) -> Result<()> {
let ret = if *IS_STDOUT_TERMINAL {
let render_options = config.read().render_options()?;
let mut render = MarkdownRender::init(render_options)?;
markdown_stream(rx, &mut render, &abort).await
markdown_stream(rx, &mut render, &abort_signal).await
} else {
raw_stream(rx, &abort).await
raw_stream(rx, &abort_signal).await
};
ret.map_err(|err| err.context("Failed to reader stream"))
}
Expand Down
37 changes: 13 additions & 24 deletions src/render/stream.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use super::{MarkdownRender, SseEvent};

use crate::utils::{create_spinner, AbortSignal};
use crate::utils::{create_spinner, poll_abort_signal, AbortSignal};

use anyhow::Result;
use crossterm::{
cursor,
event::{self, Event, KeyCode, KeyModifiers},
queue, style,
cursor, queue, style,
terminal::{self, disable_raw_mode, enable_raw_mode},
};
use std::{
Expand All @@ -19,12 +17,12 @@ use tokio::sync::mpsc::UnboundedReceiver;
pub async fn markdown_stream(
rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender,
abort: &AbortSignal,
abort_signal: &AbortSignal,
) -> Result<()> {
enable_raw_mode()?;
let mut stdout = io::stdout();

let ret = markdown_stream_inner(rx, render, abort, &mut stdout).await;
let ret = markdown_stream_inner(rx, render, abort_signal, &mut stdout).await;

disable_raw_mode()?;

Expand All @@ -34,9 +32,12 @@ pub async fn markdown_stream(
ret
}

pub async fn raw_stream(mut rx: UnboundedReceiver<SseEvent>, abort: &AbortSignal) -> Result<()> {
pub async fn raw_stream(
mut rx: UnboundedReceiver<SseEvent>,
abort_signal: &AbortSignal,
) -> Result<()> {
loop {
if abort.aborted() {
if abort_signal.aborted() {
return Ok(());
}
if let Some(evt) = rx.recv().await {
Expand All @@ -57,7 +58,7 @@ pub async fn raw_stream(mut rx: UnboundedReceiver<SseEvent>, abort: &AbortSignal
async fn markdown_stream_inner(
mut rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender,
abort: &AbortSignal,
abort_signal: &AbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut buffer = String::new();
Expand All @@ -68,7 +69,7 @@ async fn markdown_stream_inner(
let mut spinner = Some(create_spinner("Generating").await);

'outer: loop {
if abort.aborted() {
if abort_signal.aborted() {
return Ok(());
}
for reply_event in gather_events(&mut rx).await {
Expand Down Expand Up @@ -141,20 +142,8 @@ async fn markdown_stream_inner(
}
}

if crossterm::event::poll(Duration::from_millis(25))? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrlc();
break;
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort.set_ctrld();
break;
}
_ => {}
}
}
if poll_abort_signal(abort_signal)? {
break;
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,9 @@ async fn ask(
let client = input.create_client()?;
config.write().before_chat_completion(&input)?;
let (output, tool_results) = if input.stream() {
call_chat_completions_streaming(&input, client.as_ref(), config, abort_signal.clone())
.await?
call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await?
} else {
call_chat_completions(&input, client.as_ref(), config).await?
call_chat_completions(&input, false, client.as_ref()).await?
};
config
.write()
Expand Down
32 changes: 28 additions & 4 deletions src/utils/abort_signal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
use anyhow::Result;
use crossterm::event::{self, Event, KeyCode, KeyModifiers};
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};

pub type AbortSignal = Arc<AbortSignalInner>;
Expand Down Expand Up @@ -54,11 +59,30 @@ impl AbortSignalInner {
}
}

pub async fn watch_abort_signal(abort_signal: AbortSignal) {
pub async fn wait_abort_signal(abort_signal: &AbortSignal) {
loop {
if abort_signal.aborted() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
}
}

pub fn poll_abort_signal(abort_signal: &AbortSignal) -> Result<bool> {
if crossterm::event::poll(Duration::from_millis(25))? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort_signal.set_ctrlc();
return Ok(true);
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort_signal.set_ctrld();
return Ok(true);
}
_ => {}
}
}
}
Ok(false)
}

0 comments on commit bf07cdf

Please sign in to comment.