Skip to content

Commit

Permalink
refactor to async
Browse files Browse the repository at this point in the history
  • Loading branch information
Tevinthuku committed Mar 2, 2024
1 parent 7ef7cbf commit e18f0cd
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 41 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions web-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ edition = "2021"
[dependencies]
bytes = "1.5.0"
crossbeam = "0.8.4"
tokio = { version = "1.36.0", features = ["full"] }
152 changes: 111 additions & 41 deletions web-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
use std::{
borrow::Cow,
env,
fs::File,
io::{Read, Write},
use std::{borrow::Cow, env, future::Future, num::NonZeroUsize, path::Path};
use tokio::{fs::File, signal};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
num::NonZeroUsize,
path::Path,
runtime::Builder,
select,
};

use bytes::{BufMut, BytesMut};
use crossbeam::channel::unbounded;
use crossbeam::channel::Receiver;
use crossbeam::channel::{unbounded, Sender};

fn main() -> std::io::Result<()> {
#[tokio::main]
async fn main() -> std::io::Result<()> {
let address = env::var("ADDRESS").unwrap_or("127.0.0.1:80".to_owned());
let file_directory = env::var("FILE_DIRECTORY").unwrap_or("./www".to_owned());
run_server(&address, file_directory)
run_server(&address, file_directory, signal::ctrl_c()).await
}

fn run_server(address: &str, file_directory: String) -> std::io::Result<()> {
async fn run_server(
address: &str,
file_directory: String,
shutdown: impl Future,
) -> std::io::Result<()> {
let (sender, receiver) = unbounded::<TcpStream>();

let listener = TcpListener::bind(address)?;

let available_parallelism = std::thread::available_parallelism().map_or(2, NonZeroUsize::get);

let mut threads = Vec::with_capacity(available_parallelism);
Expand All @@ -34,15 +36,16 @@ fn run_server(address: &str, file_directory: String) -> std::io::Result<()> {
threads.push(thread);
}

for stream in listener.incoming() {
sender.send(stream?).map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to send stream to receiver: {err:?}"),
)
})?;
select! {
_ = run_server_inner(address, sender.clone()) => {}
ctrl = signal::ctrl_c() => {
ctrl?;
},
_ = shutdown => {}
}

drop(sender);

for thread in threads {
thread.join().map_err(|err| {
std::io::Error::new(
Expand All @@ -55,16 +58,33 @@ fn run_server(address: &str, file_directory: String) -> std::io::Result<()> {
Ok(())
}

fn process_requests(receiver: Receiver<TcpStream>, file_directory: String) {
for stream in receiver {
handle_client(stream, &file_directory).unwrap();
async fn run_server_inner(address: &str, sender: Sender<TcpStream>) -> std::io::Result<()> {
let listener = TcpListener::bind(address).await?;

loop {
let (stream, _) = listener.accept().await?;
sender.send(stream).map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to send stream to receiver: {err:?}"),
)
})?;
}
}

fn handle_client(mut stream: TcpStream, file_directory: &str) -> std::io::Result<()> {
let mut buffer = [0; 1024];
fn process_requests(receiver: Receiver<TcpStream>, file_directory: String) {
let rt = Builder::new_current_thread().enable_all().build().unwrap();
rt.block_on(async move {
for stream in receiver {
handle_client(stream, &file_directory).await.unwrap();
}
})
}

async fn handle_client(mut stream: TcpStream, file_directory: &str) -> std::io::Result<()> {
let mut buffer = BytesMut::with_capacity(1024);

let read_bytes = stream.read(&mut buffer)?;
let read_bytes = stream.read_buf(&mut buffer).await?;

let buffer = &buffer[0..read_bytes];

Expand All @@ -83,19 +103,19 @@ fn handle_client(mut stream: TcpStream, file_directory: &str) -> std::io::Result

let path = Path::new(file_directory).join(path);
let mut buffer = Vec::with_capacity(1024 * 1024);
let file = File::open(path);
let file = File::open(path).await;

let mut file = match file {
Ok(file) => file,
Err(_) => {
let mut response = BytesMut::with_capacity(512);
response.put_slice(request.http_version);
response.put_slice(b" 404 Not Found\r\n\r\n");
return stream.write_all(&response);
return stream.write_all(&response).await;
}
};

file.read_to_end(buffer.as_mut())?;
file.read_to_end(buffer.as_mut()).await?;
buffer
};

Expand All @@ -104,7 +124,7 @@ fn handle_client(mut stream: TcpStream, file_directory: &str) -> std::io::Result
response.put_slice(b" 200 OK\r\n\r\n");
response.put_slice(&file_to_send);
response.put_slice(b"\r\n");
stream.write_all(&response)
stream.write_all(&response).await
}

fn parse_request(first_line: &[u8]) -> std::io::Result<Request<'_>> {
Expand Down Expand Up @@ -133,35 +153,85 @@ struct Request<'a> {

#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},

use std::{io, time::Duration};

use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
thread,
sync::oneshot,
time,
};

use crate::run_server;

#[test]
fn test_connection() {
let address = "127.0.0.1:8080";
let start_thread = thread::spawn(|| run_server(address, "./www".to_owned()));
async fn connect(address: &str) -> io::Result<TcpStream> {
let mut backoff = 1;
loop {
match TcpStream::connect(address).await {
Ok(socket) => return Ok(socket),
Err(err) => {
if backoff > 30 {
return Err(err);
}
}
}
time::sleep(Duration::from_secs(backoff)).await;
backoff *= 2;
}
}

thread::sleep(std::time::Duration::from_secs(1));
#[tokio::test]
async fn test_connection() {
let (send, recv) = oneshot::channel::<()>();

let mut stream = TcpStream::connect(address).expect("Failed to connect to server");
println!("Connected to server");
let address = "127.0.0.1:0";
let join_handle = tokio::spawn(async move {
run_server(address, "./www".to_owned(), recv).await.unwrap();
});
let mut stream = connect(address).await.unwrap();
stream
.write_all(b"GET / HTTP/1.1\r\n\r\n")
.await
.expect("Failed to write to stream");

let mut response = String::new();
stream
.read_to_string(&mut response)
.await
.expect("Failed to read from stream");

drop(stream);

assert!(response.contains("HTTP/1.1 200 OK"));
// let _ = start_thread.join().expect("Failed to join thread");
send.send(()).expect("Failed to send shutdown signal");
let _ = join_handle.await;
}

#[tokio::test]
async fn test_not_found() {
let (send, recv) = oneshot::channel::<()>();

let address = "127.0.0.1:0";
let join_handle = tokio::spawn(async move {
run_server(address, "./www".to_owned(), recv).await.unwrap();
});
let mut stream = connect(address).await.unwrap();
stream
.write_all(b"GET /notfound.html HTTP/1.1\r\n\r\n")
.await
.expect("Failed to write to stream");

let mut response = String::new();
stream
.read_to_string(&mut response)
.await
.expect("Failed to read from stream");

drop(stream);

assert!(response.contains("HTTP/1.1 404 Not Found"));
send.send(()).expect("Failed to send shutdown signal");
let _ = join_handle.await;
}
}

0 comments on commit e18f0cd

Please sign in to comment.