Skip to content

Commit

Permalink
fix: Error instead of truncate if length mismatch for several str f…
Browse files Browse the repository at this point in the history
…unctions (#20781)
  • Loading branch information
etiennebacher authored Jan 20, 2025
1 parent bf57bde commit 099ee3c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 10 deletions.
32 changes: 22 additions & 10 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ fn extract_many(
ascii_case_insensitive: bool,
overlapping: bool,
) -> PolarsResult<Column> {
_check_same_length(s, "extract_many")?;
let ca = s[0].str()?;
let patterns = &s[1];

Expand Down Expand Up @@ -524,6 +525,7 @@ pub(super) fn len_bytes(s: &Column) -> PolarsResult<Column> {

#[cfg(feature = "regex")]
pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
_check_same_length(s, "contains")?;
let ca = s[0].str()?;
let pat = s[1].str()?;
ca.contains_chunked(pat, literal, strict)
Expand All @@ -532,20 +534,23 @@ pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResul

#[cfg(feature = "regex")]
pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
_check_same_length(s, "find")?;
let ca = s[0].str()?;
let pat = s[1].str()?;
ca.find_chunked(pat, literal, strict)
.map(|ok| ok.into_column())
}

pub(super) fn ends_with(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "ends_with")?;
let ca = &s[0].str()?.as_binary();
let suffix = &s[1].str()?.as_binary();

Ok(ca.ends_with_chunked(suffix).into_column())
}

pub(super) fn starts_with(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "starts_with")?;
let ca = s[0].str()?;
let prefix = s[1].str()?;
Ok(ca.starts_with_chunked(prefix).into_column())
Expand Down Expand Up @@ -579,37 +584,43 @@ pub(super) fn pad_end(s: &Column, length: usize, fill_char: char) -> PolarsResul

#[cfg(feature = "string_pad")]
pub(super) fn zfill(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "zfill")?;
let ca = s[0].str()?;
let length_s = s[1].strict_cast(&DataType::UInt64)?;
let length = length_s.u64()?;
Ok(ca.zfill(length).into_column())
}

pub(super) fn strip_chars(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "strip_chars")?;
let ca = s[0].str()?;
let pat_s = &s[1];
ca.strip_chars(pat_s).map(|ok| ok.into_column())
}

pub(super) fn strip_chars_start(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "strip_chars_start")?;
let ca = s[0].str()?;
let pat_s = &s[1];
ca.strip_chars_start(pat_s).map(|ok| ok.into_column())
}

pub(super) fn strip_chars_end(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "strip_chars_end")?;
let ca = s[0].str()?;
let pat_s = &s[1];
ca.strip_chars_end(pat_s).map(|ok| ok.into_column())
}

pub(super) fn strip_prefix(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "strip_prefix")?;
let ca = s[0].str()?;
let prefix = s[1].str()?;
Ok(ca.strip_prefix(prefix).into_column())
}

pub(super) fn strip_suffix(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "strip_suffix")?;
let ca = s[0].str()?;
let suffix = s[1].str()?;
Ok(ca.strip_suffix(suffix).into_column())
Expand Down Expand Up @@ -1023,32 +1034,32 @@ fn _ensure_lengths(s: &[Column]) -> bool {
.all(|series| series.len() == 1 || series.len() == len)
}

pub(super) fn str_slice(s: &[Column]) -> PolarsResult<Column> {
fn _check_same_length(s: &[Column], fn_name: &str) -> Result<(), PolarsError> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_slice` should have equal or unit length",
ComputeError: "all series in `str.{}()` should have equal or unit length",
fn_name
);
Ok(())
}

pub(super) fn str_slice(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "slice")?;
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_column())
}

pub(super) fn str_head(s: &[Column]) -> PolarsResult<Column> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_head` should have equal or unit length",
);
_check_same_length(s, "head")?;
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_head(n)?.into_column())
}

