Skip to content

Commit

Permalink
Add retain() and retain_in() to Table
Browse files Browse the repository at this point in the history
  • Loading branch information
cberner committed Mar 17, 2024
1 parent 20aed98 commit c251a1c
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 0 deletions.
8 changes: 8 additions & 0 deletions fuzz/fuzz_targets/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ pub(crate) enum FuzzOperation {
modulus: U64Between<1, 8>,
reversed: bool,
},
Retain {
modulus: U64Between<1, 8>,
},
RetainIn {
start_key: BoundedU64<KEY_SPACE>,
len: BoundedU64<KEY_SPACE>,
modulus: U64Between<1, 8>,
},
Range {
start_key: BoundedU64<KEY_SPACE>,
len: BoundedU64<KEY_SPACE>,
Expand Down
18 changes: 18 additions & 0 deletions fuzz/fuzz_targets/fuzz_redb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ fn handle_multimap_table_op(op: &FuzzOperation, reference: &mut BTreeMap<u64, BT
FuzzOperation::DrainFilter { .. } => {
// no-op. Multimap tables don't support this
}
FuzzOperation::Retain { .. } => {
// no-op. Multimap tables don't support this
}
FuzzOperation::RetainIn { .. } => {
// no-op. Multimap tables don't support this
}
FuzzOperation::Range {
start_key,
len,
Expand Down Expand Up @@ -428,6 +434,18 @@ fn handle_table_op(op: &FuzzOperation, reference: &mut BTreeMap<u64, usize>, tab
panic!();
}
}
FuzzOperation::RetainIn { start_key, len, modulus } => {
let start = start_key.value;
let end = start + len.value;
let modulus = modulus.value;
table.retain_in(|x, _| x % modulus == 0, start..end)?;
reference.retain(|x, _| (*x < start || *x >= end) || *x % modulus == 0);
}
FuzzOperation::Retain { modulus } => {
let modulus = modulus.value;
table.retain(|x, _| x % modulus == 0)?;
reference.retain(|x, _| *x % modulus == 0);
}
FuzzOperation::Range {
start_key,
len,
Expand Down
24 changes: 24 additions & 0 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ impl<'txn, K: Key + 'static, V: Value + 'static> Table<'txn, K, V> {
.map(DrainFilter::new)
}

/// Applies `predicate` to all key-value pairs. All entries for which
/// `predicate` evaluates to `false` are removed.
///
pub fn retain<F: for<'f> Fn(K::SelfType<'f>, V::SelfType<'f>) -> bool>(
&mut self,
predicate: F,
) -> Result {
self.tree.retain_in::<K::SelfType<'_>, F>(predicate, ..)
}

/// Applies `predicate` to all key-value pairs in the range `start..end`. All entries for which
/// `predicate` evaluates to `false` are removed.
///
pub fn retain_in<'a, KR, F: for<'f> Fn(K::SelfType<'f>, V::SelfType<'f>) -> bool>(
&mut self,
predicate: F,
range: impl RangeBounds<KR> + 'a,
) -> Result
where
KR: Borrow<K::SelfType<'a>> + 'a,
{
self.tree.retain_in(predicate, range)
}

/// Insert mapping of the given key to the given value
///
/// Returns the old value, if the key was present in the table
Expand Down
23 changes: 23 additions & 0 deletions src/tree_store/btree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,29 @@ impl<K: Key + 'static, V: Value + 'static> BtreeMut<'_, K, V> {
Ok(result)
}

pub(crate) fn retain_in<'a, KR, F: for<'f> Fn(K::SelfType<'f>, V::SelfType<'f>) -> bool>(
&mut self,
predicate: F,
range: impl RangeBounds<KR> + 'a,
) -> Result
where
KR: Borrow<K::SelfType<'a>> + 'a,
{
let iter = self.range(&range)?;
let mut freed = vec![];
let mut operation: MutateHelper<'_, '_, K, V> =
MutateHelper::new_do_not_modify(&mut self.root, self.mem.clone(), &mut freed);
for entry in iter {
let entry = entry?;
if !predicate(entry.key(), entry.value()) {
assert!(operation.delete(&entry.key())?.is_some());
}
}
self.freed_pages.lock().unwrap().extend_from_slice(&freed);

Ok(())
}

pub(crate) fn len(&self) -> Result<u64> {
self.read_tree()?.len()
}
Expand Down
69 changes: 69 additions & 0 deletions tests/basic_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,75 @@ fn drain() {
write_txn.abort().unwrap();
}

#[test]
fn retain() {
let tmpfile = create_tempfile();
let db = Database::create(tmpfile.path()).unwrap();
let write_txn = db.begin_write().unwrap();
{
let mut table = write_txn.open_table(U64_TABLE).unwrap();
for i in 0..10 {
table.insert(&i, &i).unwrap();
}
// Test retain uncommitted data
table.retain(|k, _| k >= 5).unwrap();
for i in 0..5 {
assert!(table.insert(&i, &i).unwrap().is_none());
}
assert_eq!(table.len().unwrap(), 10);

// Test matching on the value
table.retain(|_, v| v >= 5).unwrap();
for i in 0..5 {
assert!(table.insert(&i, &i).unwrap().is_none());
}
assert_eq!(table.len().unwrap(), 10);

// Test retain_in
table.retain_in(|_, _| false, ..5).unwrap();
for i in 0..5 {
assert!(table.insert(&i, &i).unwrap().is_none());
}
assert_eq!(table.len().unwrap(), 10);
}
write_txn.commit().unwrap();

let write_txn = db.begin_write().unwrap();
{
let mut table = write_txn.open_table(U64_TABLE).unwrap();
assert_eq!(table.len().unwrap(), 10);
table.retain(|x, _| x >= 5).unwrap();
assert_eq!(table.len().unwrap(), 5);

let mut i = 5u64;
for item in table.range(0..10).unwrap() {
let (k, v) = item.unwrap();
assert_eq!(i, k.value());
assert_eq!(i, v.value());
i += 1;
}
}
write_txn.abort().unwrap();

let write_txn = db.begin_write().unwrap();
{
let mut table = write_txn.open_table(U64_TABLE).unwrap();
table.retain(|x, _| x % 2 == 0).unwrap();
}
write_txn.commit().unwrap();

let read_txn = db.begin_write().unwrap();
{
let table = read_txn.open_table(U64_TABLE).unwrap();
assert_eq!(table.len().unwrap(), 5);
for entry in table.iter().unwrap() {
let (k, v) = entry.unwrap();
assert_eq!(k.value() % 2, 0);
assert_eq!(k.value(), v.value());
}
}
}

#[test]
fn drain_filter() {
let tmpfile = create_tempfile();
Expand Down

0 comments on commit c251a1c

Please sign in to comment.