Skip to content

Commit

Permalink
RUST-1894 Retry KMS requests on transient errors (#1281)
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelatkinson authored Jan 14, 2025
1 parent fac8592 commit 2dadbed
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[profile.default]
test-threads = 1
default-filter = 'not test(test::happy_eyeballs)'
default-filter = 'not test(test::happy_eyeballs) and not test(kms_retry)'

[profile.ci]
failure-output = "final"
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ time = "0.3.9"
tokio = { version = ">= 0.0.0", features = ["fs", "parking_lot"] }
tracing-subscriber = "0.3.16"
regex = "1.6.0"
reqwest = { version = "0.12.2", features = ["rustls-tls"] }
serde-hex = "0.1.0"
serde_path_to_error = "0.1"

Expand Down
1 change: 1 addition & 0 deletions src/client/csfle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl ClientState {
let mut builder = Crypt::builder()
.kms_providers(&opts.kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.retry_kms(true)?
.use_range_v2()?;
if let Some(m) = &opts.schema_map {
builder = builder.schema_map(&bson::to_document(m)?)?;
Expand Down
1 change: 1 addition & 0 deletions src/client/csfle/client_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl ClientEncryption {
let crypt = Crypt::builder()
.kms_providers(&kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.retry_kms(true)?
.use_range_v2()?
.build()?;
let exec = CryptExecutor::new_explicit(
Expand Down
90 changes: 58 additions & 32 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ use std::{
convert::TryInto,
ops::DerefMut,
path::{Path, PathBuf},
time::Duration,
};

use bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
use futures_util::{stream, TryStreamExt};
use mongocrypt::ctx::{Ctx, KmsProviderType, State};
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
use rayon::ThreadPool;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{oneshot, Mutex},
};

use crate::{
client::{options::ServerAddress, WeakClient},
client::{csfle::options::KmsProvidersTlsOptions, options::ServerAddress, WeakClient},
error::{Error, Result},
operation::{run_command::RunCommand, RawOutput},
options::ReadConcern,
Expand Down Expand Up @@ -174,37 +175,62 @@ impl CryptExecutor {
State::NeedKms => {
let ctx = result_mut(&mut ctx)?;
let scope = ctx.kms_scope();
let mut kms_ctxen: Vec<Result<_>> = vec![];
while let Some(kms_ctx) = scope.next_kms_ctx() {
kms_ctxen.push(Ok(kms_ctx));

async fn execute(
kms_ctx: &mut KmsCtx<'_>,
tls_options: Option<&KmsProvidersTlsOptions>,
) -> Result<()> {
let endpoint = kms_ctx.endpoint()?;
let addr = ServerAddress::parse(endpoint)?;
let provider = kms_ctx.kms_provider()?;
let tls_options = tls_options
.and_then(|tls| tls.get(&provider))
.cloned()
.unwrap_or_default();
let mut stream =
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?)).await?;
stream.write_all(kms_ctx.message()?).await?;
let mut buf = vec![0];
while kms_ctx.bytes_needed() > 0 {
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
Error::internal(format!("buffer size overflow: {}", e))
})?;
buf.resize(buf_size, 0);
let count = stream.read(&mut buf).await?;
kms_ctx.feed(&buf[0..count])?;
}
Ok(())
}

loop {
let mut kms_contexts: Vec<Result<_>> = Vec::new();
while let Some(kms_ctx) = scope.next_kms_ctx() {
kms_contexts.push(Ok(kms_ctx));
}
if kms_contexts.is_empty() {
break;
}

stream::iter(kms_contexts)
.try_for_each_concurrent(None, |mut kms_ctx| async move {
let sleep_micros =
u64::try_from(kms_ctx.sleep_micros()).unwrap_or(0);
if sleep_micros > 0 {
tokio::time::sleep(Duration::from_micros(sleep_micros)).await;
}

if let Err(error) =
execute(&mut kms_ctx, self.kms_providers.tls_options()).await
{
if !kms_ctx.retry_failure() {
return Err(error);
}
}

Ok(())
})
.await?;
}
stream::iter(kms_ctxen)
.try_for_each_concurrent(None, |mut kms_ctx| async move {
let endpoint = kms_ctx.endpoint()?;
let addr = ServerAddress::parse(endpoint)?;
let provider = kms_ctx.kms_provider()?;
let tls_options = self
.kms_providers
.tls_options()
.and_then(|tls| tls.get(&provider))
.cloned()
.unwrap_or_default();
let mut stream =
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?))
.await?;
stream.write_all(kms_ctx.message()?).await?;
let mut buf = vec![0];
while kms_ctx.bytes_needed() > 0 {
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
Error::internal(format!("buffer size overflow: {}", e))
})?;
buf.resize(buf_size, 0);
let count = stream.read(&mut buf).await?;
kms_ctx.feed(&buf[0..count])?;
}
Ok(())
})
.await?;
}
State::NeedKmsCredentials => {
let ctx = result_mut(&mut ctx)?;
Expand Down
161 changes: 161 additions & 0 deletions src/test/csfle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3493,6 +3493,167 @@ async fn range_explicit_encryption_defaults() -> Result<()> {
Ok(())
}

// Prose Test 24. KMS Retry Tests
#[tokio::test]
// using openssl causes errors after configuring a network failpoint
#[cfg(not(feature = "openssl-tls"))]
async fn kms_retry() {
use reqwest::{Certificate, Client as HttpClient};

let endpoint = "127.0.0.1:9003";

let mut certificate_file_path = PathBuf::from(std::env::var("CSFLE_TLS_CERT_DIR").unwrap());
certificate_file_path.push("ca.pem");
let certificate_file = std::fs::read(&certificate_file_path).unwrap();

let set_failpoint = |kind: &str, count: u8| {
// create a fresh client for each request to avoid hangs
let http_client = HttpClient::builder()
.add_root_certificate(Certificate::from_pem(&certificate_file).unwrap())
.build()
.unwrap();
let url = format!("https://localhost:9003/set_failpoint/{}", kind);
let body = format!("{{\"count\":{}}}", count);
http_client.post(url).body(body).send()
};

let aws_kms = AWS_KMS.clone();
let mut azure_kms = AZURE_KMS.clone();
azure_kms.1.insert("identityPlatformEndpoint", endpoint);
let mut gcp_kms = GCP_KMS.clone();
gcp_kms.1.insert("endpoint", endpoint);
let mut kms_providers = vec![aws_kms, azure_kms, gcp_kms];

let tls_options = get_client_options().await.tls_options();
for kms_provider in kms_providers.iter_mut() {
kms_provider.2 = tls_options.clone();
}

let key_vault_client = Client::for_test().await.into_client();
let client_encryption = ClientEncryption::new(
key_vault_client,
Namespace::new("keyvault", "datakeys"),
kms_providers,
)
.unwrap();

let aws_master_key = AwsMasterKey::builder()
.region("foo")
.key("bar")
.endpoint(endpoint.to_string())
.build();
let azure_master_key = AzureMasterKey::builder()
.key_vault_endpoint(endpoint)
.key_name("foo")
.build();
let gcp_master_key = GcpMasterKey::builder()
.project_id("foo")
.location("bar")
.key_ring("baz")
.key_name("qux")
.endpoint(endpoint.to_string())
.build();

// Case 1: createDataKey and encrypt with TCP retry

// AWS
set_failpoint("network", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(aws_master_key.clone())
.await
.unwrap();
set_failpoint("network", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Azure
set_failpoint("network", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(azure_master_key.clone())
.await
.unwrap();
set_failpoint("network", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// GCP
set_failpoint("network", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(gcp_master_key.clone())
.await
.unwrap();
set_failpoint("network", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Case 2: createDataKey and encrypt with HTTP retry

// AWS
set_failpoint("http", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(aws_master_key.clone())
.await
.unwrap();
set_failpoint("http", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Azure
set_failpoint("http", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(azure_master_key.clone())
.await
.unwrap();
set_failpoint("http", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// GCP
set_failpoint("http", 1).await.unwrap();
let key_id = client_encryption
.create_data_key(gcp_master_key.clone())
.await
.unwrap();
set_failpoint("http", 1).await.unwrap();
client_encryption
.encrypt(123, key_id, Algorithm::Deterministic)
.await
.unwrap();

// Case 3: createDataKey fails after too many retries

// AWS
set_failpoint("network", 4).await.unwrap();
client_encryption
.create_data_key(aws_master_key)
.await
.unwrap_err();

// Azure
set_failpoint("network", 4).await.unwrap();
client_encryption
.create_data_key(azure_master_key)
.await
.unwrap_err();

// GCP
set_failpoint("network", 4).await.unwrap();
client_encryption
.create_data_key(gcp_master_key)
.await
.unwrap_err();
}

// FLE 2.0 Documentation Example
#[tokio::test]
async fn fle2_example() -> Result<()> {
Expand Down

0 comments on commit 2dadbed

Please sign in to comment.