pub(super) fn str_tail(s: &[Column]) -> PolarsResult<Column> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_tail` should have equal or unit length",
);
_check_same_length(s, "tail")?;
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_tail(n)?.into_column())
Expand Down Expand Up @@ -1092,6 +1103,7 @@ pub(super) fn json_decode(

#[cfg(feature = "extract_jsonpath")]
pub(super) fn json_path_match(s: &[Column]) -> PolarsResult<Column> {
_check_same_length(s, "json_path_match")?;
let ca = s[0].str()?;
let pat = s[1].str()?;
Ok(ca.json_path_match(pat)?.into_column())
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/operations/namespaces/string/test_pad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -88,6 +91,12 @@ def test_str_zfill_expr() -> None:
assert_frame_equal(out, expected)


def test_str_zfill_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.zfill(pl.Series([1, 2])))


def test_pad_end_unicode() -> None:
lf = pl.LazyFrame({"a": ["Café", "345", "東京", None]})

Expand Down
72 changes: 72 additions & 0 deletions py-polars/tests/unit/operations/namespaces/string/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def test_str_slice_expr() -> None:
df.select(pl.col("a").str.slice(0, -1))


def test_str_slice_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.slice(pl.Series([1, 2])))


@pytest.mark.parametrize(
("input", "n", "output"),
[
Expand Down Expand Up @@ -115,6 +121,12 @@ def test_str_head_expr() -> None:
assert_frame_equal(out, expected)


def test_str_head_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.head(pl.Series([1, 2])))


@pytest.mark.parametrize(
("input", "n", "output"),
[
Expand Down Expand Up @@ -176,6 +188,12 @@ def test_str_tail_expr() -> None:
assert_frame_equal(out, expected)


def test_str_tail_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.tail(pl.Series([1, 2])))


def test_str_slice_multibyte() -> None:
ref = "你好世界"
s = pl.Series([ref])
Expand Down Expand Up @@ -212,6 +230,12 @@ def test_str_contains() -> None:
assert_series_equal(s.str.contains("mes"), expected)


def test_str_contains_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.contains(pl.Series(["a", "b"]))) # type: ignore [arg-type]


def test_count_match_literal() -> None:
s = pl.Series(["12 dbc 3xy", "cat\\w", "1zy3\\d\\d", None])
out = s.str.count_matches(r"\d", literal=True)
Expand Down Expand Up @@ -338,6 +362,12 @@ def test_str_find_escaped_chars() -> None:
)


def test_str_find_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.find(pl.Series(["a", "b"]))) # type: ignore [arg-type]


def test_hex_decode_return_dtype() -> None:
data = {"a": ["68656c6c6f", "776f726c64"]}
expr = pl.col("a").str.decode("hex")
Expand Down Expand Up @@ -515,6 +545,12 @@ def test_str_strip_chars() -> None:
assert_series_equal(s.str.strip_chars(" hwo"), expected)


def test_str_strip_chars_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.strip_chars(pl.Series(["a", "b"])))


def test_str_strip_chars_start() -> None:
s = pl.Series([" hello ", "\t world"])
expected = pl.Series(["hello ", "world"])
Expand All @@ -527,6 +563,12 @@ def test_str_strip_chars_start() -> None:
assert_series_equal(s.str.strip_chars_start("hw "), expected)


def test_str_strip_chars_start_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.strip_chars_start(pl.Series(["a", "b"])))


def test_str_strip_chars_end() -> None:
s = pl.Series([" hello ", "world\t "])
expected = pl.Series([" hello", "world"])
Expand All @@ -539,6 +581,12 @@ def test_str_strip_chars_end() -> None:
assert_series_equal(s.str.strip_chars_end("odl \t"), expected)


def test_str_strip_chars_end_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.strip_chars_end(pl.Series(["a", "b"])))


def test_str_strip_whitespace() -> None:
s = pl.Series("a", ["trailing ", " leading", " both "])

Expand Down Expand Up @@ -579,6 +627,12 @@ def test_str_strip_prefix_suffix_expr() -> None:
}


def test_str_strip_prefix_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.strip_prefix(pl.Series(["a", "b"])))


def test_str_strip_suffix() -> None:
s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", "", None])
expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", "", None])
Expand All @@ -588,6 +642,12 @@ def test_str_strip_suffix() -> None:
assert_series_equal(s.str.strip_suffix(pl.lit(None, dtype=pl.String)), expected)


def test_str_strip_suffix_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.strip_suffix(pl.Series(["a", "b"])))


def test_str_split() -> None:
a = pl.Series("a", ["a, b", "a", "ab,c,de"])
for out in [a.str.split(","), pl.select(pl.lit(a).str.split(",")).to_series()]:
Expand Down Expand Up @@ -730,6 +790,12 @@ def test_json_path_match() -> None:
assert_frame_equal(out, expected)


def test_str_json_path_match_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.json_path_match(pl.Series(["a", "b"])))


def test_extract_regex() -> None:
s = pl.Series(
[
Expand Down Expand Up @@ -1799,6 +1865,12 @@ def test_extract_many() -> None:
assert f2.to_list() == [[0], [0, 5]]


def test_str_extract_many_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.extract_many(pl.Series(["a", "b"])))


def test_json_decode_raise_on_data_type_mismatch_13061() -> None:
assert_series_equal(
pl.Series(["null", "null"]).str.json_decode(infer_schema_length=1),
Expand Down

0 comments on commit 099ee3c

Please sign in to comment.