Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] fix: remove DictEntry.__copyinit__ in Dict methods (Pointer, speed up) #3824

Open
wants to merge 2 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions stdlib/src/collections/dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -317,33 +317,35 @@ struct _DictIndex:
fn __moveinit__(out self, owned existing: Self):
self.data = existing.data

@always_inline
fn get_index(self, reserved: Int, slot: Int) -> Int:
if reserved <= 128:
var data = self.data.bitcast[Int8]()
return int(data.load(slot & (reserved - 1)))
return int(data.load(slot))
elif reserved <= 2**16 - 2:
var data = self.data.bitcast[Int16]()
return int(data.load(slot & (reserved - 1)))
return int(data.load(slot))
elif reserved <= 2**32 - 2:
var data = self.data.bitcast[Int32]()
return int(data.load(slot & (reserved - 1)))
return int(data.load(slot))
else:
var data = self.data.bitcast[Int64]()
return int(data.load(slot & (reserved - 1)))
return int(data.load(slot))

@always_inline
fn set_index(mut self, reserved: Int, slot: Int, value: Int):
if reserved <= 128:
var data = self.data.bitcast[Int8]()
return data.store(slot & (reserved - 1), value)
return data.store(slot, value)
elif reserved <= 2**16 - 2:
var data = self.data.bitcast[Int16]()
return data.store(slot & (reserved - 1), value)
return data.store(slot, value)
elif reserved <= 2**32 - 2:
var data = self.data.bitcast[Int32]()
return data.store(slot & (reserved - 1), value)
return data.store(slot, value)
else:
var data = self.data.bitcast[Int64]()
return data.store(slot & (reserved - 1), value)
return data.store(slot, value)

fn __del__(owned self):
self.data.free()
Expand Down Expand Up @@ -845,7 +847,6 @@ struct Dict[K: KeyElement, V: CollectionElement](
var entry = Pointer.address_of(self._entries[index])
debug_assert(entry[].__bool__(), "entry in index must be full")
var entry_value = entry[].unsafe_take()
entry[] = None
self.size -= 1
return entry_value^.reap_value()
raise "KeyError"
Expand Down Expand Up @@ -980,9 +981,11 @@ struct Dict[K: KeyElement, V: CollectionElement](
self.size += 1
self._n_entries += 1

@always_inline
fn _get_index(self, slot: Int) -> Int:
return self._index.get_index(self._reserved(), slot)

@always_inline
fn _set_index(mut self, slot: Int, index: Int):
return self._index.set_index(self._reserved(), slot, index)

Expand All @@ -1000,6 +1003,7 @@ struct Dict[K: KeyElement, V: CollectionElement](
return slot
self._next_index_slot(slot, perturb)

@always_inline
fn _find_index(self, hash: Int, key: K) -> (Bool, Int, Int):
# Return (found, slot, index)
var slot = hash & (self._reserved() - 1)
Expand All @@ -1011,9 +1015,9 @@ struct Dict[K: KeyElement, V: CollectionElement](
elif index == Self.REMOVED:
pass
else:
var entry = self._entries[index]
debug_assert(entry.__bool__(), "entry in index must be full")
if hash == entry.value().hash and key == entry.value().key:
var entry = Pointer.address_of(self._entries[index])
debug_assert(entry[].__bool__(), "entry in index must be full")
if hash == entry[].value().hash and key == entry[].value().key:
return (True, slot, index)
self._next_index_slot(slot, perturb)

Expand All @@ -1030,30 +1034,31 @@ struct Dict[K: KeyElement, V: CollectionElement](
return
var _reserved = self._reserved() * 2
self.size = 0
var old_n_entries = self._n_entries
self._n_entries = 0
var old_entries = self._entries^
self._entries = self._new_entries(_reserved)
self._index = _DictIndex(self._reserved())

for i in range(len(old_entries)):
var entry = old_entries[i]
if entry:
self._insert[safe_context=True](entry.unsafe_take())
for i in range(old_n_entries):
var entry = Pointer.address_of(old_entries[i])
if entry[]:
self._insert[safe_context=True](entry[].unsafe_take())

fn _compact(mut self):
self._index = _DictIndex(self._reserved())
var right = 0
for left in range(self.size):
while not self._entries[right]:
var r_entry = Pointer.address_of(self._entries[right])
while not r_entry[]:
right += 1
debug_assert(right < self._reserved(), "Invalid dict state")
var entry = self._entries[right]
debug_assert(entry.__bool__(), "Logic error")
var slot = self._find_empty_index(entry.value().hash)
r_entry = Pointer.address_of(self._entries[right])
debug_assert(r_entry[].__bool__(), "Logic error")
var slot = self._find_empty_index(r_entry[].value().hash)
self._set_index(slot, left)
if left != right:
self._entries[left] = entry.unsafe_take()
entry = None
self._entries[left] = r_entry[].unsafe_take()
right += 1

self._n_entries = self.size
Expand Down
39 changes: 38 additions & 1 deletion stdlib/test/collections/test_dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from collections import Dict, KeyElement, Optional
from collections.dict import OwnedKwargsDict

from test_utils import CopyCounter
from test_utils import CopyCounter, ValueDestructorRecorder
from testing import assert_equal, assert_false, assert_raises, assert_true
from memory import UnsafePointer


def test_dict_construction():
Expand Down Expand Up @@ -602,6 +603,41 @@ fn test_dict_setdefault() raises:
assert_equal(0, other_dict["b"].copy_count)


def test_dict_entries_del_count():
alias dict_type = Dict[Int, ValueDestructorRecorder]
for entries_to_insert in range(16):
var x = List[Int]()
var y = dict_type()
var result = 0
var expected_result = 0
for i in range(entries_to_insert):
expected_result += i

for i in range(entries_to_insert):
y[i] = ValueDestructorRecorder(i, UnsafePointer.address_of(x))
assert_equal(len(x), 0)

for i in range(entries_to_insert):
result += y[i].value
assert_equal(len(x), 0)
__type_of(y).__del__(y^)
assert_equal(len(x), entries_to_insert)
assert_equal(result, expected_result)

y = dict_type()
x.clear()
result = 0
for i in range(entries_to_insert):
y[i] = ValueDestructorRecorder(i, UnsafePointer.address_of(x))
assert_equal(len(x), 0)
for i in range(entries_to_insert):
result += y.pop(i).value
assert_equal(len(x), entries_to_insert)
__type_of(y).__del__(y^)
assert_equal(len(x), entries_to_insert)
assert_equal(result, expected_result)


def main():
test_dict()
test_dict_fromkeys()
Expand All @@ -615,3 +651,4 @@ def main():
test_clear()
test_init_initial_capacity()
test_dict_setdefault()
test_dict_entries_del_count()
Loading