diff --git a/stdlib/src/collections/dict.mojo b/stdlib/src/collections/dict.mojo index d35229a17e..cce16f171c 100644 --- a/stdlib/src/collections/dict.mojo +++ b/stdlib/src/collections/dict.mojo @@ -317,33 +317,35 @@ struct _DictIndex: fn __moveinit__(out self, owned existing: Self): self.data = existing.data + @always_inline fn get_index(self, reserved: Int, slot: Int) -> Int: if reserved <= 128: var data = self.data.bitcast[Int8]() - return int(data.load(slot & (reserved - 1))) + return int(data.load(slot)) elif reserved <= 2**16 - 2: var data = self.data.bitcast[Int16]() - return int(data.load(slot & (reserved - 1))) + return int(data.load(slot)) elif reserved <= 2**32 - 2: var data = self.data.bitcast[Int32]() - return int(data.load(slot & (reserved - 1))) + return int(data.load(slot)) else: var data = self.data.bitcast[Int64]() - return int(data.load(slot & (reserved - 1))) + return int(data.load(slot)) + @always_inline fn set_index(mut self, reserved: Int, slot: Int, value: Int): if reserved <= 128: var data = self.data.bitcast[Int8]() - return data.store(slot & (reserved - 1), value) + return data.store(slot, value) elif reserved <= 2**16 - 2: var data = self.data.bitcast[Int16]() - return data.store(slot & (reserved - 1), value) + return data.store(slot, value) elif reserved <= 2**32 - 2: var data = self.data.bitcast[Int32]() - return data.store(slot & (reserved - 1), value) + return data.store(slot, value) else: var data = self.data.bitcast[Int64]() - return data.store(slot & (reserved - 1), value) + return data.store(slot, value) fn __del__(owned self): self.data.free() @@ -845,7 +847,6 @@ struct Dict[K: KeyElement, V: CollectionElement]( var entry = Pointer.address_of(self._entries[index]) debug_assert(entry[].__bool__(), "entry in index must be full") var entry_value = entry[].unsafe_take() - entry[] = None self.size -= 1 return entry_value^.reap_value() raise "KeyError" @@ -980,9 +981,11 @@ struct Dict[K: KeyElement, V: CollectionElement]( self.size += 1 self._n_entries += 1 + @always_inline fn _get_index(self, slot: Int) -> Int: return self._index.get_index(self._reserved(), slot) + @always_inline fn _set_index(mut self, slot: Int, index: Int): return self._index.set_index(self._reserved(), slot, index) @@ -1000,6 +1003,7 @@ struct Dict[K: KeyElement, V: CollectionElement]( return slot self._next_index_slot(slot, perturb) + @always_inline fn _find_index(self, hash: Int, key: K) -> (Bool, Int, Int): # Return (found, slot, index) var slot = hash & (self._reserved() - 1) @@ -1011,9 +1015,9 @@ struct Dict[K: KeyElement, V: CollectionElement]( elif index == Self.REMOVED: pass else: - var entry = self._entries[index] - debug_assert(entry.__bool__(), "entry in index must be full") - if hash == entry.value().hash and key == entry.value().key: + var entry = Pointer.address_of(self._entries[index]) + debug_assert(entry[].__bool__(), "entry in index must be full") + if hash == entry[].value().hash and key == entry[].value().key: return (True, slot, index) self._next_index_slot(slot, perturb) @@ -1030,30 +1034,31 @@ struct Dict[K: KeyElement, V: CollectionElement]( return var _reserved = self._reserved() * 2 self.size = 0 + var old_n_entries = self._n_entries self._n_entries = 0 var old_entries = self._entries^ self._entries = self._new_entries(_reserved) self._index = _DictIndex(self._reserved()) - for i in range(len(old_entries)): - var entry = old_entries[i] - if entry: - self._insert[safe_context=True](entry.unsafe_take()) + for i in range(old_n_entries): + var entry = Pointer.address_of(old_entries[i]) + if entry[]: + self._insert[safe_context=True](entry[].unsafe_take()) fn _compact(mut self): self._index = _DictIndex(self._reserved()) var right = 0 for left in range(self.size): - while not self._entries[right]: + var r_entry = Pointer.address_of(self._entries[right]) + while not r_entry[]: right += 1 debug_assert(right < self._reserved(), "Invalid dict state") - var entry = self._entries[right] - debug_assert(entry.__bool__(), "Logic error") - var slot = self._find_empty_index(entry.value().hash) + r_entry = Pointer.address_of(self._entries[right]) + debug_assert(r_entry[].__bool__(), "Logic error") + var slot = self._find_empty_index(r_entry[].value().hash) self._set_index(slot, left) if left != right: - self._entries[left] = entry.unsafe_take() - entry = None + self._entries[left] = r_entry[].unsafe_take() right += 1 self._n_entries = self.size diff --git a/stdlib/test/collections/test_dict.mojo b/stdlib/test/collections/test_dict.mojo index 4f8b1e965c..52f8d0f72e 100644 --- a/stdlib/test/collections/test_dict.mojo +++ b/stdlib/test/collections/test_dict.mojo @@ -15,8 +15,9 @@ from collections import Dict, KeyElement, Optional from collections.dict import OwnedKwargsDict -from test_utils import CopyCounter +from test_utils import CopyCounter, ValueDestructorRecorder from testing import assert_equal, assert_false, assert_raises, assert_true +from memory import UnsafePointer def test_dict_construction(): @@ -602,6 +603,41 @@ fn test_dict_setdefault() raises: assert_equal(0, other_dict["b"].copy_count) +def test_dict_entries_del_count(): + alias dict_type = Dict[Int, ValueDestructorRecorder] + for entries_to_insert in range(16): + var x = List[Int]() + var y = dict_type() + var result = 0 + var expected_result = 0 + for i in range(entries_to_insert): + expected_result += i + + for i in range(entries_to_insert): + y[i] = ValueDestructorRecorder(i, UnsafePointer.address_of(x)) + assert_equal(len(x), 0) + + for i in range(entries_to_insert): + result += y[i].value + assert_equal(len(x), 0) + __type_of(y).__del__(y^) + assert_equal(len(x), entries_to_insert) + assert_equal(result, expected_result) + + y = dict_type() + x.clear() + result = 0 + for i in range(entries_to_insert): + y[i] = ValueDestructorRecorder(i, UnsafePointer.address_of(x)) + assert_equal(len(x), 0) + for i in range(entries_to_insert): + result += y.pop(i).value + assert_equal(len(x), entries_to_insert) + __type_of(y).__del__(y^) + assert_equal(len(x), entries_to_insert) + assert_equal(result, expected_result) + + def main(): test_dict() test_dict_fromkeys() @@ -615,3 +651,4 @@ def main(): test_clear() test_init_initial_capacity() test_dict_setdefault() + test_dict_entries_del_count()