Skip to content

Commit

Permalink
Unify Snowflake object name handling in the Snowpark AST (#2789)
Browse files Browse the repository at this point in the history
<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1621205

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.

This is the client-side change corresponding to
snowflakedb/snowflake#240557

Unify Snowflake object name handling in the Snowpark AST.
Remove `FnName` and `SpTableName`. They both had `Flat` and `Structured`
variants, but ultimately designate Snowflake object names.
Introduce `data SpName` and `entity SpNameRef` for referring to
Snowflake objects by relative or fully qualified name.
Update `FnNameRefExpr`.
Use `SpNameRef` in a few places that used to use `List[String]`.
  • Loading branch information
sfc-gh-oplaton authored Jan 6, 2025
1 parent 2247ae5 commit 3a8612f
Show file tree
Hide file tree
Showing 118 changed files with 42,605 additions and 54,783 deletions.
54 changes: 33 additions & 21 deletions src/snowflake/snowpark/_internal/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def build_proto_from_struct_type(
ast_field.nullable = field.nullable


def build_sp_name(name: Union[str, Iterable[str]], expr: proto.SpName) -> None:
if isinstance(name, str):
expr.sp_name_flat.name = name
elif isinstance(name, Iterable):
expr.sp_name_structured.name.extend(name)
else:
raise ValueError(
f"Invalid object name: {name}. The object name must be a string or an iterable of strings."
)


# TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
def _set_fn_name(
name: Union[str, Iterable[str]], fn: proto.FnNameRefExpr
Expand All @@ -358,26 +369,27 @@ def _set_fn_name(
Raises:
ValueError: Raised if the function name is not a string or an iterable of strings.
"""
if isinstance(name, str):
fn.name.fn_name_flat.name = name # type: ignore[attr-defined] # TODO(SNOW-1491199) # "FnNameRefExpr" has no attribute "name"
elif isinstance(name, Iterable):
fn.name.fn_name_structured.name.extend(name) # type: ignore[attr-defined] # TODO(SNOW-1491199) # "FnNameRefExpr" has no attribute "name"
else:
raise ValueError(
f"Invalid function name: {name}. The function name must be a string or an iterable of strings."
)
try:
build_sp_name(name, fn.name.name)
except ValueError as e:
raise ValueError("Invalid function name") from e


# TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
def build_sp_table_name( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function is missing a return type annotation
expr_builder: proto.SpTableName, name: Union[str, Iterable[str]]
): # pragma: no cover
if isinstance(name, str):
expr_builder.sp_table_name_flat.name = name
elif isinstance(name, Iterable):
expr_builder.sp_table_name_structured.name.extend(name)
else:
raise ValueError(f"Invalid name type {type(name)} for SpTableName entity.")
def build_sp_table_name(
expr_builder: proto.SpNameRef, name: Union[str, Iterable[str]]
) -> None: # pragma: no cover
try:
build_sp_name(name, expr_builder.name)
except ValueError as e:
raise ValueError("Invalid table name") from e


def build_sp_view_name(expr: proto.SpNameRef, name: Union[str, Iterable[str]]) -> None:
try:
build_sp_name(name, expr.name)
except ValueError as e:
raise ValueError("Invalid view name") from e


def build_function_expr(
Expand Down Expand Up @@ -1108,7 +1120,7 @@ def build_udf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function i
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1197,7 +1209,7 @@ def build_udaf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location.value = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1294,7 +1306,7 @@ def build_udtf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1406,7 +1418,7 @@ def build_sproc( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down
Loading

0 comments on commit 3a8612f

Please sign in to comment.