Skip to content

Commit

Permalink
Resolve unwrap in build body and when acquiring RwLock (#79)
Browse files Browse the repository at this point in the history
* fix: Use parking_lot to avoid PoisenError when acquiring a RwLock<_>

* fix: Propagate error when building a HTTP request

- Add error::Error variant
- Add test for an invalid request

* feat: Warn when unwrap is used
  • Loading branch information
threema-donat authored Apr 26, 2024
1 parent a3860ae commit ed825b4
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 34 deletions.
49 changes: 49 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ring = { version = "0.17", features = ["std"], optional = true }
hyper-rustls = { version = "0.26.0", default-features = false, features = ["http2", "webpki-roots", "ring"] }
rustls-pemfile = "2.1.1"
rustls = "0.22.0"
parking_lot = "0.12"

[dev-dependencies]
argparse = "0.2"
Expand Down
62 changes: 35 additions & 27 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Client {
/// See [ErrorReason](enum.ErrorReason.html) for possible errors.
#[cfg_attr(feature = "tracing", ::tracing::instrument)]
pub async fn send<T: PayloadLike>(&self, payload: T) -> Result<Response, Error> {
let request = self.build_request(payload);
let request = self.build_request(payload)?;
let requesting = self.http_client.request(request);

let response = requesting.await?;
Expand Down Expand Up @@ -152,7 +152,7 @@ impl Client {
}
}

fn build_request<T: PayloadLike>(&self, payload: T) -> hyper::Request<BoxBody<Bytes, Infallible>> {
fn build_request<T: PayloadLike>(&self, payload: T) -> Result<hyper::Request<BoxBody<Bytes, Infallible>>, Error> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());

let mut builder = hyper::Request::builder()
Expand Down Expand Up @@ -180,18 +180,16 @@ impl Client {
builder = builder.header("apns-topic", apns_topic.as_bytes());
}
if let Some(ref signer) = self.signer {
let auth = signer
.with_signature(|signature| format!("Bearer {}", signature))
.unwrap();
let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?;

builder = builder.header(AUTHORIZATION, auth.as_bytes());
}

let payload_json = payload.to_json_string().unwrap();
let payload_json = payload.to_json_string()?;
builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes());

let request_body = Full::from(payload_json.into_bytes()).boxed();
builder.body(request_body).unwrap()
builder.body(request_body).map_err(Error::BuildRequestError)
}
}

Expand Down Expand Up @@ -247,7 +245,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -258,7 +256,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Sandbox);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.development.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -269,17 +267,27 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(&Method::POST, request.method());
}

#[test]
fn test_request_invalid() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("\r\n", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);

assert!(matches!(request, Err(Error::BuildRequestError(_))));
}

#[test]
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
}
Expand All @@ -289,7 +297,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();

Expand All @@ -301,7 +309,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -319,7 +327,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), Some(signer), Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_ne!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -333,7 +341,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
};
let payload = builder.build("a_test_id", options);
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_push_type = request.headers().get("apns-push-type").unwrap();

assert_eq!("background", apns_push_type);
Expand All @@ -344,7 +352,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority");

assert_eq!(None, apns_priority);
Expand All @@ -363,7 +371,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("5", apns_priority);
Expand All @@ -382,7 +390,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("10", apns_priority);
Expand All @@ -395,7 +403,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id");

assert_eq!(None, apns_id);
Expand All @@ -414,7 +422,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id").unwrap();

assert_eq!("a-test-apns-id", apns_id);
Expand All @@ -427,7 +435,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration");

assert_eq!(None, apns_expiration);
Expand All @@ -446,7 +454,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration").unwrap();

assert_eq!("420", apns_expiration);
Expand All @@ -459,7 +467,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id");

assert_eq!(None, apns_collapse_id);
Expand All @@ -478,7 +486,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

assert_eq!("a_collapse_id", apns_collapse_id);
Expand All @@ -491,7 +499,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic");

assert_eq!(None, apns_topic);
Expand All @@ -510,7 +518,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic").unwrap();

assert_eq!("a_topic", apns_topic);
Expand All @@ -521,7 +529,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();

let body = request.into_body().collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(body.to_vec()).unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub enum Error {
#[error("Error building TLS config: {0}")]
Tls(#[from] rustls::Error),

/// Error while creating the HTTP request
#[error("Failed to construct HTTP request: {0}")]
BuildRequestError(#[source] http::Error),

/// Unexpected private key (only EC keys are supported).
#[cfg(all(not(feature = "openssl"), feature = "ring"))]
#[error("Unexpected private key: {0}")]
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
//! }
//! # }
//! ```
#![warn(clippy::unwrap_used)]

#[cfg(not(any(feature = "openssl", feature = "ring")))]
compile_error!("either feature \"openssl\" or feature \"ring\" has to be enabled");

Expand Down
12 changes: 5 additions & 7 deletions src/signer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::error::Error;
use parking_lot::RwLock;
use std::io::Read;
use std::sync::Arc;
use std::{
sync::RwLock,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use base64::prelude::*;
#[cfg(feature = "openssl")]
Expand Down Expand Up @@ -138,7 +136,7 @@ impl Signer {
self.renew()?;
}

let signature = self.signature.read().unwrap();
let signature = self.signature.read();

#[cfg(feature = "tracing")]
{
Expand Down Expand Up @@ -191,7 +189,7 @@ impl Signer {
);
}

let mut signature = self.signature.write().unwrap();
let mut signature = self.signature.write();

*signature = Signature {
key: Self::create_signature(&self.secret, &self.key_id, &self.team_id, issued_at)?,
Expand All @@ -202,7 +200,7 @@ impl Signer {
}

fn is_expired(&self) -> bool {
let sig = self.signature.read().unwrap();
let sig = self.signature.read();
let expiry = get_time() - sig.issued_at;
expiry >= self.expire_after_s.as_secs() as i64
}
Expand Down

0 comments on commit ed825b4

Please sign in to comment.