Skip to content

Commit

Permalink
Unit test for tracker with callback handler
Browse files Browse the repository at this point in the history
Signed-off-by: Sreekanth <[email protected]>
  • Loading branch information
BulkBeing committed Jan 21, 2025
1 parent bd99be6 commit 8ba4992
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 3 deletions.
2 changes: 1 addition & 1 deletion rust/numaflow-core/src/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ impl Source {
let offset = message.offset.clone().expect("offset can never be none");

// insert the offset and the ack one shot in the tracker.
tracker_handle.insert(&message, resp_ack_tx).await?;
tracker_handle.insert(message, resp_ack_tx).await?;

// store the ack one shot in the batch to invoke ack later.
ack_batch.push((offset, resp_ack_rx));
Expand Down
153 changes: 152 additions & 1 deletion rust/numaflow-core/src/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,63 @@ impl TrackerHandle {

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use axum::routing::{get, post};
use axum::{http::StatusCode, Router};
use tokio::sync::oneshot;
use tokio::time::{timeout, Duration};

use crate::message::MessageID;
use crate::message::{MessageID, Metadata};

use super::*;

type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;

#[test]
fn test_message_to_callback_info_conversion() {
let offset = Bytes::from_static(b"offset1");
let mut message = Message {
keys: Arc::from([]),
tags: None,
value: Bytes::from_static(b"test"),
offset: None,
event_time: Default::default(),
id: MessageID {
vertex_name: "in".into(),
offset: offset.clone(),
index: 1,
},
headers: HashMap::new(),
metadata: None,
};

let callback_info: super::Result<CallbackInfo> = TryFrom::try_from(&message);
assert!(callback_info.is_err());

const CALLBACK_URL: &str = "https://localhost/v1/process/callback";
let headers = [
(DEFAULT_CALLBACK_URL_HEADER_KEY, CALLBACK_URL),
(DEFAULT_ID_HEADER, "1234"),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
message.headers = headers;

const FROM_VERTEX_NAME: &str = "source-vetext";
message.metadata = Some(Metadata {
previous_vertex: FROM_VERTEX_NAME.into(),
});

let callback_info: CallbackInfo = TryFrom::try_from(&message).unwrap();
assert_eq!(callback_info.id, "1234");
assert_eq!(callback_info.callback_url, CALLBACK_URL);
assert_eq!(callback_info.from_vertex, FROM_VERTEX_NAME);
assert_eq!(callback_info.responses, vec![None]);
}

#[tokio::test]
async fn test_insert_update_delete() {
let handle = TrackerHandle::new(None);
Expand Down Expand Up @@ -590,4 +638,107 @@ mod tests {
assert_eq!(result.unwrap(), ReadAck::Nak);
assert!(handle.is_empty().await.unwrap(), "Tracker should be empty");
}

#[tokio::test]
async fn test_tracker_with_callback_handler() -> Result<()> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener
.local_addr()
.map_err(|e| format!("Failed to bind to 127.0.0.1:0: error={e:?}"))?
.port();

let server_addr = format!("127.0.0.1:{port}");
let callback_url = format!("http://{server_addr}/v1/process/callback");

let request_count = Arc::new(AtomicUsize::new(0));
let router = Router::new()
.route("/livez", get(|| async { StatusCode::OK }))
.route(
"/v1/process/callback",
post({
let req_count = Arc::clone(&request_count);
|| async move {
req_count.fetch_add(1, Ordering::Relaxed);
StatusCode::OK
}
}),
);

let server = tokio::spawn(async move {
axum::serve(listener, router).await.unwrap();
});

let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()?;

// Wait for the server to be ready
let mut server_ready = false;
let health_url = format!("http://{server_addr}/livez");
for _ in 0..10 {
let Ok(resp) = client.get(&health_url).send().await else {
tokio::time::sleep(Duration::from_millis(5)).await;
continue;
};
if resp.status().is_success() {
server_ready = true;
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
assert!(server_ready, "Server is not ready");

let callback_handler = CallbackHandler::new("test".into(), 10);
let handle = TrackerHandle::new(Some(callback_handler));
let (ack_send, ack_recv) = oneshot::channel();

let headers = [
(DEFAULT_CALLBACK_URL_HEADER_KEY, callback_url),
(DEFAULT_ID_HEADER, "1234".into()),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();

let offset = Bytes::from_static(b"offset1");
let message = Message {
keys: Arc::from([]),
tags: None,
value: Bytes::from_static(b"test"),
offset: None,
event_time: Default::default(),
id: MessageID {
vertex_name: "in".into(),
offset: offset.clone(),
index: 1,
},
headers,
metadata: Some(Metadata {
previous_vertex: "source-vertex".into(),
}),
};

// Insert a new message
handle.insert(&message, ack_send).await.unwrap();
handle.update_eof(offset).await.unwrap();

// Verify that the message was discarded and Ack was received
let result = timeout(Duration::from_secs(1), ack_recv).await.unwrap();
assert!(result.is_ok(), "Ack should be received");
assert_eq!(result.unwrap(), ReadAck::Ack);
assert!(handle.is_empty().await.unwrap(), "Tracker should be empty");

// Callback request is made after sending data on ack_send channel.
let mut received_callback_request = false;
for _ in 0..5 {
tokio::time::sleep(Duration::from_millis(10)).await;
received_callback_request = request_count.load(Ordering::Relaxed) == 1;
if received_callback_request {
break;
}
}
assert!(received_callback_request, "Expected one callback request");
server.abort();
Ok(())
}
}
2 changes: 1 addition & 1 deletion rust/serving/src/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ mod tests {
use crate::pipeline::PipelineDCG;
use crate::test_utils::get_port;
use crate::{AppState, Settings};
use axum::http::StatusCode;
use axum::routing::{get, post};
use axum::{Json, Router};
use axum_server::tls_rustls::RustlsConfig;
use reqwest::StatusCode;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
Expand Down

0 comments on commit 8ba4992

Please sign in to comment.