Skip to content

Commit

Permalink
Add filter support to shrink domain size
Browse files Browse the repository at this point in the history
  • Loading branch information
myl7 committed Apr 11, 2024
1 parent a3a8515 commit 6e5d85c
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 32 deletions.
19 changes: 8 additions & 11 deletions benches/dcf_full_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ use fss_rs::group::Group;

fn from_domain_range_size<const DOM_SZ: usize, const LAMBDA: usize, const CIPHER_N: usize>(
c: &mut Criterion,
filter_bitn: usize,
) {
let mut keys = [[0; 16]; CIPHER_N];
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<LAMBDA, CIPHER_N>::new(keys_iter);
let dcf = DcfImpl::<DOM_SZ, LAMBDA, _>::new(prg);
let dcf = DcfImpl::<DOM_SZ, LAMBDA, _>::new_with_filter(prg, filter_bitn);

let mut s0s = [[0; LAMBDA]; 2];
s0s.iter_mut().for_each(|s0| thread_rng().fill_bytes(s0));
Expand All @@ -35,16 +36,12 @@ fn from_domain_range_size<const DOM_SZ: usize, const LAMBDA: usize, const CIPHER

let k = dcf.gen(&f, [&s0s[0], &s0s[1]]);

// TODO: Bit mask and 1 bit drop
let mut ys = vec![ByteGroup::zero(); 2usize.pow(DOM_SZ as u32 * 8 - 1)];
let mut ys = vec![ByteGroup::zero(); 1 << filter_bitn];
let mut ys_iter: Vec<_> = ys.iter_mut().collect();

c.bench_with_input(
BenchmarkId::new(
"dcf full_eval",
format!("{}b -> {}B", DOM_SZ * 8 - 1, LAMBDA),
),
&(DOM_SZ, LAMBDA),
BenchmarkId::new("dcf full_eval", format!("{}b -> {}B", filter_bitn, LAMBDA)),
&(DOM_SZ, LAMBDA, filter_bitn),
|b, &_| {
b.iter(|| {
dcf.full_eval(false, &k, &mut ys_iter);
Expand All @@ -55,9 +52,9 @@ fn from_domain_range_size<const DOM_SZ: usize, const LAMBDA: usize, const CIPHER

// TODO: Bit mask
fn bench(c: &mut Criterion) {
from_domain_range_size::<2, 16, 4>(c);
// from_domain_range_size::<2, 16, 4>(c); // 18
// from_domain_range_size::<2, 16, 4>(c); // 20
from_domain_range_size::<2, 16, 4>(c, 16);
from_domain_range_size::<3, 16, 4>(c, 18);
from_domain_range_size::<3, 16, 4>(c, 20);
}

criterion_group!(benches, bench);
Expand Down
16 changes: 7 additions & 9 deletions benches/dpf_full_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ use fss_rs::group::Group;

fn from_domain_range_size<const DOM_SZ: usize, const LAMBDA: usize, const CIPHER_N: usize>(
c: &mut Criterion,
filter_bitn: usize,
) {
let mut keys = [[0; 16]; CIPHER_N];
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<LAMBDA, CIPHER_N>::new(keys_iter);
let dpf = DpfImpl::<DOM_SZ, LAMBDA, _>::new(prg);
let dpf = DpfImpl::<DOM_SZ, LAMBDA, _>::new_with_filter(prg, filter_bitn);

let mut s0s = [[0; LAMBDA]; 2];
s0s.iter_mut().for_each(|s0| thread_rng().fill_bytes(s0));
Expand All @@ -32,14 +33,11 @@ fn from_domain_range_size<const DOM_SZ: usize, const LAMBDA: usize, const CIPHER
let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);

// TODO: Bit mask and 1 bit drop
let mut ys = vec![ByteGroup::zero(); 2usize.pow(DOM_SZ as u32 * 8 - 1)];
let mut ys = vec![ByteGroup::zero(); 1 << filter_bitn];
let mut ys_iter: Vec<_> = ys.iter_mut().collect();

c.bench_with_input(
BenchmarkId::new(
"dpf full_eval",
format!("{}b -> {}B", DOM_SZ * 8 - 1, LAMBDA),
),
BenchmarkId::new("dpf full_eval", format!("{}b -> {}B", filter_bitn, LAMBDA)),
&(DOM_SZ, LAMBDA),
|b, &_| {
b.iter(|| {
Expand All @@ -51,9 +49,9 @@ fn from_domain_range_size<const DOM_SZ: usize, const LAMBDA: usize, const CIPHER

// TODO: Bit mask
fn bench(c: &mut Criterion) {
from_domain_range_size::<2, 16, 4>(c);
// from_domain_range_size::<2, 16, 4>(c); // 18
// from_domain_range_size::<2, 16, 4>(c); // 20
from_domain_range_size::<2, 16, 4>(c, 16);
from_domain_range_size::<3, 16, 4>(c, 18);
from_domain_range_size::<3, 16, 4>(c, 20);
}

criterion_group!(benches, bench);
Expand Down
82 changes: 76 additions & 6 deletions src/dcf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,24 @@ where
P: Prg<LAMBDA>,
{
prg: P,
filter_bitn: usize,
}

impl<const DOM_SZ: usize, const LAMBDA: usize, P> DcfImpl<DOM_SZ, LAMBDA, P>
where
P: Prg<LAMBDA>,
{
pub fn new(prg: P) -> Self {
Self { prg }
Self {
prg,
filter_bitn: 8 * DOM_SZ,
}
}

// TODO
pub fn new_with_filter(prg: P, filter_bitn: usize) -> Self {
assert!(filter_bitn <= 8 * DOM_SZ && filter_bitn > 1);
Self { prg, filter_bitn }
}
}

Expand All @@ -94,7 +104,7 @@ where
{
fn gen(&self, f: &CmpFn<DOM_SZ, LAMBDA, G>, s0s: [&[u8; LAMBDA]; 2]) -> Share<LAMBDA, G> {
// The bit size of `$\alpha$`
let n = 8 * DOM_SZ;
let n = self.filter_bitn;
let mut v_alpha = G::zero();
// Set `$s^{(1)}_0$` and `$s^{(1)}_1$`
let mut ss_prev = [*s0s[0], *s0s[1]];
Expand Down Expand Up @@ -175,7 +185,7 @@ where

fn full_eval(&self, b: bool, k: &Share<LAMBDA, G>, ys: &mut [&mut G]) {
let n = k.cws.len();
assert_eq!(n, DOM_SZ * 8);
assert_eq!(n, self.filter_bitn);

let s = k.s0s[0];
let v = G::zero();
Expand Down Expand Up @@ -221,7 +231,7 @@ where
) where
G: Group<LAMBDA>,
{
assert_eq!(ys.len(), 1 << (DOM_SZ * 8 - layer_i));
assert_eq!(ys.len(), 1 << (self.filter_bitn - layer_i));
if ys.len() == 1 {
*ys[0] =
v + (G::from(s) + if t { k.cw_np1.clone() } else { G::zero() }).add_inverse_if(b);
Expand Down Expand Up @@ -257,7 +267,7 @@ where
G: Group<LAMBDA>,
{
let n = k.cws.len();
assert_eq!(n, DOM_SZ * 8);
assert_eq!(n, self.filter_bitn);
let v = y;

let mut s_prev = k.s0s[0];
Expand Down Expand Up @@ -382,6 +392,38 @@ mod tests {
assert_eq!(ys0, ys1);
}

#[test]
fn test_dcf_gen_then_eval_with_filter() {
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<16, 16, _>::new_with_filter(prg, 127);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
alpha: ALPHAS[2].to_owned(),
beta: BETA.clone().into(),
bound: BoundState::GtBeta,
};
let k = dcf.gen(&f, [&s0s[0], &s0s[1]]);
let mut k0 = k.clone();
k0.s0s = vec![k0.s0s[0]];
let mut k1 = k.clone();
k1.s0s = vec![k1.s0s[1]];
let mut ys0 = vec![ByteGroup::zero(); ALPHAS.len()];
let mut ys1 = vec![ByteGroup::zero(); ALPHAS.len()];
dcf.eval(false, &k0, ALPHAS, &mut ys0.iter_mut().collect::<Vec<_>>());
dcf.eval(true, &k1, ALPHAS, &mut ys1.iter_mut().collect::<Vec<_>>());
ys0.iter_mut()
.zip(ys1.iter())
.for_each(|(y0, y1)| *y0 += y1.clone());
ys1 = vec![
ByteGroup::zero(),
ByteGroup::zero(),
ByteGroup::zero(),
ByteGroup::zero(),
BETA.clone().into(),
];
assert_eq!(ys0, ys1);
}

#[test]
fn test_dcf_gen_then_eval_not_zeros() {
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
Expand All @@ -406,7 +448,7 @@ mod tests {
}

#[test]
fn test_dcf_full_domain_eval() {
fn test_dcf_full_eval() {
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<2, 16, _>::new(prg);
Expand All @@ -430,4 +472,32 @@ mod tests {
assert_eq!(y0, y0_full);
}
}

#[test]
fn test_dcf_full_eval_with_filter() {
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<2, 16, _>::new_with_filter(prg, 15);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
alpha: x,
beta: BETA.clone().into(),
bound: BoundState::LtBeta,
};
let k = dcf.gen(&f, [&s0s[0], &s0s[1]]);
let mut k0 = k.clone();
k0.s0s = vec![k0.s0s[0]];
let xs: Vec<_> = (0u16..=u16::MAX >> 1)
.map(|i| (i << 1).to_be_bytes())
.collect();
assert_eq!(xs.len(), 1 << 15);
let xs0: Vec<_> = xs.iter().collect();
let mut ys0 = vec![ByteGroup::zero(); 1 << 15];
let mut ys0_full = vec![ByteGroup::zero(); 1 << 15];
dcf.eval(false, &k0, &xs0, &mut ys0.iter_mut().collect::<Vec<_>>());
dcf.full_eval(false, &k0, &mut ys0_full.iter_mut().collect::<Vec<_>>());
for (y0, y0_full) in ys0.iter().zip(ys0_full.iter()) {
assert_eq!(y0, y0_full);
}
}
}
80 changes: 74 additions & 6 deletions src/dpf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,24 @@ where
P: Prg<LAMBDA>,
{
prg: P,
filter_bitn: usize,
}

impl<const DOM_SZ: usize, const LAMBDA: usize, P> DpfImpl<DOM_SZ, LAMBDA, P>
where
P: Prg<LAMBDA>,
{
pub fn new(prg: P) -> Self {
Self { prg }
Self {
prg,
filter_bitn: 8 * DOM_SZ,
}
}

// TODO
pub fn new_with_filter(prg: P, filter_bitn: usize) -> Self {
assert!(filter_bitn <= 8 * DOM_SZ && filter_bitn > 1);
Self { prg, filter_bitn }
}
}

Expand All @@ -68,7 +78,7 @@ where
{
fn gen(&self, f: &PointFn<DOM_SZ, LAMBDA, G>, s0s: [&[u8; LAMBDA]; 2]) -> Share<LAMBDA, G> {
// The bit size of `$\alpha$`
let n = 8 * DOM_SZ;
let n = self.filter_bitn;
// Set `$s^{(1)}_0$` and `$s^{(1)}_1$`
let mut ss_prev = [s0s[0].to_owned(), s0s[1].to_owned()];
// Set `$t^{(0)}_0$` and `$t^{(0)}_1$`
Expand Down Expand Up @@ -128,7 +138,7 @@ where

fn full_eval(&self, b: bool, k: &Share<LAMBDA, G>, ys: &mut [&mut G]) {
let n = k.cws.len();
assert_eq!(n, DOM_SZ * 8);
assert_eq!(n, self.filter_bitn);

let s = k.s0s[0];
let t = b;
Expand Down Expand Up @@ -173,7 +183,7 @@ where
) where
G: Group<LAMBDA>,
{
assert_eq!(ys.len(), 1 << (DOM_SZ * 8 - layer_i));
assert_eq!(ys.len(), 1 << (self.filter_bitn - layer_i));
if ys.len() == 1 {
*ys[0] = (Into::<G>::into(s) + if t { k.cw_np1.clone() } else { G::zero() })
.add_inverse_if(b);
Expand Down Expand Up @@ -205,7 +215,7 @@ where
G: Group<LAMBDA>,
{
let n = k.cws.len();
assert_eq!(n, DOM_SZ * 8);
assert_eq!(n, self.filter_bitn);
let v = y;

let mut s_prev = k.s0s[0];
Expand Down Expand Up @@ -280,6 +290,37 @@ mod tests {
assert_eq!(ys0, ys1);
}

#[test]
fn test_dpf_gen_then_eval_with_filter() {
let prg = Aes256HirosePrg::<16, 1>::new(std::array::from_fn(|i| KEYS[i]));
let dpf = DpfImpl::<16, 16, _>::new_with_filter(prg, 127);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = PointFn {
alpha: ALPHAS[2].to_owned(),
beta: BETA.clone().into(),
};
let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
let mut k0 = k.clone();
k0.s0s = vec![k0.s0s[0]];
let mut k1 = k.clone();
k1.s0s = vec![k1.s0s[1]];
let mut ys0 = vec![ByteGroup::zero(); ALPHAS.len()];
let mut ys1 = vec![ByteGroup::zero(); ALPHAS.len()];
dpf.eval(false, &k0, ALPHAS, &mut ys0.iter_mut().collect::<Vec<_>>());
dpf.eval(true, &k1, ALPHAS, &mut ys1.iter_mut().collect::<Vec<_>>());
ys0.iter_mut()
.zip(ys1.iter())
.for_each(|(y0, y1)| *y0 += y1.clone());
ys1 = vec![
ByteGroup::zero(),
ByteGroup::zero(),
BETA.clone().into(),
BETA.clone().into(),
ByteGroup::zero(),
];
assert_eq!(ys0, ys1);
}

#[test]
fn test_dpf_gen_then_eval_not_zeros() {
let prg = Aes256HirosePrg::<16, 1>::new(std::array::from_fn(|i| KEYS[i]));
Expand All @@ -303,7 +344,7 @@ mod tests {
}

#[test]
fn test_dpf_full_domain_eval() {
fn test_dpf_full_eval() {
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
let prg = Aes256HirosePrg::<16, 1>::new(std::array::from_fn(|i| KEYS[i]));
let dpf = DpfImpl::<2, 16, _>::new(prg);
Expand All @@ -326,4 +367,31 @@ mod tests {
assert_eq!(y0, y0_full);
}
}

#[test]
fn test_dpf_full_eval_with_filter() {
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
let prg = Aes256HirosePrg::<16, 1>::new(std::array::from_fn(|i| KEYS[i]));
let dpf = DpfImpl::<2, 16, _>::new_with_filter(prg, 15);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = PointFn {
alpha: x,
beta: BETA.clone().into(),
};
let k = dpf.gen(&f, [&s0s[0], &s0s[1]]);
let mut k0 = k.clone();
k0.s0s = vec![k0.s0s[0]];
let xs: Vec<_> = (0u16..=u16::MAX >> 1)
.map(|i| (i << 1).to_be_bytes())
.collect();
assert_eq!(xs.len(), 1 << 15);
let xs0: Vec<_> = xs.iter().collect();
let mut ys0 = vec![ByteGroup::zero(); 1 << 15];
let mut ys0_full = vec![ByteGroup::zero(); 1 << 15];
dpf.eval(false, &k0, &xs0, &mut ys0.iter_mut().collect::<Vec<_>>());
dpf.full_eval(false, &k0, &mut ys0_full.iter_mut().collect::<Vec<_>>());
for (y0, y0_full) in ys0.iter().zip(ys0_full.iter()) {
assert_eq!(y0, y0_full);
}
}
}

0 comments on commit 6e5d85c

Please sign in to comment.