diff --git a/src/pydiverse/transform/_internal/pipe/table.py b/src/pydiverse/transform/_internal/pipe/table.py index 03e0855..0927a4d 100644 --- a/src/pydiverse/transform/_internal/pipe/table.py +++ b/src/pydiverse/transform/_internal/pipe/table.py @@ -194,7 +194,7 @@ def update(self, vb: Verb, rcache: Cache | None = None): self._update(new_select=vb.select) elif isinstance(vb, Rename): - self.update( + self._update( new_cols={ (new_name if (new_name := vb.name_map.get(name)) else name): col for name, col in self.cols.items() diff --git a/src/pydiverse/transform/_internal/pipe/verbs.py b/src/pydiverse/transform/_internal/pipe/verbs.py index 25b15cc..2aa4853 100644 --- a/src/pydiverse/transform/_internal/pipe/verbs.py +++ b/src/pydiverse/transform/_internal/pipe/verbs.py @@ -276,11 +276,11 @@ def rename(table: Table, name_map: dict[str, str]) -> Pipeable: if d := set(name_map) - set(table._cache.cols): raise ValueError( - f"no column with name `{next(d)}` in table `{table._ast.name}`" + f"no column with name `{next(iter(d))}` in table `{table._ast.name}`" ) - if d := (set(table._cache.cols) - set(name_map)) | set(name_map.values()): - raise ValueError(f"duplicate column name `{next(d)}`") + if d := (set(table._cache.cols) - set(name_map)) & set(name_map.values()): + raise ValueError(f"duplicate column name `{next(iter(d))}`") new._cache.update(new._ast) return new @@ -292,6 +292,8 @@ def mutate(**kwargs: ColExpr) -> Pipeable: ... @verb def mutate(table: Table, **kwargs: ColExpr) -> Pipeable: + if len(kwargs) == 0: + return table return table >> _mutate(*map(list, zip(*kwargs.items(), strict=True))) @@ -302,9 +304,6 @@ def _mutate( values: list[ColExpr], uuids: list[uuid.UUID] | None = None, ) -> Pipeable: - if len(names) == 0: - return table - new = copy.copy(table) if uuids is None: uuids = [uuid.uuid1() for _ in names] diff --git a/tests/test_core.py b/tests/test_core.py index 367a487..f7b4d52 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,7 @@ from pydiverse.transform import C, Table from pydiverse.transform._internal.pipe.pipeable import Pipeable, inverse_partial, verb -from pydiverse.transform._internal.pipe.verbs import join, mutate, select +from pydiverse.transform._internal.pipe.verbs import mutate, select @pytest.fixture @@ -27,11 +27,8 @@ def test_getattr(self, tbl1): _ = tbl1.colXXX def test_getitem(self, tbl1): - assert tbl1.col1 is tbl1["col1"] - assert tbl1.col2 is tbl1["col2"] - - assert tbl1.col2 is tbl1[tbl1.col2.name] - assert tbl1.col2 is tbl1[C.col2.name] + assert tbl1.col1._uuid == tbl1["col1"]._uuid + assert tbl1.col2._uuid is tbl1["col2"]._uuid def test_iter(self, tbl1, tbl2): assert repr(list(tbl1)) == repr([tbl1.col1, tbl1.col2]) @@ -39,11 +36,6 @@ def test_iter(self, tbl1, tbl2): assert repr(list(tbl2 >> select(tbl2.col2))) == repr([tbl2.col2]) - joined = tbl1 >> join( - tbl2 >> select(tbl2.col3), tbl1.col1 == tbl2.col2, "left", suffix="_2" - ) - assert repr(list(joined)) == repr([tbl1.col1, tbl1.col2, joined.col3_2]) - def test_dir(self, tbl1): assert dir(tbl1) == ["col1", "col2"] assert dir(tbl1 >> mutate(x=tbl1.col1)) == ["col1", "col2", "x"]