From 9d20a4a21ba520b61b781417a82a10639a4dfff2 Mon Sep 17 00:00:00 2001 From: rd4com <144297616+rd4com@users.noreply.github.com> Date: Thu, 5 Dec 2024 22:21:05 +0100 Subject: [PATCH] Add `List.map(fn(mut T)->None)` Signed-off-by: rd4com <144297616+rd4com@users.noreply.github.com> --- stdlib/src/collections/list.mojo | 29 ++++++++++++++++++++++++++ stdlib/test/collections/test_list.mojo | 16 ++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index d4791dd79ef..0d5d90f5fe2 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -930,6 +930,35 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( """ return self.data + fn map(ref self, func: fn (mut T) -> None) -> Self: + """Map the values of the list into a new list trough a function. + + Args: + func: The function used on every elements to create the new list. + + Returns: + A new `List` created by calling `func` on every elements of `self`. + + For example: + ```mojo + fn MyFunc(mut e: Int): + e+=1 + + var MyList = List(0, 1, 2).map(MyFunc) + + print( + MyList[0] == 1, + MyList[1] == 2, + MyList[2] == 3, + ) + ```. + + """ + var tmp = self + for i in tmp: + func(i[]) + return tmp + fn _clip(value: Int, start: Int, end: Int) -> Int: return max(start, min(value, end)) diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index cd173fc66cc..ca100390827 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -924,6 +924,21 @@ def test_list_repr(): assert_equal(empty.__repr__(), "[]") +def test_list_map(): + fn MyFunc(mut e: Int): + e += 1 + + var lst = List(0, 1, 2).map(MyFunc) + for e in range(len(lst)): + assert_equal(lst[e], e + 1) + + lst = List(0, 1, 2) + var lst2 = lst.map(MyFunc) + for e in range(len(lst)): + assert_equal(lst[e], e) + assert_equal(lst2[e], e + 1) + + # ===-------------------------------------------------------------------===# # main # ===-------------------------------------------------------------------===# @@ -962,3 +977,4 @@ def main(): test_indexing() test_list_dtor() test_list_repr() + test_list_map()