From 2fab44c809d6ba326201419857ca7fe44196a9f0 Mon Sep 17 00:00:00 2001 From: losfair Date: Sun, 21 Jan 2024 19:06:07 +0800 Subject: [PATCH] fix: BufReader should not panic when used after cancellation --- monoio/src/io/util/buf_reader.rs | 71 +++++++++++++++++++++++--------- monoio/tests/buf_reader.rs | 33 +++++++++++++++ 2 files changed, 85 insertions(+), 19 deletions(-) create mode 100644 monoio/tests/buf_reader.rs diff --git a/monoio/src/io/util/buf_reader.rs b/monoio/src/io/util/buf_reader.rs index f081749d..7f31c0de 100644 --- a/monoio/src/io/util/buf_reader.rs +++ b/monoio/src/io/util/buf_reader.rs @@ -6,12 +6,38 @@ use crate::{ BufResult, }; +enum BufState { + /// This buffer is never used, in use, or used by a previously cancelled + /// read. + Unallocated(usize), + + /// This buffer is available. + Available(Box<[u8]>), +} + +impl BufState { + fn take(&mut self) -> Box<[u8]> { + let size = self.size(); + match std::mem::replace(self, BufState::Unallocated(size)) { + BufState::Unallocated(len) => vec![0u8; len].into(), + BufState::Available(buf) => buf, + } + } + + fn size(&self) -> usize { + match self { + BufState::Unallocated(len) => *len, + BufState::Available(buf) => buf.len(), + } + } +} + /// BufReader is a struct with a buffer. BufReader implements AsyncBufRead /// and AsyncReadRent, and if the inner io implements AsyncWriteRent, it /// will delegate the implementation. pub struct BufReader { inner: R, - buf: Option>, + buf: BufState, pos: usize, cap: usize, } @@ -26,10 +52,9 @@ impl BufReader { /// Create BufReader with given buffer size pub fn with_capacity(capacity: usize, inner: R) -> Self { - let buffer = vec![0; capacity]; Self { inner, - buf: Some(buffer.into_boxed_slice()), + buf: BufState::Unallocated(capacity), pos: 0, cap: 0, } @@ -59,7 +84,10 @@ impl BufReader { /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is /// empty. pub fn buffer(&self) -> &[u8] { - &self.buf.as_ref().expect("unable to take buffer")[self.pos..self.cap] + match &self.buf { + BufState::Unallocated(_) => &[], + BufState::Available(buf) => &buf[self.pos..self.cap], + } } /// Invalidates all data in the internal buffer. @@ -75,8 +103,7 @@ impl AsyncReadRent for BufReader { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - let owned_buf = self.buf.as_ref().unwrap(); - if self.pos == self.cap && buf.bytes_total() >= owned_buf.len() { + if self.pos == self.cap && buf.bytes_total() >= self.buf.size() { self.discard_buffer(); return self.inner.read(buf).await; } @@ -115,19 +142,19 @@ impl AsyncBufRead for BufReader { async fn fill_buf(&mut self) -> std::io::Result<&[u8]> { if self.pos == self.cap { // there's no buffered data - let buf = self - .buf - .take() - .expect("no buffer available, generated future must be awaited"); + let buf = self.buf.take(); let (res, buf_) = self.inner.read(buf).await; - self.buf = Some(buf_); + self.buf = BufState::Available(buf_); match res { Ok(n) => { self.pos = 0; self.cap = n; - return Ok(unsafe { - // We just put the buf into Option, so it must be Some. - &(self.buf.as_ref().unwrap_unchecked().as_ref())[self.pos..self.cap] + return Ok(match &self.buf { + BufState::Available(buf) => &buf[self.pos..self.cap], + BufState::Unallocated(_) => { + // We just put the buf into Option, so it must be Some. + unreachable!() + } }); } Err(e) => { @@ -135,11 +162,17 @@ impl AsyncBufRead for BufReader { } } } - Ok(&(self - .buf - .as_ref() - .expect("no buffer available, generated future must be awaited") - .as_ref())[self.pos..self.cap]) + match &self.buf { + BufState::Available(buf) => Ok(&buf[self.pos..self.cap]), + BufState::Unallocated(_) => { + // The `Unallocated` state only happens if: + // - nothing is read into this `BufReader` yet (pos == 0, cap == 0), or + // - a previous `fill_buf` was cancelled (pos == cap) + // Both cases are covered by the above `if` block, so it's impossible + // to reach here. + unreachable!("buf is unallocated"); + } + } } fn consume(&mut self, amt: usize) { diff --git a/monoio/tests/buf_reader.rs b/monoio/tests/buf_reader.rs new file mode 100644 index 00000000..00b18407 --- /dev/null +++ b/monoio/tests/buf_reader.rs @@ -0,0 +1,33 @@ +use futures::FutureExt; +use monoio::{ + io::{AsyncReadRentExt, BufReader}, + net::{TcpListener, TcpStream}, +}; + +#[monoio::test_all(timer_enabled = true)] +async fn buf_reader_use_after_cancel() { + let srv = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = srv.local_addr().unwrap(); + + monoio::spawn(async move { + let (mut stream, _) = srv.accept().await.unwrap(); + + // deadlock + let _ = stream.read_exact(vec![0u8; 1]).await; + }); + + let stream = TcpStream::connect(addr).await.unwrap(); + let mut stream = BufReader::new(stream); + + // Cancel the first read after a timeout + futures::select_biased! { + _ = monoio::time::sleep(std::time::Duration::from_millis(50)).fuse() => {}, + _ = stream.read_exact(vec![0u8; 1]).fuse() => unreachable!(), + } + + // Read again. This should not panic. + futures::select_biased! { + _ = monoio::time::sleep(std::time::Duration::from_millis(50)).fuse() => {}, + _ = stream.read_exact(vec![0u8; 1]).fuse() => unreachable!(), + } +}