Skip to content

Commit

Permalink
Adding DataChain.column(...) and fixing functions and types (#226)
Browse files Browse the repository at this point in the history
* fixing sql to python

* added tests for sql to python and changed input type of sql to python

* changing docstring

* fixing tests and Decimal type conversion

* returning exception when column is not found

* changed docstring

* fixed typo

* added new exception type

* renaming error class

* skipping division expression tests for CH

* using new column method from dc

* updating studio branch

* return to develop
  • Loading branch information
ilongin authored Aug 6, 2024
1 parent d47aee3 commit bc7608b
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 26 deletions.
31 changes: 13 additions & 18 deletions src/datachain/lib/convert/sql_to_python.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
from datetime import datetime
from decimal import Decimal
from typing import Any

from sqlalchemy import ARRAY, JSON, Boolean, DateTime, Float, Integer, String
from sqlalchemy import ColumnElement

from datachain.data_storage.sqlite import Column

SQL_TO_PYTHON = {
String: str,
Integer: int,
Float: float,
Boolean: bool,
DateTime: datetime,
ARRAY: list,
JSON: dict,
}
def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
res = {}
for name, sql_exp in args_map.items():
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
res[name] = type_


def sql_to_python(args_map: dict[str, Column]) -> dict[str, Any]:
return {
k: SQL_TO_PYTHON.get(type(v.type), str) # type: ignore[union-attr]
for k, v in args_map.items()
}
return res
24 changes: 24 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import sqlalchemy
from pydantic import BaseModel, create_model
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.sql.sqltypes import NullType

from datachain import DataModel
from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.convert.values_to_tuples import values_to_tuples
from datachain.lib.data_model import DataType
from datachain.lib.dataset_info import DatasetInfo
Expand Down Expand Up @@ -110,6 +112,11 @@ def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: st
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")


class DataChainColumnError(DataChainParamsError): # noqa: D101
def __init__(self, col_name, msg): # noqa: D107
super().__init__(f"Error for column {col_name}: {msg}")


OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]


Expand Down Expand Up @@ -225,6 +232,17 @@ def schema(self) -> dict[str, DataType]:
"""Get schema of the chain."""
return self._effective_signals_schema.values

def column(self, name: str) -> Column:
"""Returns Column instance with a type if name is found in current schema,
otherwise raises an exception.
"""
name_path = name.split(".")
for path, type_, _, _ in self.signals_schema.get_flat_tree():
if path == name_path:
return Column(name, python_to_sql(type_))

raise ValueError(f"Column with name {name} not found in the schema")

