Skip to content

Commit

Permalink
Refactor and optimize String.replace()
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Dec 11, 2024
1 parent 2ace785 commit 5655490
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 53 deletions.
96 changes: 44 additions & 52 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1670,14 +1670,15 @@ struct String(
"""Return the number of non-overlapping occurrences of substring
`substr` in the string.
If sub is empty, returns the number of empty strings between characters
which is the length of the string plus one.
Args:
substr: The substring to count.
substr: The substring to count.
Returns:
The number of occurrences of `substr`.
The number of occurrences of `substr`.
Notes:
If sub is empty, returns the number of empty strings between characters
which is the length of the string plus one.
"""
if not substr:
return len(self) + 1
Expand Down Expand Up @@ -1901,51 +1902,54 @@ struct String(
Returns:
The string where all occurrences of `old` are replaced with `new`.
"""
if not old:
return self._interleave(new)

var occurrences = self.count(old)
if occurrences == -1:
return self

var self_start = self.unsafe_ptr()
var self_ptr = self.unsafe_ptr()
var s_ptr = self.unsafe_ptr()
var new_ptr = new.unsafe_ptr()

var self_len = self.byte_length()
var s_len = self.byte_length()
var old_len = old.byte_length()
var new_len = new.byte_length()

var res = Self._buffer_type()
res.reserve(self_len + (old_len - new_len) * occurrences + 1)

for _ in range(occurrences):
var curr_offset = int(self_ptr) - int(self_start)

var idx = self.find(old, curr_offset)
if old_len == 0:
var capacity = s_len + new_len * self.byte_length() + 1
var res_ptr = UnsafePointer[Byte].alloc(capacity)
var offset = 0
for s in self:
memcpy(res_ptr + offset, new_ptr, new_len)
offset += new_len
memcpy(res_ptr + offset, s.unsafe_ptr(), s.byte_length())
offset += s.byte_length()
res_ptr[capacity - 1] = 0
return String(ptr=res_ptr, length=capacity)

# FIXME(#3792): this should use self.as_bytes().count(old) which will be
# faster because returning unicode offsets has overhead and will return
# less bytes than necessary and cause a segfault
var occurrences = self.count(old)
if occurrences == 0:
return self

debug_assert(idx >= 0, "expected to find occurrence during find")
var capacity = s_len + (new_len - old_len) * occurrences + 1
var res_ptr = UnsafePointer[Byte].alloc(capacity)
var s_offset = 0
var res_offset = 0

while s_offset < s_len:
# FIXME(#3548): this should use raw bytes self.as_bytes().find(...)
var idx = self.find(old, s_offset)
if idx == -1:
memcpy(res_ptr + res_offset, s_ptr + s_offset, s_len - s_offset)
break
# Copy preceding unchanged chars
for _ in range(curr_offset, idx):
res.append(self_ptr[])
self_ptr += 1

var length = idx - s_offset
memcpy(res_ptr + res_offset, s_ptr + s_offset, length)
res_offset += length
s_offset += length + old_len
# Insert a copy of the new replacement string
for i in range(new_len):
res.append(new_ptr[i])
memcpy(res_ptr + res_offset, new_ptr, new_len)
res_offset += new_len

self_ptr += old_len

while True:
var val = self_ptr[]
if val == 0:
break
res.append(self_ptr[])
self_ptr += 1

res.append(0)
return String(res^)
res_ptr[capacity - 1] = 0
return String(ptr=res_ptr, length=capacity)

fn strip(self, chars: StringSlice) -> StringSlice[__origin_of(self)]:
"""Return a copy of the string with leading and trailing characters
Expand Down Expand Up @@ -2030,18 +2034,6 @@ struct String(
"""
hasher._update_with_bytes(self.unsafe_ptr(), self.byte_length())

fn _interleave(self, val: String) -> String:
var res = Self._buffer_type()
var val_ptr = val.unsafe_ptr()
var self_ptr = self.unsafe_ptr()
res.reserve(val.byte_length() * self.byte_length() + 1)
for i in range(self.byte_length()):
for j in range(val.byte_length()):
res.append(val_ptr[j])
res.append(self_ptr[i])
res.append(0)
return String(res^)

fn lower(self) -> String:
"""Returns a copy of the string with all cased characters
converted to lowercase.
Expand Down
3 changes: 2 additions & 1 deletion stdlib/test/python/my_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, bar):

class AbstractPerson(ABC):
@abstractmethod
def method(self): ...
def method(self):
...


def my_function(name):
Expand Down

0 comments on commit 5655490

Please sign in to comment.