diff --git a/crates/daphne-server/src/storage_proxy_connection/mod.rs b/crates/daphne-server/src/storage_proxy_connection/mod.rs index 555d0084..702bfc5f 100644 --- a/crates/daphne-server/src/storage_proxy_connection/mod.rs +++ b/crates/daphne-server/src/storage_proxy_connection/mod.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use axum::http::StatusCode; use daphne_service_utils::{ - capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _}, + capnproto::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _}, durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom, DO_PATH_PREFIX}, }; use serde::de::DeserializeOwned; diff --git a/crates/daphne-service-utils/build.rs b/crates/daphne-service-utils/build.rs index 02c29fec..6b9d9f13 100644 --- a/crates/daphne-service-utils/build.rs +++ b/crates/daphne-service-utils/build.rs @@ -4,6 +4,7 @@ fn main() { #[cfg(feature = "durable_requests")] ::capnpc::CompilerCommand::new() + .file("./src/capnproto/base.capnp") .file("./src/durable_requests/durable_request.capnp") .run() .expect("compiling schema"); diff --git a/crates/daphne-service-utils/src/capnproto/base.capnp b/crates/daphne-service-utils/src/capnproto/base.capnp new file mode 100644 index 00000000..ab8668f4 --- /dev/null +++ b/crates/daphne-service-utils/src/capnproto/base.capnp @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0xba869f168ff63e77; + +enum DapVersion @0xb5b2c8705a8b22d5 { + draft09 @0; + draftLatest @1; +} + +# [u8; 32] +struct U8L32 @0x9e42cda292792294 { + fst @0 :UInt64; + snd @1 :UInt64; + thr @2 :UInt64; + frh @3 :UInt64; +} + +# [u8; 16] +struct U8L16 @0x9e3f65b13f71cfcb { + fst @0 :UInt64; + snd @1 :UInt64; +} + +struct PartialBatchSelector { + union { + timeInterval @0 :Void; + leaderSelectedByBatchId @1 :BatchId; + } +} + +using ReportId = U8L16; +using BatchId = U8L32; +using TaskId = U8L32; +using AggregationJobId = U8L16; +using Time = UInt64; diff --git a/crates/daphne-service-utils/src/capnproto/mod.rs b/crates/daphne-service-utils/src/capnproto/mod.rs new file mode 100644 index 00000000..ed7e7c98 --- /dev/null +++ b/crates/daphne-service-utils/src/capnproto/mod.rs @@ -0,0 +1,231 @@ +// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::base_capnp::{self, partial_batch_selector, u8_l16, u8_l32}; +use capnp::struct_list; +use capnp::traits::{FromPointerBuilder, FromPointerReader}; +use daphne::{ + messages::{AggregationJobId, BatchId, PartialBatchSelector, ReportId, TaskId}, + DapVersion, +}; + +pub trait CapnprotoPayloadEncode { + type Builder<'a>: FromPointerBuilder<'a>; + + fn encode_to_builder(&self, builder: Self::Builder<'_>); +} + +pub trait CapnprotoPayloadEncodeExt { + fn encode_to_bytes(&self) -> Vec; +} + +pub trait CapnprotoPayloadDecode { + type Reader<'a>: FromPointerReader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized; +} + +pub trait CapnprotoPayloadDecodeExt { + fn decode_from_bytes(bytes: &[u8]) -> capnp::Result + where + Self: Sized; +} + +impl CapnprotoPayloadEncodeExt for T +where + T: CapnprotoPayloadEncode, +{ + fn encode_to_bytes(&self) -> Vec { + let mut message = capnp::message::Builder::new_default(); + self.encode_to_builder(message.init_root::>()); + let mut buf = Vec::new(); + capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible"); + buf + } +} + +impl CapnprotoPayloadDecodeExt for T +where + T: CapnprotoPayloadDecode, +{ + fn decode_from_bytes(bytes: &[u8]) -> capnp::Result + where + Self: Sized, + { + let mut cursor = std::io::Cursor::new(bytes); + let reader = capnp::serialize_packed::read_message( + &mut cursor, + capnp::message::ReaderOptions::new(), + )?; + + let reader = reader.get_root::>()?; + T::decode_from_reader(reader) + } +} + +impl CapnprotoPayloadEncode for &'_ T +where + T: CapnprotoPayloadEncode, +{ + type Builder<'a> = T::Builder<'a>; + + fn encode_to_builder(&self, builder: Self::Builder<'_>) { + T::encode_to_builder(self, builder); + } +} + +impl From for DapVersion { + fn from(val: base_capnp::DapVersion) -> Self { + match val { + base_capnp::DapVersion::Draft09 => DapVersion::Draft09, + base_capnp::DapVersion::DraftLatest => DapVersion::Latest, + } + } +} + +impl From for base_capnp::DapVersion { + fn from(value: DapVersion) -> Self { + match value { + DapVersion::Draft09 => base_capnp::DapVersion::Draft09, + DapVersion::Latest => base_capnp::DapVersion::DraftLatest, + } + } +} + +impl CapnprotoPayloadEncode for [u8; 32] { + type Builder<'a> = u8_l32::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + builder.set_fst(u64::from_le_bytes(self[0..8].try_into().unwrap())); + builder.set_snd(u64::from_le_bytes(self[8..16].try_into().unwrap())); + builder.set_thr(u64::from_le_bytes(self[16..24].try_into().unwrap())); + builder.set_frh(u64::from_le_bytes(self[24..32].try_into().unwrap())); + } +} + +impl CapnprotoPayloadDecode for [u8; 32] { + type Reader<'a> = u8_l32::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + let mut array = [0; 32]; + array[0..8].copy_from_slice(&reader.get_fst().to_le_bytes()); + array[8..16].copy_from_slice(&reader.get_snd().to_le_bytes()); + array[16..24].copy_from_slice(&reader.get_thr().to_le_bytes()); + array[24..32].copy_from_slice(&reader.get_frh().to_le_bytes()); + Ok(array) + } +} + +impl CapnprotoPayloadEncode for [u8; 16] { + type Builder<'a> = u8_l16::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + builder.set_fst(u64::from_le_bytes(self[0..8].try_into().unwrap())); + builder.set_snd(u64::from_le_bytes(self[8..16].try_into().unwrap())); + } +} + +impl CapnprotoPayloadDecode for [u8; 16] { + type Reader<'a> = u8_l16::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + let mut array = [0; 16]; + array[0..8].copy_from_slice(&reader.get_fst().to_le_bytes()); + array[8..16].copy_from_slice(&reader.get_snd().to_le_bytes()); + Ok(array) + } +} + +macro_rules! capnp_encode_ids { + ($($id:ident => $inner:ident),*$(,)?) => { + $( + impl CapnprotoPayloadEncode for $id { + type Builder<'a> = $inner::Builder<'a>; + + fn encode_to_builder(&self, builder: Self::Builder<'_>) { + self.0.encode_to_builder(builder) + } + } + + impl CapnprotoPayloadDecode for $id { + type Reader<'a> = $inner::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + <_>::decode_from_reader(reader).map(Self) + } + } + )* + }; +} + +capnp_encode_ids! { + TaskId => u8_l32, + ReportId => u8_l16, + BatchId => u8_l32, + AggregationJobId => u8_l16, +} + +impl CapnprotoPayloadEncode for PartialBatchSelector { + type Builder<'a> = partial_batch_selector::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + match self { + PartialBatchSelector::TimeInterval => builder.set_time_interval(()), + PartialBatchSelector::LeaderSelectedByBatchId { batch_id } => { + batch_id.encode_to_builder(builder.init_leader_selected_by_batch_id()); + } + } + } +} + +impl CapnprotoPayloadDecode for PartialBatchSelector { + type Reader<'a> = partial_batch_selector::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result { + match reader.which()? { + partial_batch_selector::Which::TimeInterval(()) => Ok(Self::TimeInterval), + partial_batch_selector::Which::LeaderSelectedByBatchId(reader) => { + Ok(Self::LeaderSelectedByBatchId { + batch_id: <_>::decode_from_reader(reader?)?, + }) + } + } + } +} + +pub fn encode_list(list: I, mut builder: struct_list::Builder<'_, O>) +where + I: IntoIterator, + O: for<'b> capnp::traits::OwnedStruct< + Builder<'b> = ::Builder<'b>, + >, +{ + for (i, item) in list.into_iter().enumerate() { + item.encode_to_builder(builder.reborrow().get(i.try_into().unwrap())); + } +} + +pub fn decode_list(reader: struct_list::Reader<'_, O>) -> capnp::Result +where + T: CapnprotoPayloadDecode, + C: FromIterator, + O: for<'b> capnp::traits::OwnedStruct = T::Reader<'b>>, +{ + reader.into_iter().map(T::decode_from_reader).collect() +} + +pub fn usize_to_capnp_len(u: usize) -> u32 { + u.try_into() + .expect("capnp can't encode more that u32::MAX of something") +} diff --git a/crates/daphne-service-utils/src/capnproto_payload.rs b/crates/daphne-service-utils/src/capnproto_payload.rs deleted file mode 100644 index e3c2b6aa..00000000 --- a/crates/daphne-service-utils/src/capnproto_payload.rs +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause - -use capnp::traits::{FromPointerBuilder, FromPointerReader}; - -pub trait CapnprotoPayloadEncode { - type Builder<'a>: FromPointerBuilder<'a>; - - fn encode_to_builder(&self, builder: Self::Builder<'_>); -} - -pub trait CapnprotoPayloadEncodeExt { - fn encode_to_bytes(&self) -> Vec; -} - -pub trait CapnprotoPayloadDecode { - type Reader<'a>: FromPointerReader<'a>; - - fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result - where - Self: Sized; -} - -pub trait CapnprotoPayloadDecodeExt { - fn decode_from_bytes(bytes: &[u8]) -> capnp::Result - where - Self: Sized; -} - -impl CapnprotoPayloadEncodeExt for T -where - T: CapnprotoPayloadEncode, -{ - fn encode_to_bytes(&self) -> Vec { - let mut message = capnp::message::Builder::new_default(); - self.encode_to_builder(message.init_root::>()); - let mut buf = Vec::new(); - capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible"); - buf - } -} - -impl CapnprotoPayloadDecodeExt for T -where - T: CapnprotoPayloadDecode, -{ - fn decode_from_bytes(bytes: &[u8]) -> capnp::Result - where - Self: Sized, - { - let mut cursor = std::io::Cursor::new(bytes); - let reader = capnp::serialize_packed::read_message( - &mut cursor, - capnp::message::ReaderOptions::new(), - )?; - - let reader = reader.get_root::>()?; - T::decode_from_reader(reader) - } -} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs index eb8c4cca..0f5b1181 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs @@ -11,7 +11,7 @@ use daphne::{ use serde::{Deserialize, Serialize}; use crate::{ - capnproto_payload::{CapnprotoPayloadDecode, CapnprotoPayloadEncode}, + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadEncode}, durable_request_capnp::{aggregate_store_merge_req, dap_aggregate_share}, durable_requests::ObjectIdFrom, }; @@ -284,9 +284,7 @@ mod test { }; use rand::{thread_rng, Rng}; - use crate::capnproto_payload::{ - CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _, - }; + use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _}; use super::*; diff --git a/crates/daphne-service-utils/src/lib.rs b/crates/daphne-service-utils/src/lib.rs index 92e03846..fbd06364 100644 --- a/crates/daphne-service-utils/src/lib.rs +++ b/crates/daphne-service-utils/src/lib.rs @@ -5,7 +5,7 @@ pub mod bearer_token; #[cfg(feature = "durable_requests")] -pub mod capnproto_payload; +pub mod capnproto; #[cfg(feature = "durable_requests")] pub mod durable_requests; pub mod http_headers; @@ -13,6 +13,15 @@ pub mod http_headers; pub mod test_route_types; // the generated code expects this module to be defined at the root of the library. +#[cfg(feature = "durable_requests")] +#[doc(hidden)] +pub mod base_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!(env!("OUT_DIR"), "/src/capnproto/base_capnp.rs")); +} + #[cfg(feature = "durable_requests")] mod durable_request_capnp { #![allow(dead_code)] diff --git a/crates/daphne-worker/src/durable/mod.rs b/crates/daphne-worker/src/durable/mod.rs index b9a61c0b..1c2fdd2c 100644 --- a/crates/daphne-worker/src/durable/mod.rs +++ b/crates/daphne-worker/src/durable/mod.rs @@ -25,7 +25,7 @@ pub(crate) mod test_state_cleaner; use crate::tracing_utils::shorten_paths; use daphne_service_utils::{ - capnproto_payload::{CapnprotoPayloadDecode, CapnprotoPayloadDecodeExt}, + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadDecodeExt}, durable_requests::bindings::DurableMethod, }; use serde::{Deserialize, Serialize}; diff --git a/crates/daphne-worker/src/storage/mod.rs b/crates/daphne-worker/src/storage/mod.rs index 9065ddb4..99bf5289 100644 --- a/crates/daphne-worker/src/storage/mod.rs +++ b/crates/daphne-worker/src/storage/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod kv; use crate::storage_proxy; use axum::http::StatusCode; use daphne_service_utils::{ - capnproto_payload::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt}, + capnproto::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt}, durable_requests::{bindings::DurableMethod, DurableRequest, ObjectIdFrom}, }; pub(crate) use kv::Kv;