def print_schema(self) -> None:
"""Print schema of the chain."""
self._effective_signals_schema.print_tree()
Expand Down Expand Up @@ -829,6 +847,12 @@ def mutate(self, **kwargs) -> "Self":
)
```
"""
for col_name, expr in kwargs.items():
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
raise DataChainColumnError(
col_name, f"Cannot infer type with expression {expr}"
)

mutated = {}
schema = self.signals_schema
for name, value in kwargs.items():
Expand Down
5 changes: 3 additions & 2 deletions src/datachain/sql/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from sqlalchemy.sql.expression import func

from . import path, string
from . import array, path, string
from .array import avg
from .conditional import greatest, least
from .random import rand

count = func.count
sum = func.sum
avg = func.avg
min = func.min
max = func.max

__all__ = [
"array",
"avg",
"count",
"func",
Expand Down
8 changes: 8 additions & 0 deletions src/datachain/sql/functions/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ class sip_hash_64(GenericFunction): # noqa: N801
inherit_cache = True


class avg(GenericFunction): # noqa: N801
type = Float()
package = "array"
name = "avg"
inherit_cache = True


compiler_not_implemented(cosine_distance)
compiler_not_implemented(euclidean_distance)
compiler_not_implemented(length)
compiler_not_implemented(sip_hash_64)
compiler_not_implemented(avg)
5 changes: 5 additions & 0 deletions src/datachain/sql/sqlite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def setup():
compiles(conditional.least, "sqlite")(compile_least)
compiles(Values, "sqlite")(compile_values)
compiles(random.rand, "sqlite")(compile_rand)
compiles(array.avg, "sqlite")(compile_avg)

if load_usearch_extension(sqlite3.connect(":memory:")):
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
Expand Down Expand Up @@ -349,6 +350,10 @@ def compile_rand(element, compiler, **kwargs):
return compiler.process(func.random(), **kwargs)


def compile_avg(element, compiler, **kwargs):
return compiler.process(func.avg(*element.clauses.clauses), **kwargs)


def load_usearch_extension(conn) -> bool:
try:
# usearch is part of the vector optional dependencies
Expand Down
100 changes: 94 additions & 6 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from datachain import Column
from datachain.lib.data_model import DataModel
from datachain.lib.dc import C, DataChain, Sys
from datachain.lib.dc import C, DataChain, DataChainColumnError, Sys
from datachain.lib.file import File
from datachain.lib.signal_schema import (
SignalResolvingError,
Expand All @@ -19,6 +19,8 @@
)
from datachain.lib.udf_signature import UdfSignatureError
from datachain.lib.utils import DataChainParamsError
from datachain.sql import functions as func
from datachain.sql.types import Float, Int64, String
from tests.utils import skip_if_not_sqlite

DF_DATA = {
Expand Down Expand Up @@ -1254,14 +1256,20 @@ def test_column_math(test_session):
fib = [1, 1, 2, 3, 5, 8]
chain = DataChain.from_values(num=fib, session=test_session)

ch = chain.mutate(add2=Column("num") + 2)
ch = chain.mutate(add2=chain.column("num") + 2)
assert list(ch.collect("add2")) == [x + 2 for x in fib]

ch = chain.mutate(div2=Column("num") / 2.0)
assert list(ch.collect("div2")) == [x / 2.0 for x in fib]
ch2 = ch.mutate(x=1 - ch.column("add2"))
assert list(ch2.collect("x")) == [1 - (x + 2.0) for x in fib]


def test_column_math_division(test_session):
skip_if_not_sqlite()
fib = [1, 1, 2, 3, 5, 8]
chain = DataChain.from_values(num=fib, session=test_session)

ch2 = ch.mutate(x=1 - Column("div2"))
assert list(ch2.collect("x")) == [1 - (x / 2.0) for x in fib]
ch = chain.mutate(div2=chain.column("num") / 2.0)
assert list(ch.collect("div2")) == [x / 2.0 for x in fib]


def test_from_values_array_of_floats(test_session):
Expand Down Expand Up @@ -1409,3 +1417,83 @@ def test_rename_object_name_with_mutate(catalog):
assert ds.signals_schema.values.get("ids") is int
assert "file" not in ds.signals_schema.values
assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"]


def test_column(catalog):
ds = DataChain.from_values(
ints=[1, 2], floats=[0.5, 0.5], file=[File(name="a"), File(name="b")]
)

c = ds.column("ints")
assert isinstance(c, Column)
assert c.name == "ints"
assert isinstance(c.type, Int64)

c = ds.column("floats")
assert isinstance(c, Column)
assert c.name == "floats"
assert isinstance(c.type, Float)

c = ds.column("file.name")
assert isinstance(c, Column)
assert c.name == "file__name"
assert isinstance(c.type, String)

with pytest.raises(ValueError):
c = ds.column("missing")


def test_mutate_with_subtraction():
ds = DataChain.from_values(id=[1, 2])
assert ds.mutate(new=ds.column("id") - 1).signals_schema.values["new"] is int


def test_mutate_with_addition():
ds = DataChain.from_values(id=[1, 2])
assert ds.mutate(new=ds.column("id") + 1).signals_schema.values["new"] is int


def test_mutate_with_division():
ds = DataChain.from_values(id=[1, 2])
assert ds.mutate(new=ds.column("id") / 10).signals_schema.values["new"] is float


def test_mutate_with_multiplication():
ds = DataChain.from_values(id=[1, 2])
assert ds.mutate(new=ds.column("id") * 10).signals_schema.values["new"] is int


def test_mutate_with_func():
ds = DataChain.from_values(id=[1, 2])
assert (
ds.mutate(new=func.avg(ds.column("id"))).signals_schema.values["new"] is float
)


def test_mutate_with_complex_expression():
ds = DataChain.from_values(id=[1, 2], name=["Jim", "Jon"])
assert (
ds.mutate(
new=(func.sum(ds.column("id"))) * (5 - func.min(ds.column("id")))
).signals_schema.values["new"]
is int
)


def test_mutate_with_saving():
skip_if_not_sqlite()
ds = DataChain.from_values(id=[1, 2])
ds = ds.mutate(new=ds.column("id") / 2).save("mutated")

ds = DataChain(name="mutated")
assert ds.signals_schema.values["new"] is float
assert list(ds.collect("new")) == [0.5, 1.0]


def test_mutate_with_expression_without_type(catalog):
with pytest.raises(DataChainColumnError) as excinfo:
DataChain.from_values(id=[1, 2]).mutate(new=(Column("id") - 1)).save()

assert str(excinfo.value) == (
"Error for column new: Cannot infer type with expression id - :id_1"
)
28 changes: 28 additions & 0 deletions tests/unit/lib/test_sql_to_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from sqlalchemy.sql.sqltypes import NullType

from datachain import Column
from datachain.lib.convert.sql_to_python import sql_to_python
from datachain.sql import functions as func
from datachain.sql.types import Float, Int64, String


def test_sql_columns_to_python_types():
assert sql_to_python(
{
"name": Column("name", String),
"age": Column("age", Int64),
"score": Column("score", Float),
}
) == {"name": str, "age": int, "score": float}


def test_sql_expression_to_python_types():
assert sql_to_python({"age": Column("age", Int64) - 2}) == {"age": int}


def test_sql_function_to_python_types():
assert sql_to_python({"age": func.avg(Column("age", Int64))}) == {"age": float}


def test_sql_to_python_types_default_type():
assert sql_to_python({"null": Column("null", NullType)}) == {"null": str}

0 comments on commit bc7608b

Please sign in to comment.