diff --git a/stdlib/scripts/check_licenses.mojo b/stdlib/scripts/check_licenses.mojo index 0b60a0f0dd..f707328ae4 100644 --- a/stdlib/scripts/check_licenses.mojo +++ b/stdlib/scripts/check_licenses.mojo @@ -35,7 +35,8 @@ def main(): # this is the current file continue file_path = Path(target_paths[i]) - if not file_path.read_text().startswith(LICENSE): + var text: String = file_path.read_text() + if not text.as_string_slice().startswith(LICENSE): files_without_license.append(file_path) if len(files_without_license) > 0: diff --git a/stdlib/src/collections/__init__.mojo b/stdlib/src/collections/__init__.mojo index 97f58c9c88..123ed4db6b 100644 --- a/stdlib/src/collections/__init__.mojo +++ b/stdlib/src/collections/__init__.mojo @@ -21,3 +21,4 @@ from .list import List from .optional import Optional, OptionalReg from .set import Set from .vector import InlinedFixedVector +from .linked_list import LinkedList diff --git a/stdlib/src/collections/linked_list.mojo b/stdlib/src/collections/linked_list.mojo new file mode 100644 index 0000000000..2fd6e07d19 --- /dev/null +++ b/stdlib/src/collections/linked_list.mojo @@ -0,0 +1,657 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2024, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from memory import UnsafePointer, Pointer +from sys.info import alignof +from collections import Optional + + +@value +struct _LinkedListNode[T: CollectionElement](): + alias PointerT = UnsafePointer[Self] + + var data: T + var next: Self.PointerT + var prev: Self.PointerT + + fn __init__(out self, owned data: T): + self.data = data + self.next = Self.PointerT() + self.prev = Self.PointerT() + + +struct LinkedList[T: CollectionElement]( + CollectionElement, CollectionElementNew, Sized, Boolable +): + """The `LinkedList` type is a dynamically-allocated doubly linked list. + + It supports pushing to the front and back in O(1). + + Parameters: + T: The type of the elements. + """ + + alias NodeT = _LinkedListNode[T] + alias PointerT = UnsafePointer[Self.NodeT] + + var _length: Int + var _head: Self.PointerT + var _tail: Self.PointerT + + fn __init__(out self): + """Default construct an empty list.""" + self._length = 0 + self._head = Self.PointerT() + self._tail = Self.PointerT() + + fn __init__(out self, owned *elems: T): + """ + Construct a list with the provided elements. + + Args: + elems: The elements to add to the list. + """ + self = Self(elements=elems^) + + fn __init__(out self, *, owned elements: VariadicListMem[T, _]): + """ + Construct a list from a `VariadicListMem`. + + Args: + elements: The elements to add to the list. + """ + self = Self() + + var length = len(elements) + + for i in range(length): + var src = UnsafePointer.address_of(elements[i]) + var node = Self.PointerT.alloc(1) + var dst = UnsafePointer.address_of(node[].data) + src.move_pointee_into(dst) + node[].next = Self.PointerT() + node[].prev = self._tail + if not self._tail: + self._head = node + self._tail = node + else: + self._tail[].next = node + self._tail = node + + # Do not destroy the elements when their backing storage goes away. + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(elements) + ) + + self._length = length + + fn push_back(mut self, owned elem: T): + """ + Append the provided element to the end of the list. + O(1) time complexity. + + Args: + elem: The element to append to the list. + """ + var node = Self.PointerT.alloc(1) + var data = UnsafePointer.address_of(node[].data) + data.init_pointee_move(elem^) + node[].prev = self._tail + if self._tail: + self._tail[].next = node + self._tail = node + self._length += 1 + if not self._head: + self._head = node + + fn append(mut self, owned elem: T): + """ + Append the provided element to the end of the list. + O(1) time complexity. Alias for list compatibility. + + Args: + elem: The element to append to the list. + """ + self.push_back(elem^) + + fn push_front(mut self, owned elem: T): + """ + Append the provided element to the front of the list. + O(1) time complexity. + + Args: + elem: The element to prepend to the list. + """ + var node = Self.PointerT.alloc(1) + node.init_pointee_move(Self.NodeT(elem^)) + node[].next = self._head + if self._head: + self._head[].prev = node + self._head = node + self._length += 1 + if not self._tail: + self._tail = node + + fn insert(mut self, owned idx: Int, owned elem: T) raises: + """ + Insert an element `elem` into the list at index `idx`. + + Args: + idx: The index to insert `elem` at. + elem: The item to insert into the list. + """ + var i = max(0, index(idx) if idx >= 0 else index(idx) + len(self)) + + if i == 0: + var node = Self.PointerT.alloc(1) + node.init_pointee_move(Self.NodeT(elem^)) + + if self._head: + node[].next = self._head + self._head[].prev = node + + self._head = node + + if not self._tail: + self._tail = node + + self._length += 1 + return + + i -= 1 + + var current = self._get_nth(i) + if current: + var next = current[].next + var node = Self.PointerT.alloc(1) + if not node: + raise "OOM" + var data = UnsafePointer.address_of(node[].data) + data[] = elem^ + node[].next = next + node[].prev = current + if next: + next[].prev = node + current[].next = node + if node[].next == Self.PointerT(): + self._tail = node + if node[].prev == Self.PointerT(): + self._head = node + self._length += 1 + else: + raise "index out of bounds" + + fn head(ref self) -> Optional[Pointer[T, __origin_of(self)]]: + """ + Gets a reference to the head of the list if one exists. + O(1) time complexity. + + Returns: + An reference to the head if one is present. + """ + if self._head: + return Optional[Pointer[T, __origin_of(self)]]( + Pointer.address_of(self._head[].data) + ) + else: + return Optional[Pointer[T, __origin_of(self)]]() + + fn tail(ref self) -> Optional[Pointer[T, __origin_of(self)]]: + """ + Gets a reference to the tail of the list if one exists. + O(1) time complexity. + + + Returns: + An reference to the tail if one is present. + """ + if self._tail: + return Optional[Pointer[T, __origin_of(self)]]( + Pointer.address_of(self._tail[].data) + ) + else: + return Optional[Pointer[T, __origin_of(self)]]() + + fn _get_nth(ref self, read idx: Int) -> Self.PointerT: + debug_assert(-len(self) <= idx < len(self), "index out of range") + + if Int(idx) >= 0: + if self._length <= idx: + return Self.PointerT() + + var cursor = UInt(idx) + var current = self._head + + while cursor > 0 and current: + current = current[].next + cursor -= 1 + + if cursor > 0: + return Self.PointerT() + else: + return current + else: + # print(index(idx)) + var cursor = (idx * -1) - 1 + # print(cursor) + var current = self._tail + + while cursor > 0 and current: + # print("loop") + current = current[].prev + cursor -= 1 + + if cursor > 0: + return Self.PointerT() + else: + # print(current) + # var c = self._head + # for i in range(len(self)): + # print(c) + # c = c[].next + return current + + fn __getitem__[I: Indexer](ref self, read idx: I) raises -> ref [self] T: + """ + Returns a reference the indicated element if it exists. + O(len(self)) time complexity. + + Parameters: + I: The type of indexer to use. + + Args: + idx: The index of the element to retrieve. Negative numbers are converted into an offset from the tail. + + Returns: + A `Pointer` to the element at the provided index, if it exists. + """ + var current = self._get_nth(Int(idx)) + + if not current: + raise "index out of bounds" + else: + return UnsafePointer[T].address_of(current[].data)[] + + fn __setitem__[I: Indexer](ref self, read idx: I, owned value: T) raises: + """ + Sets the item at index `idx` to `value`, destroying the current value. + O(len(self)) time complexity. + + Parameters: + I: The type of indexer to use. + + Args: + idx: The index of the element to retrieve. Negative numbers are converted into an offset from the tail. + value: The value to emplace into the list. + + Raises: + Raises if given an out of bounds index. + """ + var current = self._get_nth(Int(idx)) + + if not current: + raise "index out of bounds" + + var data = UnsafePointer.address_of(current[].data) + data.init_pointee_move(value^) + + fn pop(mut self) -> Optional[T]: + """ + Remove the last element of the list. + + Returns: + The element, if it was found. + """ + return self.pop(len(self) - 1) + + fn pop[I: Indexer](mut self, owned i: I) -> Optional[T]: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + The element, if it was found. + """ + var current = self._get_nth(Int(i)) + + if not current: + return Optional[T]() + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.data^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._length -= 1 + return Optional[T](data) + + fn clear(mut self): + """Removes all elements from the list.""" + var current = self._head + while current: + var old = current + current = current[].next + old.destroy_pointee() + old.free() + + self._head = Self.PointerT() + self._tail = Self.PointerT() + self._length = 0 + + fn __len__(read self) -> Int: + """ + Returns the number of elements in the list. + + Returns: + The length of the list. + """ + return self._length + + fn empty(self) -> Bool: + """ + Whether the list is empty. + + Returns: + Whether the list is empty or not. + """ + return self._head == Self.PointerT() + + fn __bool__(self) -> Bool: + """ + Casts self to `Bool` based on whether the list is empty or not. + + Returns: + Whether the list is empty or not. + """ + return not self.empty() + + fn __copyinit__(out self, read existing: Self): + """Creates a deepcopy of the given list. + + Args: + existing: The list to copy. + """ + self = Self() + var n = existing._head + while n: + self.push_back(n[].data) + n = n[].next + + fn copy(read self) -> Self: + """ + Creates a deepcopy of this list and return it. + + Returns: + A copy of this list. + """ + return Self.__copyinit__(self) + + fn __moveinit__(out self, owned existing: Self): + """Move data of an existing list into a new one. + + Args: + existing: The existing list. + """ + self._length = existing._length + self._head = existing._head + self._tail = existing._tail + + fn __del__(owned self): + """Destroy all elements in the list and free its memory.""" + + var current = self._head + while current: + var prev = current + current = current[].next + prev.destroy_pointee() + + fn reverse(mut self): + """Reverses the list in-place.""" + var current = self._head + + while current: + current[].next, current[].prev = current[].prev, current[].next + current = current[].prev + + self._head, self._tail = self._tail, self._head + + fn __reversed__(self) -> Self: + """ + Create a reversed copy of the list. + + Returns: + A reversed copy of the list. + """ + var rev = Self() + + var current = self._tail + while current: + rev.push_back(current[].data) + current = current[].prev + + return rev + + fn extend(mut self, owned other: Self): + """ + Extends the list with another. + O(1) time complexity. + + Args: + other: The list to append to this one. + """ + if self._tail: + self._tail[].next = other._head + if other._head: + other._head[].prev = self._tail + if other._tail: + self._tail = other._tail + + self._length += other._length + else: + self._head = other._head + self._tail = other._tail + self._length = other._length + + other._head = Self.PointerT() + other._tail = Self.PointerT() + + fn count[ + T: EqualityComparableCollectionElement + ](self: LinkedList[T], read elem: T) -> UInt: + """ + Count the occurrences of `elem` in the list. + + Parameters: + T: The list element type, used to conditionally enable the function. + + Args: + elem: The element to search for. + + Returns: + The number of occurrences of `elem` in the list. + """ + var current = self._head + var count = 0 + while current: + if current[].data == elem: + count += 1 + + current = current[].next + + return count + + fn __contains__[ + T: EqualityComparableCollectionElement, // + ](self: LinkedList[T], value: T) -> Bool: + """ + Checks if the list contains `value`. + + Parameters: + T: The list element type, used to conditionally enable the function. + + Args: + value: The value to search for in the list. + + Returns: + Whether the list contains `value`. + """ + var current = self._head + while current: + if current[].data == value: + return True + current = current[].next + + return False + + fn __eq__[ + T: EqualityComparableCollectionElement, // + ](read self: LinkedList[T], read other: LinkedList[T]) -> Bool: + """ + Checks if the two lists are equal. + + Parameters: + T: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are equal. + """ + if self._length != other._length: + return False + + var self_cursor = self._head + var other_cursor = other._head + + while self_cursor: + if self_cursor[].data != other_cursor[].data: + return False + + self_cursor = self_cursor[].next + other_cursor = other_cursor[].next + + return True + + fn __ne__[ + T: EqualityComparableCollectionElement, // + ](self: LinkedList[T], other: LinkedList[T]) -> Bool: + """ + Checks if the two lists are not equal. + + Parameters: + T: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are not equal. + """ + return not (self == other) + + @no_inline + fn __str__[ + U: RepresentableCollectionElement, // + ](self: LinkedList[U]) raises -> String: + """Returns a string representation of a `List`. + + Note that since we can't condition methods on a trait yet, + the way to call this method is a bit special. Here is an example below: + + ```mojo + var my_list = LinkedList[Int](1, 2, 3) + print(my_list.__str__()) + ``` + + When the compiler supports conditional methods, then a simple `str(my_list)` will + be enough. + + The elements' type must implement the `__repr__()` method for this to work. + + Parameters: + U: The type of the elements in the list. Must implement the + traits `Representable` and `CollectionElement`. + + Returns: + A string representation of the list. + """ + var output = String() + self.write_to(output) + return output^ + + @no_inline + fn write_to[ + W: Writer, U: RepresentableCollectionElement, // + ](self: LinkedList[U], mut writer: W) raises: + """Write `my_list.__str__()` to a `Writer`. + + Parameters: + W: A type conforming to the Writable trait. + U: The type of the List elements. Must have the trait `RepresentableCollectionElement`. + + Args: + writer: The object to write to. + """ + writer.write("[") + for i in range(len(self)): + writer.write(repr(self[i])) + if i < len(self) - 1: + writer.write(", ") + writer.write("]") + + @no_inline + fn __repr__[ + U: RepresentableCollectionElement, // + ](self: LinkedList[U]) raises -> String: + """Returns a string representation of a `List`. + + Note that since we can't condition methods on a trait yet, + the way to call this method is a bit special. Here is an example below: + + ```mojo + var my_list = LinkedList[Int](1, 2, 3) + print(my_list.__repr__()) + ``` + + When the compiler supports conditional methods, then a simple `repr(my_list)` will + be enough. + + The elements' type must implement the `__repr__()` for this to work. + + Parameters: + U: The type of the elements in the list. Must implement the + traits `Representable` and `CollectionElement`. + + Returns: + A string representation of the list. + """ + return self.__str__() diff --git a/stdlib/test/collections/test_linked_list.mojo b/stdlib/test/collections/test_linked_list.mojo new file mode 100644 index 0000000000..13511abe78 --- /dev/null +++ b/stdlib/test/collections/test_linked_list.mojo @@ -0,0 +1,501 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2024, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +# RUN: %mojo %s + +from collections import LinkedList +from sys.info import sizeof + +from memory import UnsafePointer, Span +from test_utils import CopyCounter, MoveCounter +from testing import assert_equal, assert_false, assert_raises, assert_true + + +def test_list(): + var list = LinkedList[Int]() + + for i in range(5): + list.push_back(i) + + assert_equal(5, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + assert_equal(3, list[3]) + assert_equal(4, list[4]) + + assert_equal(0, list[-5]) + assert_equal(3, list[-2]) + assert_equal(4, list[-1]) + + list[2] = -2 + assert_equal(-2, list[2]) + + list[-5] = 5 + assert_equal(5, list[-5]) + list[-2] = 3 + assert_equal(3, list[-2]) + list[-1] = 7 + assert_equal(7, list[-1]) + + +def test_list_clear(): + var list = LinkedList[Int](1, 2, 3) + assert_equal(len(list), 3) + list.clear() + + assert_equal(len(list), 0) + + +def test_list_to_bool_conversion(): + assert_false(LinkedList[String]()) + assert_true(LinkedList[String]("a")) + assert_true(LinkedList[String]("", "a")) + assert_true(LinkedList[String]("")) + + +def test_list_pop(): + var list = LinkedList[Int]() + # Test pop with index + for i in range(6): + list.push_back(i) + + assert_equal(6, len(list)) + + # try popping from index 3 for 3 times + for i in range(3, 6): + assert_equal(i, list.pop(3).value()) + + # list should have 3 elements now + assert_equal(3, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + + # Test pop with negative index + for i in range(0, 2): + var popped = list.pop(-len(list)) + assert_true(popped) + assert_equal(i, popped.value()) + + # test default index as well + assert_equal(2, list.pop().value()) + list.push_back(2) + assert_equal(2, list.pop().value()) + + # list should be empty now + assert_equal(0, len(list)) + + +def test_list_variadic_constructor(): + var l = LinkedList[Int](2, 4, 6) + assert_equal(3, len(l)) + assert_equal(2, l[0]) + assert_equal(4, l[1]) + assert_equal(6, l[2]) + + l.push_back(8) + assert_equal(4, len(l)) + assert_equal(8, l[3]) + + # + # Test variadic construct copying behavior + # + + var l2 = LinkedList[CopyCounter]( + CopyCounter(), CopyCounter(), CopyCounter() + ) + + assert_equal(len(l2), 3) + assert_equal(l2[0].copy_count, 0) + assert_equal(l2[1].copy_count, 0) + assert_equal(l2[2].copy_count, 0) + + +def test_list_reverse(): + # + # Test reversing the list [] + # + + var vec = LinkedList[Int]() + + assert_equal(len(vec), 0) + + vec.reverse() + + assert_equal(len(vec), 0) + + # + # Test reversing the list [123] + # + + vec = LinkedList[Int]() + + vec.push_back(123) + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + vec.reverse() + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + # + # Test reversing the list ["one", "two", "three"] + # + + var vec2 = LinkedList[String]("one", "two", "three") + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "one") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "three") + + vec2.reverse() + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "three") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "one") + + # + # Test reversing the list [5, 10] + # + + vec = LinkedList[Int]() + vec.push_back(5) + vec.push_back(10) + + assert_equal(len(vec), 2) + assert_equal(vec[0], 5) + assert_equal(vec[1], 10) + + vec.reverse() + + assert_equal(len(vec), 2) + assert_equal(vec[0], 10) + assert_equal(vec[1], 5) + + +def test_list_insert(): + # + # Test the list [1, 2, 3] created with insert + # + + var v1 = LinkedList[Int]() + v1.insert(len(v1), 1) + v1.insert(len(v1), 3) + v1.insert(1, 2) + + assert_equal(len(v1), 3) + assert_equal(v1[0], 1) + assert_equal(v1[1], 2) + assert_equal(v1[2], 3) + + print(v1.__str__()) + + # + # Test the list [1, 2, 3, 4, 5] created with negative and positive index + # + + var v2 = LinkedList[Int]() + v2.insert(-1729, 2) + v2.insert(len(v2), 3) + v2.insert(len(v2), 5) + v2.insert(-1, 4) + v2.insert(-len(v2), 1) + print(v2.__str__()) + + assert_equal(len(v2), 5) + assert_equal(v2[0], 1) + assert_equal(v2[1], 2) + assert_equal(v2[2], 3) + assert_equal(v2[3], 4) + assert_equal(v2[4], 5) + + # + # Test the list [1, 2, 3, 4] created with negative index + # + + var v3 = LinkedList[Int]() + v3.insert(-11, 4) + v3.insert(-13, 3) + v3.insert(-17, 2) + v3.insert(-19, 1) + + assert_equal(len(v3), 4) + assert_equal(v3[0], 1) + assert_equal(v3[1], 2) + assert_equal(v3[2], 3) + assert_equal(v3[3], 4) + + # + # Test the list [1, 2, 3, 4, 5, 6, 7, 8] created with insert + # + + var v4 = LinkedList[Int]() + for i in range(4): + v4.insert(0, 4 - i) + v4.insert(len(v4), 4 + i + 1) + + for i in range(len(v4)): + assert_equal(v4[i], i + 1) + + +def test_list_extend_non_trivial(): + # Tests three things: + # - extend() for non-plain-old-data types + # - extend() with mixed-length self and other lists + # - extend() using optimal number of __moveinit__() calls + + # Preallocate with enough capacity to avoid reallocation making the + # move count checks below flaky. + var v1 = LinkedList[MoveCounter[String]]() + v1.push_back(MoveCounter[String]("Hello")) + v1.push_back(MoveCounter[String]("World")) + + var v2 = LinkedList[MoveCounter[String]]() + v2.push_back(MoveCounter[String]("Foo")) + v2.push_back(MoveCounter[String]("Bar")) + v2.push_back(MoveCounter[String]("Baz")) + + v1.extend(v2^) + + assert_equal(len(v1), 5) + assert_equal(v1[0].value, "Hello") + assert_equal(v1[1].value, "World") + assert_equal(v1[2].value, "Foo") + assert_equal(v1[3].value, "Bar") + assert_equal(v1[4].value, "Baz") + + assert_equal(v1[0].move_count, 1) + assert_equal(v1[1].move_count, 1) + assert_equal(v1[2].move_count, 1) + assert_equal(v1[3].move_count, 1) + assert_equal(v1[4].move_count, 1) + + +def test_2d_dynamic_list(): + var list = LinkedList[LinkedList[Int]]() + + for i in range(2): + var v = LinkedList[Int]() + for j in range(3): + v.push_back(i + j) + list.push_back(v) + + assert_equal(0, list[0][0]) + assert_equal(1, list[0][1]) + assert_equal(2, list[0][2]) + assert_equal(1, list[1][0]) + assert_equal(2, list[1][1]) + assert_equal(3, list[1][2]) + + assert_equal(2, len(list)) + + assert_equal(3, len(list[0])) + + list[0].clear() + assert_equal(0, len(list[0])) + + list.clear() + assert_equal(0, len(list)) + + +def test_list_explicit_copy(): + var list = LinkedList[CopyCounter]() + list.push_back(CopyCounter()) + var list_copy = list.copy() + assert_equal(0, list[0].copy_count) + assert_equal(1, list_copy[0].copy_count) + + var l2 = LinkedList[Int]() + for i in range(10): + l2.push_back(i) + + var l2_copy = l2.copy() + assert_equal(len(l2), len(l2_copy)) + for i in range(len(l2)): + assert_equal(l2[i], l2_copy[i]) + + +@value +struct CopyCountedStruct(CollectionElement): + var counter: CopyCounter + var value: String + + fn __init__(out self, *, other: Self): + self.counter = other.counter.copy() + self.value = other.value.copy() + + @implicit + fn __init__(out self, value: String): + self.counter = CopyCounter() + self.value = value + + +def test_no_extra_copies_with_sugared_set_by_field(): + var list = LinkedList[LinkedList[CopyCountedStruct]]() + var child_list = LinkedList[CopyCountedStruct]() + child_list.push_back(CopyCountedStruct("Hello")) + child_list.push_back(CopyCountedStruct("World")) + + # No copies here. Constructing with LinkedList[CopyCountedStruct](CopyCountedStruct("Hello")) is a copy. + assert_equal(0, child_list[0].counter.copy_count) + assert_equal(0, child_list[1].counter.copy_count) + + list.push_back(child_list^) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + # list[0][1] makes a copy for reasons I cannot determine + list.__getitem__(0).__getitem__(1).value = "Mojo" + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + assert_equal("Mojo", list[0][1].value) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + +def test_list_boolable(): + assert_true(LinkedList[Int](1)) + assert_false(LinkedList[Int]()) + + +def test_list_count(): + var list = LinkedList[Int](1, 2, 3, 2, 5, 6, 7, 8, 9, 10) + assert_equal(1, list.count(1)) + assert_equal(2, list.count(2)) + assert_equal(0, list.count(4)) + + var list2 = LinkedList[Int]() + assert_equal(0, list2.count(1)) + + +def test_list_contains(): + var x = LinkedList[Int](1, 2, 3) + assert_false(0 in x) + assert_true(1 in x) + assert_false(4 in x) + + # TODO: implement LinkedList.__eq__ for Self[ComparableCollectionElement] + # var y = LinkedList[LinkedList[Int]]() + # y.push_back(LinkedList(1,2)) + # assert_equal(LinkedList(1,2) in y,True) + # assert_equal(LinkedList(0,1) in y,False) + + +def test_list_eq_ne(): + var l1 = LinkedList[Int](1, 2, 3) + var l2 = LinkedList[Int](1, 2, 3) + assert_true(l1 == l2) + assert_false(l1 != l2) + + var l3 = LinkedList[Int](1, 2, 3, 4) + assert_false(l1 == l3) + assert_true(l1 != l3) + + var l4 = LinkedList[Int]() + var l5 = LinkedList[Int]() + assert_true(l4 == l5) + assert_true(l1 != l4) + + var l6 = LinkedList[String]("a", "b", "c") + var l7 = LinkedList[String]("a", "b", "c") + var l8 = LinkedList[String]("a", "b") + assert_true(l6 == l7) + assert_false(l6 != l7) + assert_false(l6 == l8) + + +def test_indexing(): + var l = LinkedList[Int](1, 2, 3) + assert_equal(l[Int(1)], 2) + assert_equal(l[False], 1) + assert_equal(l[True], 2) + assert_equal(l[2], 3) + + +# ===-------------------------------------------------------------------===# +# LinkedList dtor tests +# ===-------------------------------------------------------------------===# +var g_dtor_count: Int = 0 + + +struct DtorCounter(CollectionElement): + # NOTE: payload is required because LinkedList does not support zero sized structs. + var payload: Int + + fn __init__(out self): + self.payload = 0 + + fn __init__(out self, *, other: Self): + self.payload = other.payload + + fn __copyinit__(out self, existing: Self, /): + self.payload = existing.payload + + fn __moveinit__(out self, owned existing: Self, /): + self.payload = existing.payload + existing.payload = 0 + + fn __del__(owned self): + g_dtor_count += 1 + + +def inner_test_list_dtor(): + # explicitly reset global counter + g_dtor_count = 0 + + var l = LinkedList[DtorCounter]() + assert_equal(g_dtor_count, 0) + + l.push_back(DtorCounter()) + assert_equal(g_dtor_count, 0) + + l^.__del__() + assert_equal(g_dtor_count, 1) + + +def test_list_dtor(): + # call another function to force the destruction of the list + inner_test_list_dtor() + + # verify we still only ran the destructor once + assert_equal(g_dtor_count, 1) + + +# ===-------------------------------------------------------------------===# +# main +# ===-------------------------------------------------------------------===# +fn main() raises: + test_list() + test_list_clear() + test_list_to_bool_conversion() + test_list_pop() + test_list_variadic_constructor() + test_list_reverse() + test_list_extend_non_trivial() + test_list_explicit_copy() + test_no_extra_copies_with_sugared_set_by_field() + test_2d_dynamic_list() + test_list_boolable() + test_list_count() + test_list_contains() + test_indexing() + test_list_dtor() + test_list_insert()