-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SNOW-1491199 add ast-encoding test in precommit
- Loading branch information
1 parent
8383698
commit 95fd33e
Showing
5 changed files
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## AST Tests | ||
|
||
This driver enables testing of the AST generation that will be used in the server-side Snowpark implementation, starting with Phase 0. | ||
|
||
All generated AST should be tested using this mechanism. To add a test, create a new file under `tests/ast/data`. Files look like the following example. The test driver sets up the session and looks at the accumulated lazy values in the resulting environment. | ||
|
||
N.B. No eager evaluation is permitted, as any intermediate batches will not be observed. This can easily be changed if necessary, however. | ||
|
||
```python | ||
## TEST CASE | ||
|
||
df = session.table(tables.table1) | ||
df = df.filter("STR LIKE '%e%'") | ||
|
||
## EXPECTED ENCODED AST | ||
|
||
[...] | ||
|
||
## EXPECTED UNPARSER OUTPUT | ||
|
||
res1 = session.table('table1') | ||
|
||
res2 = res1.filter('STR LIKE '%e%'') | ||
``` | ||
|
||
To generate the expected output the first time the test is run, or when the AST generation changes, run: | ||
```bash | ||
pytest --update-expectations tests/ast | ||
``` | ||
|
||
For these tests to work, the Unparser must be built in the monorepo: | ||
```bash | ||
cd my-monorepo-path | ||
cd Snowflake/unparser | ||
sbt assembly | ||
``` | ||
|
||
The location of the Unparser can be set either via the environment variable `SNOWPARK_UNPARSER_JAR` or via the _pytest_ commandline argument `--unparser-jar=<path>`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
import logging | ||
import os | ||
from functools import cached_property | ||
|
||
import pytest | ||
|
||
from snowflake.snowpark import Session | ||
|
||
|
||
def default_unparser_path(): | ||
explicit = os.getenv("SNOWPARK_UNPARSER_JAR") | ||
default_default = f"{os.getenv('HOME')}/Snowflake/trunk/Snowpark/unparser/target/scala-2.13/unparser-assembly-0.1.jar" | ||
return explicit or default_default | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption( | ||
"--unparser-jar", | ||
action="store", | ||
default=default_unparser_path(), | ||
type=str, | ||
help="Path to the Unparser JAR built in the monorepo. To build it, run `sbt assembly` from the unparser directory.", | ||
) | ||
parser.addoption( | ||
"--update-expectations", | ||
action="store_true", | ||
default=False, | ||
help="If set, overwrite test files with the actual output as the expected output.", | ||
) | ||
|
||
|
||
def pytest_configure(config): | ||
pytest.unparser_jar = config.getoption("--unparser-jar") | ||
if not os.path.exists(pytest.unparser_jar): | ||
pytest.unparser_jar = None | ||
logging.error( | ||
f"Unparser JAR not found at {pytest.unparser_jar}. " | ||
f"Please set the correct path with --unparser-jar or SNOWPARK_UNPARSER_JAR." | ||
) | ||
pytest.update_expectations = config.getoption("--update-expectations") | ||
|
||
|
||
class TestTables: | ||
def __init__(self, session) -> None: | ||
self._session = session | ||
|
||
@cached_property | ||
def table1(self) -> str: | ||
table_name: str = "table1" | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[1, "one"], | ||
[2, "two"], | ||
[3, "three"], | ||
], | ||
schema=["num", "str"], | ||
) | ||
|
||
@cached_property | ||
def table2(self) -> str: | ||
table_name: str = "table2" | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[1, [1, 2, 3], {"Ashi Garami": "Single Leg X"}, "Kimura"], | ||
[2, [11, 22], {"Sankaku": "Triangle"}, "Coffee"], | ||
[3, [], {}, "Tea"], | ||
], | ||
schema=["idx", "lists", "maps", "strs"], | ||
) | ||
|
||
@cached_property | ||
def df1_table(self) -> str: | ||
table_name: str = "df1" | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[1, 2], | ||
[3, 4], | ||
], | ||
schema=["a", "b"], | ||
) | ||
|
||
@cached_property | ||
def df2_table(self) -> str: | ||
table_name: str = "df2" | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[0, 1], | ||
[3, 4], | ||
], | ||
schema=["c", "d"], | ||
) | ||
|
||
@cached_property | ||
def df3_table(self) -> str: | ||
table_name: str = "df3" | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[1, 2], | ||
], | ||
schema=["a", "b"], | ||
) | ||
|
||
@cached_property | ||
def df4_table(self) -> str: | ||
table_name: str = "df4" | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[2, 1], | ||
], | ||
schema=["b", "a"], | ||
) | ||
|
||
@cached_property | ||
def double_quoted_table(self) -> str: | ||
table_name: str = '"the#qui.ck#bro.wn#""Fox""won\'t#jump!"' | ||
return self._save_table( | ||
table_name, | ||
[ | ||
[1, "one"], | ||
[2, "two"], | ||
[3, "three"], | ||
], | ||
schema=["num", 'Owner\'s""opinion.s'], | ||
) | ||
|
||
def _save_table(self, name: str, *args, **kwargs): | ||
kwargs.pop("_emit_ast", None) | ||
kwargs.pop("_ast_stmt", None) | ||
kwargs.pop("_ast", None) | ||
df = self._session.create_dataframe(*args, _emit_ast=False, **kwargs) | ||
logging.debug("Creating table %s", name) | ||
df.write.save_as_table(name, _emit_ast=False) | ||
return name | ||
|
||
|
||
# For test performance (especially integration tests), it would be very valuable to create the Snowpark session and the | ||
# temporary tables only once per test session. Unfortunately, the local testing features don't work well with any scope | ||
# setting above "function" (e.g. "module" or "session"). | ||
# TODO: SNOW-1763053 use scope="module" | ||
@pytest.fixture(scope="function") | ||
def session(): | ||
with Session.builder.config("local_testing", True).create() as s: | ||
s.ast_enabled = True | ||
yield s | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def tables(session): | ||
return TestTables(session) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
from snowflake.snowpark.mock._connection import MockServerConnection | ||
|
||
|
||
def test_session(session): | ||
assert session.ast_enabled | ||
assert isinstance(session._conn, MockServerConnection) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters