Skip to content

Commit

Permalink
Make CapnprotoPayload{En,De}code composable.
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Jan 2, 2025
1 parent 3f9f5a1 commit 75ba2da
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
2 changes: 1 addition & 1 deletion crates/daphne-server/src/storage_proxy_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl<'d, B: DurableMethod + Debug, P: AsRef<[u8]>> RequestBuilder<'d, B, P> {

impl<'d, B: DurableMethod> RequestBuilder<'d, B, [u8; 0]> {
pub fn encode<T: CapnprotoPayloadEncode>(self, payload: &T) -> RequestBuilder<'d, B, Vec<u8>> {
self.with_body(payload.encode_to_bytes().unwrap())
self.with_body(payload.encode_to_bytes())
}

pub fn with_body<T: AsRef<[u8]>>(self, payload: T) -> RequestBuilder<'d, B, T> {
Expand Down
24 changes: 15 additions & 9 deletions crates/daphne-service-utils/src/capnproto_payload.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use capnp::traits::{FromPointerBuilder, FromPointerReader};

pub trait CapnprotoPayloadEncode {
fn encode_to_builder(&self) -> capnp::message::Builder<capnp::message::HeapAllocator>;
type Builder<'a>: FromPointerBuilder<'a>;

fn encode_to_builder(&self, builder: Self::Builder<'_>);
}

pub trait CapnprotoPayloadEncodeExt {
fn encode_to_bytes(&self) -> capnp::Result<Vec<u8>>;
fn encode_to_bytes(&self) -> Vec<u8>;
}

pub trait CapnprotoPayloadDecode {
fn decode_from_reader(
reader: capnp::message::Reader<capnp::serialize::OwnedSegments>,
) -> capnp::Result<Self>
type Reader<'a>: FromPointerReader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self>
where
Self: Sized;
}
Expand All @@ -27,11 +31,12 @@ impl<T> CapnprotoPayloadEncodeExt for T
where
T: CapnprotoPayloadEncode,
{
fn encode_to_bytes(&self) -> capnp::Result<Vec<u8>> {
fn encode_to_bytes(&self) -> Vec<u8> {
let mut message = capnp::message::Builder::new_default();
self.encode_to_builder(message.init_root::<T::Builder<'_>>());
let mut buf = Vec::new();
let message = self.encode_to_builder();
capnp::serialize_packed::write_message(&mut buf, &message)?;
Ok(buf)
capnp::serialize_packed::write_message(&mut buf, &message).expect("infalible");
buf
}
}

Expand All @@ -49,6 +54,7 @@ where
capnp::message::ReaderOptions::new(),
)?;

let reader = reader.get_root::<T::Reader<'_>>()?;
T::decode_from_reader(reader)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ pub struct AggregateStoreMergeOptions {
}

impl CapnprotoPayloadEncode for AggregateStoreMergeReq {
fn encode_to_builder(&self) -> capnp::message::Builder<capnp::message::HeapAllocator> {
type Builder<'a> = aggregate_store_merge_req::Builder<'a>;

fn encode_to_builder(&self, mut builder: Self::Builder<'_>) {
let Self {
contained_reports,
agg_share_delta,
options,
} = self;
let mut message = capnp::message::Builder::new_default();
let mut request = message.init_root::<aggregate_store_merge_req::Builder>();
{
let mut contained_reports = request.reborrow().init_contained_reports(
let mut contained_reports = builder.reborrow().init_contained_reports(
contained_reports
.len()
.try_into()
Expand All @@ -94,7 +94,7 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq {
}
}
{
let mut agg_share_delta_packet = request.reborrow().init_agg_share_delta();
let mut agg_share_delta_packet = builder.reborrow().init_agg_share_delta();
agg_share_delta_packet.set_report_count(agg_share_delta.report_count);
agg_share_delta_packet.set_min_time(agg_share_delta.min_time);
agg_share_delta_packet.set_max_time(agg_share_delta.max_time);
Expand Down Expand Up @@ -157,20 +157,18 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq {
let AggregateStoreMergeOptions {
skip_replay_protection,
} = options;
let mut options_packet = request.init_options();
let mut options_packet = builder.init_options();
options_packet.set_skip_replay_protection(*skip_replay_protection);
}
message
}
}

impl CapnprotoPayloadDecode for AggregateStoreMergeReq {
fn decode_from_reader(
reader: capnp::message::Reader<capnp::serialize::OwnedSegments>,
) -> capnp::Result<Self> {
let request = reader.get_root::<aggregate_store_merge_req::Reader>()?;
type Reader<'a> = aggregate_store_merge_req::Reader<'a>;

fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result<Self> {
let agg_share_delta = {
let agg_share_delta = request.get_agg_share_delta()?;
let agg_share_delta = reader.get_agg_share_delta()?;
let data = {
macro_rules! make_decode {
($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => {
Expand Down Expand Up @@ -238,8 +236,7 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq {
}
};
let contained_reports = {
request
.reborrow()
reader
.get_contained_reports()?
.into_iter()
.map(|report| {
Expand All @@ -257,7 +254,7 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq {
contained_reports,
agg_share_delta,
options: AggregateStoreMergeOptions {
skip_replay_protection: request.get_options()?.get_skip_replay_protection(),
skip_replay_protection: reader.get_options()?.get_skip_replay_protection(),
},
})
}
Expand Down Expand Up @@ -352,8 +349,7 @@ mod test {
},
};
let other =
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes().unwrap())
.unwrap();
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes()).unwrap();
assert_eq!(this, other);
}
}
Expand Down Expand Up @@ -411,8 +407,7 @@ mod test {
},
};
let other =
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes().unwrap())
.unwrap();
AggregateStoreMergeReq::decode_from_bytes(&this.encode_to_bytes()).unwrap();
assert_eq!(this, other);
}
}
Expand Down

0 comments on commit 75ba2da

Please sign in to comment.