Skip to content

Commit

Permalink
Add MockConnectInfo (#1767)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Feb 17, 2023
1 parent cd86f7e commit 143c415
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 3 deletions.
1 change: 1 addition & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `FormRejection::FailedToDeserializeFormBody` which is returned
if the request body couldn't be deserialized into the target type, as opposed
to `FailedToDeserializeForm` which is only for query parameters ([#1683])
- **added:** Add `MockConnectInfo` for setting `ConnectInfo` during tests

[#1683]: https://github.com/tokio-rs/axum/pull/1683
[#1690]: https://github.com/tokio-rs/axum/pull/1690
Expand Down
116 changes: 113 additions & 3 deletions axum/src/extract/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,81 @@ where
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request_parts(parts, state).await?;
Ok(connect_info)
match Extension::<Self>::from_request_parts(parts, state).await {
Ok(Extension(connect_info)) => Ok(connect_info),
Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
None => Err(err),
},
}
}
}

/// Middleware used to mock [`ConnectInfo`] during tests.
///
/// If you're accidentally using [`MockConnectInfo`] and
/// [`Router::into_make_service_with_connect_info`] at the same time then
/// [`Router::into_make_service_with_connect_info`] takes precedence.
///
/// # Example
///
/// ```
/// use axum::{
/// Router,
/// extract::connect_info::{MockConnectInfo, ConnectInfo},
/// body::Body,
/// routing::get,
/// http::{Request, StatusCode},
/// };
/// use std::net::SocketAddr;
/// use tower::ServiceExt;
///
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {}
///
/// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()`
/// fn app() -> Router {
/// Router::new().route("/", get(handler))
/// }
///
/// // use this router for tests
/// fn test_app() -> Router {
/// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))))
/// }
///
/// // #[tokio::test]
/// async fn some_test() {
/// let app = test_app();
///
/// let request = Request::new(Body::empty());
/// let response = app.oneshot(request).await.unwrap();
/// assert_eq!(response.status(), StatusCode::OK);
/// }
/// #
/// # #[tokio::main]
/// # async fn main() {
/// # some_test().await;
/// # }
/// ```
///
/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
#[derive(Clone, Copy, Debug)]
pub struct MockConnectInfo<T>(pub T);

impl<S, T> Layer<S> for MockConnectInfo<T>
where
T: Clone + Send + Sync + 'static,
{
type Service = <Extension<Self> as Layer<S>>::Service;

fn layer(&self, inner: S) -> Self::Service {
Extension(self.clone()).layer(inner)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, Router, Server};
use crate::{routing::get, test_helpers::TestClient, Router, Server};
use std::net::{SocketAddr, TcpListener};

#[crate::test]
Expand Down Expand Up @@ -214,4 +280,48 @@ mod tests {
let body = res.text().await.unwrap();
assert_eq!(body, "it worked!");
}

#[crate::test]
async fn mock_connect_info() {
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
format!("{addr}")
}

let app = Router::new()
.route("/", get(handler))
.layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));

let client = TestClient::new(app);

let res = client.get("/").send().await;
let body = res.text().await;
assert!(body.starts_with("0.0.0.0:1337"));
}

#[crate::test]
async fn both_mock_and_real_connect_info() {
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
format!("{addr}")
}

let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
let app = Router::new()
.route("/", get(handler))
.layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));

let server = Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service_with_connect_info::<SocketAddr>());
server.await.expect("server error");
});

let client = reqwest::Client::new();

let res = client.get(format!("http://{addr}")).send().await.unwrap();
let body = res.text().await.unwrap();
assert!(body.starts_with("127.0.0.1:"));
}
}
25 changes: 25 additions & 0 deletions examples/testing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
//! cargo test -p example-testing
//! ```
use std::net::SocketAddr;

use axum::{
extract::ConnectInfo,
routing::{get, post},
Json, Router,
};
Expand Down Expand Up @@ -43,6 +46,10 @@ fn app() -> Router {
Json(serde_json::json!({ "data": payload.0 }))
}),
)
.route(
"/requires-connect-into",
get(|ConnectInfo(addr): ConnectInfo<SocketAddr>| async move { format!("Hi {addr}") }),
)
// We can still add middleware
.layer(TraceLayer::new_for_http())
}
Expand All @@ -52,6 +59,7 @@ mod tests {
use super::*;
use axum::{
body::Body,
extract::connect_info::MockConnectInfo,
http::{self, Request, StatusCode},
};
use serde_json::{json, Value};
Expand Down Expand Up @@ -164,4 +172,21 @@ mod tests {
let response = app.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}

// Here we're calling `/requires-connect-into` which requires `ConnectInfo`
//
// That is normally set with `Router::into_make_service_with_connect_info` but we can't easily
// use that during tests. The solution is instead to set the `MockConnectInfo` layer during
// tests.
#[tokio::test]
async fn with_into_make_service_with_connect_info() {
let mut app = app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 3000))));

let request = Request::builder()
.uri("/requires-connect-into")
.body(Body::empty())
.unwrap();
let response = app.ready().await.unwrap().call(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}

0 comments on commit 143c415

Please sign in to comment.