diff --git a/shardtree/CHANGELOG.md b/shardtree/CHANGELOG.md index 33d77eb..795c587 100644 --- a/shardtree/CHANGELOG.md +++ b/shardtree/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to Rust's notion of ## Unreleased +### Changed +- `shardtree::BatchInsertionResult.max_insert_position` now has type `Position` + instead of `Option` (all APIs return `Option` + and use `None` at that level to represent "no leaves inserted"). +- `shardtree::LocatedTree::from_parts` now returns `Option` (returning + `None` if the provided `Address` and `Tree` are inconsistent). + ## [0.5.0] - 2024-10-04 This release includes a significant refactoring and rework of several methods diff --git a/shardtree/src/batch.rs b/shardtree/src/batch.rs index cbd728d..59c8714 100644 --- a/shardtree/src/batch.rs +++ b/shardtree/src/batch.rs @@ -71,8 +71,8 @@ impl< values = res.remainder; subtree_root_addr = subtree_root_addr.next_at_level(); - max_insert_position = res.max_insert_position; - start = max_insert_position.unwrap() + 1; + max_insert_position = Some(res.max_insert_position); + start = res.max_insert_position + 1; all_incomplete.append(&mut res.incomplete); } else { break; @@ -102,7 +102,7 @@ pub struct BatchInsertionResult)> /// [`Node::Nil`]: crate::tree::Node::Nil pub incomplete: Vec, /// The maximum position at which a leaf was inserted. - pub max_insert_position: Option, + pub max_insert_position: Position, /// The positions of all leaves with [`Retention::Checkpoint`] retention that were inserted. pub checkpoints: BTreeMap, /// The unconsumed remainder of the iterator from which leaves were inserted, if the tree @@ -243,7 +243,7 @@ impl LocatedPrunableTree { subtree: to_insert, contains_marked, incomplete, - max_insert_position: Some(last_position), + max_insert_position: last_position, checkpoints, remainder: values, }, diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index 7779cb3..96fbc6e 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -406,13 +406,16 @@ impl< /// Adds a checkpoint at the rightmost leaf state of the tree. pub fn checkpoint(&mut self, checkpoint_id: C) -> Result> { + /// Pre-condition: `root_addr` must be the address of `root`. fn go( root_addr: Address, root: &PrunableTree, ) -> Option<(PrunableTree, Position)> { match &root.0 { Node::Parent { ann, left, right } => { - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); go(r_addr, right).map_or_else( || { go(l_addr, left).map(|(new_left, pos)| { @@ -765,7 +768,10 @@ impl< // Compute the roots of the left and right children and hash them together. // We skip computation in any subtrees that will not have data included in // the final result. - let (l_addr, r_addr) = cap.root_addr.children().unwrap(); + let (l_addr, r_addr) = cap + .root_addr + .children() + .expect("has children because we checked `cap.root` is a parent"); let l_result = if r_addr.contains(&target_addr) { None } else { @@ -1162,7 +1168,8 @@ impl< cur_addr = cur_addr.parent(); } - Ok(MerklePath::from_parts(witness, position).unwrap()) + Ok(MerklePath::from_parts(witness, position) + .expect("witness has length DEPTH because we extended it to the root")) } fn witness_internal( diff --git a/shardtree/src/prunable.rs b/shardtree/src/prunable.rs index fed767f..2514bdd 100644 --- a/shardtree/src/prunable.rs +++ b/shardtree/src/prunable.rs @@ -358,6 +358,7 @@ impl LocatedPrunableTree { /// Note that no actual leaf value may exist at this position, as it may have previously been /// pruned. pub fn max_position(&self) -> Option { + /// Pre-condition: `addr` must be the address of `root`. fn go( addr: Address, root: &Tree>, (H, RetentionFlags)>, @@ -369,7 +370,9 @@ impl LocatedPrunableTree { if ann.is_some() { Some(addr.max_position()) } else { - let (l_addr, r_addr) = addr.children().unwrap(); + let (l_addr, r_addr) = addr + .children() + .expect("has children because we checked `root` is a parent"); go(r_addr, right.as_ref()).or_else(|| go(l_addr, left.as_ref())) } } @@ -406,6 +409,7 @@ impl LocatedPrunableTree { /// Returns the positions of marked leaves in the tree. pub fn marked_positions(&self) -> BTreeSet { + /// Pre-condition: `root_addr` must be the address of `root`. fn go( root_addr: Address, root: &PrunableTree, @@ -413,7 +417,9 @@ impl LocatedPrunableTree { ) { match &root.0 { Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); go(l_addr, left.as_ref(), acc); go(r_addr, right.as_ref(), acc); } @@ -440,8 +446,10 @@ impl LocatedPrunableTree { /// Returns either the witness for the leaf at the specified position, or an error that /// describes the causes of failure. pub fn witness(&self, position: Position, truncate_at: Position) -> Result, QueryError> { - // traverse down to the desired leaf position, and then construct - // the authentication path on the way back up. + /// Traverse down to the desired leaf position, and then construct + /// the authentication path on the way back up. + // + /// Pre-condition: `root_addr` must be the address of `root`. fn go( root: &PrunableTree, root_addr: Address, @@ -450,7 +458,9 @@ impl LocatedPrunableTree { ) -> Result, Vec
> { match &root.0 { Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); if root_addr.level() > 1.into() { let r_start = r_addr.position_range_start(); if position < r_start { @@ -525,6 +535,7 @@ impl LocatedPrunableTree { /// subtree root with the specified position as its maximum position exists, or `None` /// otherwise. pub fn truncate_to_position(&self, position: Position) -> Option { + /// Pre-condition: `root_addr` must be the address of `root`. fn go( position: Position, root_addr: Address, @@ -532,7 +543,9 @@ impl LocatedPrunableTree { ) -> Option> { match &root.0 { Node::Parent { ann, left, right } => { - let (l_child, r_child) = root_addr.children().unwrap(); + let (l_child, r_child) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); if position < r_child.position_range_start() { // we are truncating within the range of the left node, so recurse // to the left to truncate the left child and then reconstruct the @@ -586,8 +599,10 @@ impl LocatedPrunableTree { subtree: Self, contains_marked: bool, ) -> Result<(Self, Vec), InsertionError> { - // A function to recursively dig into the tree, creating a path downward and introducing - // empty nodes as necessary until we can insert the provided subtree. + /// A function to recursively dig into the tree, creating a path downward and introducing + /// empty nodes as necessary until we can insert the provided subtree. + /// + /// Pre-condition: `root_addr` must be the address of `into`. #[allow(clippy::type_complexity)] fn go( root_addr: Address, @@ -694,7 +709,9 @@ impl LocatedPrunableTree { Tree(Node::Parent { ann, left, right }) => { // In this case, we have an existing parent but we need to dig down farther // before we can insert the subtree that we're carrying for insertion. - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `into` is a parent"); if l_addr.contains(&subtree.root_addr) { let (new_left, incomplete) = go(l_addr, left.as_ref(), subtree, contains_marked)?; @@ -770,7 +787,7 @@ impl LocatedPrunableTree { if r.remainder.next().is_some() { Err(InsertionError::TreeFull) } else { - Ok((r.subtree, r.max_insert_position.unwrap(), checkpoint_id)) + Ok((r.subtree, r.max_insert_position, checkpoint_id)) } }) } @@ -892,6 +909,7 @@ impl LocatedPrunableTree { /// Clears the specified retention flags at all positions specified, pruning any branches /// that no longer need to be retained. pub fn clear_flags(&self, to_clear: BTreeMap) -> Self { + /// Pre-condition: `root_addr` must be the address of `root`. fn go( to_clear: &[(Position, RetentionFlags)], root_addr: Address, @@ -903,7 +921,9 @@ impl LocatedPrunableTree { } else { match &root.0 { Node::Parent { ann, left, right } => { - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); let p = to_clear.partition_point(|(p, _)| p < &l_addr.position_range_end()); trace!( @@ -1228,7 +1248,7 @@ mod tests { root in arb_prunable_tree(arb_char_str(), 8, 2^6) ) { let root_addr = Address::from_parts(Level::from(7), 0); - let tree = LocatedTree::from_parts(root_addr, root); + let tree = LocatedTree::from_parts(root_addr, root).unwrap(); let (to_clear, to_retain) = tree.flag_positions().into_iter().enumerate().fold( (BTreeMap::new(), BTreeMap::new()), diff --git a/shardtree/src/store/caching.rs b/shardtree/src/store/caching.rs index 1ac3dca..f472760 100644 --- a/shardtree/src/store/caching.rs +++ b/shardtree/src/store/caching.rs @@ -46,9 +46,11 @@ where let _ = cache.put_cap(backend.get_cap()?); backend.with_checkpoints(backend.checkpoint_count()?, |checkpoint_id, checkpoint| { + // TODO: Once MSRV is at least 1.82, replace this (and similar `expect()`s below) with: + // `let Ok(_) = cache.add_checkpoint(checkpoint_id.clone(), checkpoint.clone());` cache .add_checkpoint(checkpoint_id.clone(), checkpoint.clone()) - .unwrap(); + .expect("error type is Infallible"); Ok(()) })?; @@ -74,26 +76,37 @@ where } self.deferred_actions.clear(); - for shard_root in self.cache.get_shard_roots().unwrap() { + for shard_root in self + .cache + .get_shard_roots() + .expect("error type is Infallible") + { self.backend.put_shard( self.cache .get_shard(shard_root) - .unwrap() + .expect("error type is Infallible") .expect("known address"), )?; } - self.backend.put_cap(self.cache.get_cap().unwrap())?; + self.backend + .put_cap(self.cache.get_cap().expect("error type is Infallible"))?; - let mut checkpoints = Vec::with_capacity(self.cache.checkpoint_count().unwrap()); + let mut checkpoints = Vec::with_capacity( + self.cache + .checkpoint_count() + .expect("error type is Infallible"), + ); self.cache .with_checkpoints( - self.cache.checkpoint_count().unwrap(), + self.cache + .checkpoint_count() + .expect("error type is Infallible"), |checkpoint_id, checkpoint| { checkpoints.push((checkpoint_id.clone(), checkpoint.clone())); Ok(()) }, ) - .unwrap(); + .expect("error type is Infallible"); for (checkpoint_id, checkpoint) in checkpoints { self.backend.add_checkpoint(checkpoint_id, checkpoint)?; } diff --git a/shardtree/src/tree.rs b/shardtree/src/tree.rs index af33295..aec92f7 100644 --- a/shardtree/src/tree.rs +++ b/shardtree/src/tree.rs @@ -197,9 +197,35 @@ pub struct LocatedTree { } impl LocatedTree { - /// Constructs a new LocatedTree from its constituent parts - pub fn from_parts(root_addr: Address, root: Tree) -> Self { - LocatedTree { root_addr, root } + /// Constructs a new LocatedTree from its constituent parts. + /// + /// Returns `None` if `root_addr` is inconsistent with `root` (in particular, if the + /// level of `root_addr` is too small to contain `tree`). + pub fn from_parts(root_addr: Address, root: Tree) -> Option { + // In order to meet various pre-conditions throughout the crate, we require that + // no `Node::Parent` in `root` has a level of 0 relative to `root_addr`. + fn is_consistent(addr: Address, root: &Tree) -> bool { + match (&root.0, addr.children()) { + // Found an inconsistency! + (Node::Parent { .. }, None) => false, + // Check consistency of children recursively. + (Node::Parent { left, right, .. }, Some((l_addr, r_addr))) => { + is_consistent(l_addr, left) && is_consistent(r_addr, right) + } + + // Leaves are technically allowed to occur at any level, so we do not + // require `addr` to have no children. + (Node::Leaf { .. }, _) => true, + + // Nil nodes have no information, so we cannot verify that the data it + // represents is consistent with `root_addr`. Instead we rely on methods + // that mutate `LocatedTree` to verify that the insertion address is not + // inconsistent with `root_addr`. + (Node::Nil, _) => true, + } + } + + is_consistent(root_addr, &root).then_some(LocatedTree { root_addr, root }) } /// Returns the root address of this tree. @@ -234,10 +260,13 @@ impl LocatedTree { /// Returns the value at the specified position, if any. pub fn value_at_position(&self, position: Position) -> Option<&V> { + /// Pre-condition: `addr` must be the address of `root`. fn go(pos: Position, addr: Address, root: &Tree) -> Option<&V> { match &root.0 { Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = addr.children().unwrap(); + let (l_addr, r_addr) = addr + .children() + .expect("has children because we checked `root` is a parent"); if l_addr.position_range().contains(&pos) { go(pos, l_addr, left) } else { @@ -306,6 +335,7 @@ impl LocatedTree { /// if the tree is terminated by a [`Node::Nil`] or leaf node before the specified address can /// be reached. pub fn subtree(&self, addr: Address) -> Option { + /// Pre-condition: `root_addr` must be the address of `root`. fn go( root_addr: Address, root: &Tree, @@ -319,7 +349,9 @@ impl LocatedTree { } else { match &root.0 { Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); if l_addr.contains(&addr) { go(l_addr, left.as_ref(), addr) } else { @@ -343,6 +375,7 @@ impl LocatedTree { /// If this root address of this tree is lower down in the tree than the level specified, /// the entire tree is returned as the sole element of the result vector. pub fn decompose_to_level(self, level: Level) -> Vec { + /// Pre-condition: `root_addr` must be the address of `root`. fn go( level: Level, root_addr: Address, @@ -353,7 +386,9 @@ impl LocatedTree { } else { match root.0 { Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = root_addr.children().unwrap(); + let (l_addr, r_addr) = root_addr + .children() + .expect("has children because we checked `root` is a parent"); let mut l_decomposed = go( level, l_addr,