diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 17b7fc2330..127c90bb4d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Add `FormRejection::FailedToDeserializeFormBody` which is returned if the request body couldn't be deserialized into the target type, as opposed to `FailedToDeserializeForm` which is only for query parameters ([#1683]) +- **added:** Add `MockConnectInfo` for setting `ConnectInfo` during tests [#1683]: https://github.com/tokio-rs/axum/pull/1683 [#1690]: https://github.com/tokio-rs/axum/pull/1690 diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 2ab24863b0..7dd0c47d49 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -137,15 +137,81 @@ where type Rejection = as FromRequestParts>::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Extension(connect_info) = Extension::::from_request_parts(parts, state).await?; - Ok(connect_info) + match Extension::::from_request_parts(parts, state).await { + Ok(Extension(connect_info)) => Ok(connect_info), + Err(err) => match parts.extensions.get::>() { + Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())), + None => Err(err), + }, + } + } +} + +/// Middleware used to mock [`ConnectInfo`] during tests. +/// +/// If you're accidentally using [`MockConnectInfo`] and +/// [`Router::into_make_service_with_connect_info`] at the same time then +/// [`Router::into_make_service_with_connect_info`] takes precedence. +/// +/// # Example +/// +/// ``` +/// use axum::{ +/// Router, +/// extract::connect_info::{MockConnectInfo, ConnectInfo}, +/// body::Body, +/// routing::get, +/// http::{Request, StatusCode}, +/// }; +/// use std::net::SocketAddr; +/// use tower::ServiceExt; +/// +/// async fn handler(ConnectInfo(addr): ConnectInfo) {} +/// +/// // this router you can run with `app.into_make_service_with_connect_info::()` +/// fn app() -> Router { +/// Router::new().route("/", get(handler)) +/// } +/// +/// // use this router for tests +/// fn test_app() -> Router { +/// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))) +/// } +/// +/// // #[tokio::test] +/// async fn some_test() { +/// let app = test_app(); +/// +/// let request = Request::new(Body::empty()); +/// let response = app.oneshot(request).await.unwrap(); +/// assert_eq!(response.status(), StatusCode::OK); +/// } +/// # +/// # #[tokio::main] +/// # async fn main() { +/// # some_test().await; +/// # } +/// ``` +/// +/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info +#[derive(Clone, Copy, Debug)] +pub struct MockConnectInfo(pub T); + +impl Layer for MockConnectInfo +where + T: Clone + Send + Sync + 'static, +{ + type Service = as Layer>::Service; + + fn layer(&self, inner: S) -> Self::Service { + Extension(self.clone()).layer(inner) } } #[cfg(test)] mod tests { use super::*; - use crate::{routing::get, Router, Server}; + use crate::{routing::get, test_helpers::TestClient, Router, Server}; use std::net::{SocketAddr, TcpListener}; #[crate::test] @@ -214,4 +280,48 @@ mod tests { let body = res.text().await.unwrap(); assert_eq!(body, "it worked!"); } + + #[crate::test] + async fn mock_connect_info() { + async fn handler(ConnectInfo(addr): ConnectInfo) -> String { + format!("{addr}") + } + + let app = Router::new() + .route("/", get(handler)) + .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + let body = res.text().await; + assert!(body.starts_with("0.0.0.0:1337")); + } + + #[crate::test] + async fn both_mock_and_real_connect_info() { + async fn handler(ConnectInfo(addr): ConnectInfo) -> String { + format!("{addr}") + } + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let app = Router::new() + .route("/", get(handler)) + .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); + + let server = Server::from_tcp(listener) + .unwrap() + .serve(app.into_make_service_with_connect_info::()); + server.await.expect("server error"); + }); + + let client = reqwest::Client::new(); + + let res = client.get(format!("http://{addr}")).send().await.unwrap(); + let body = res.text().await.unwrap(); + assert!(body.starts_with("127.0.0.1:")); + } } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index b2041bef9e..b5f7466b46 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -4,7 +4,10 @@ //! cargo test -p example-testing //! ``` +use std::net::SocketAddr; + use axum::{ + extract::ConnectInfo, routing::{get, post}, Json, Router, }; @@ -43,6 +46,10 @@ fn app() -> Router { Json(serde_json::json!({ "data": payload.0 })) }), ) + .route( + "/requires-connect-into", + get(|ConnectInfo(addr): ConnectInfo| async move { format!("Hi {addr}") }), + ) // We can still add middleware .layer(TraceLayer::new_for_http()) } @@ -52,6 +59,7 @@ mod tests { use super::*; use axum::{ body::Body, + extract::connect_info::MockConnectInfo, http::{self, Request, StatusCode}, }; use serde_json::{json, Value}; @@ -164,4 +172,21 @@ mod tests { let response = app.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); } + + // Here we're calling `/requires-connect-into` which requires `ConnectInfo` + // + // That is normally set with `Router::into_make_service_with_connect_info` but we can't easily + // use that during tests. The solution is instead to set the `MockConnectInfo` layer during + // tests. + #[tokio::test] + async fn with_into_make_service_with_connect_info() { + let mut app = app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 3000)))); + + let request = Request::builder() + .uri("/requires-connect-into") + .body(Body::empty()) + .unwrap(); + let response = app.ready().await.unwrap().call(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + } }