Skip to content

Commit

Permalink
Fix mistakes in column export
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Nov 30, 2024
1 parent 5ffc955 commit 9d6b946
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/pydiverse/transform/_internal/pipe/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def update(self, vb: Verb, rcache: Cache | None = None):
self._update(new_select=vb.select)

elif isinstance(vb, Rename):
self.update(
self._update(
new_cols={
(new_name if (new_name := vb.name_map.get(name)) else name): col
for name, col in self.cols.items()
Expand Down
11 changes: 5 additions & 6 deletions src/pydiverse/transform/_internal/pipe/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ def rename(table: Table, name_map: dict[str, str]) -> Pipeable:

if d := set(name_map) - set(table._cache.cols):
raise ValueError(
f"no column with name `{next(d)}` in table `{table._ast.name}`"
f"no column with name `{next(iter(d))}` in table `{table._ast.name}`"
)

if d := (set(table._cache.cols) - set(name_map)) | set(name_map.values()):
raise ValueError(f"duplicate column name `{next(d)}`")
if d := (set(table._cache.cols) - set(name_map)) & set(name_map.values()):
raise ValueError(f"duplicate column name `{next(iter(d))}`")

new._cache.update(new._ast)
return new
Expand All @@ -292,6 +292,8 @@ def mutate(**kwargs: ColExpr) -> Pipeable: ...

@verb
def mutate(table: Table, **kwargs: ColExpr) -> Pipeable:
if len(kwargs) == 0:
return table
return table >> _mutate(*map(list, zip(*kwargs.items(), strict=True)))


Expand All @@ -302,9 +304,6 @@ def _mutate(
values: list[ColExpr],
uuids: list[uuid.UUID] | None = None,
) -> Pipeable:
if len(names) == 0:
return table

new = copy.copy(table)
if uuids is None:
uuids = [uuid.uuid1() for _ in names]
Expand Down
14 changes: 3 additions & 11 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydiverse.transform import C, Table
from pydiverse.transform._internal.pipe.pipeable import Pipeable, inverse_partial, verb
from pydiverse.transform._internal.pipe.verbs import join, mutate, select
from pydiverse.transform._internal.pipe.verbs import mutate, select


@pytest.fixture
Expand All @@ -27,23 +27,15 @@ def test_getattr(self, tbl1):
_ = tbl1.colXXX

def test_getitem(self, tbl1):
assert tbl1.col1 is tbl1["col1"]
assert tbl1.col2 is tbl1["col2"]

assert tbl1.col2 is tbl1[tbl1.col2.name]
assert tbl1.col2 is tbl1[C.col2.name]
assert tbl1.col1._uuid == tbl1["col1"]._uuid
assert tbl1.col2._uuid is tbl1["col2"]._uuid

def test_iter(self, tbl1, tbl2):
assert repr(list(tbl1)) == repr([tbl1.col1, tbl1.col2])
assert repr(list(tbl2)) == repr([tbl2.col1, tbl2.col2, tbl2.col3])

assert repr(list(tbl2 >> select(tbl2.col2))) == repr([tbl2.col2])

joined = tbl1 >> join(
tbl2 >> select(tbl2.col3), tbl1.col1 == tbl2.col2, "left", suffix="_2"
)
assert repr(list(joined)) == repr([tbl1.col1, tbl1.col2, joined.col3_2])

def test_dir(self, tbl1):
assert dir(tbl1) == ["col1", "col2"]
assert dir(tbl1 >> mutate(x=tbl1.col1)) == ["col1", "col2", "x"]
Expand Down

0 comments on commit 9d6b946

Please sign in to comment.