Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aling/fix diamond join #2851

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark._internal.utils import (
quote_name,
merge_multiple_dicts_with_assertion,
)
from snowflake.snowpark.types import BooleanType, _NumericType

ARRAY_BIND_THRESHOLD = 512
Expand All @@ -165,8 +168,11 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
self.session = session
self.plan_builder = SnowflakePlanBuilder(self.session)
self.generated_alias_maps = {}
# key: expr_id, snowflake_plan_uuid ) -> value: alias
self.generated_alias_maps_v2 = {}
self.subquery_plans = []
self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None
self.alias_maps_to_use_v2: Optional[Dict[uuid.UUID, str]] = None

def analyze(
self,
Expand Down Expand Up @@ -366,9 +372,13 @@ def analyze(
return expr.sql

if isinstance(expr, Attribute):
assert self.alias_maps_to_use is not None
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
return quote_name(name)
# assert self.alias_maps_to_use is not None
# name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
assert self.alias_maps_to_use_v2 is not None
name2 = self.alias_maps_to_use_v2.get(
(expr.expr_id, expr.snowflake_plan_uuid), expr.name
)
return quote_name(name2)

if isinstance(expr, UnresolvedAttribute):
if expr.df_alias:
Expand Down Expand Up @@ -630,6 +640,10 @@ def unary_expression_extractor(
quoted_name = quote_name(expr.name)
if isinstance(expr.child, Attribute):
self.generated_alias_maps[expr.child.expr_id] = quoted_name
# assert expr.child.snowflake_plan_uuid is not None
self.generated_alias_maps_v2[
(expr.child.expr_id, expr.child.snowflake_plan_uuid)
] = quoted_name
assert self.alias_maps_to_use is not None
for k, v in self.alias_maps_to_use.items():
if v == expr.child.name:
Expand Down Expand Up @@ -724,6 +738,7 @@ def window_frame_boundary(
boundary: Expression,
df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]],
) -> str:

# it means interval preceding
if isinstance(boundary, UnaryMinus) and isinstance(boundary.child, Interval):
return window_frame_boundary_expression(
Expand Down Expand Up @@ -772,10 +787,15 @@ def to_sql_try_avoid_cast(
def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
self.subquery_plans = []
self.generated_alias_maps = {}
self.generated_alias_maps_v2 = {}

result = self.do_resolve(logical_plan)

# result is a snowflake plan
result.add_aliases(self.generated_alias_maps)
result.add_aliases_v2(self.generated_alias_maps_v2)
# for k, v in self.generated_alias_maps_v2.items():
# new_dict = {k: (v, result.uuid)}
# result.add_aliases_v2(new_dict)

if self.subquery_plans:
result = result.with_subqueries(self.subquery_plans)
Expand Down Expand Up @@ -803,8 +823,11 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
if isinstance(logical_plan, Selectable):
# Selectable doesn't have children. It already has the expr_to_alias dict.
self.alias_maps_to_use = logical_plan.expr_to_alias.copy()
self.alias_maps_to_use_v2 = logical_plan.expr_to_alias_v2.copy()
else:

use_maps = {}
use_maps_v2 = {}
# get counts of expr_to_alias keys
counts = Counter()
for v in resolved_children.values():
Expand All @@ -821,6 +844,12 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:

self.alias_maps_to_use = use_maps

use_maps_v2 = merge_multiple_dicts_with_assertion(
*[v.expr_to_alias_v2 for v in resolved_children.values()]
)

self.alias_maps_to_use_v2 = use_maps_v2

res = self.do_resolve_with_resolved_children(
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
)
Expand Down
23 changes: 21 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,39 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:


class Attribute(Expression, NamedExpression):
def __init__(self, name: str, datatype: DataType, nullable: bool = True) -> None:
def __init__(
self,
name: str,
datatype: DataType,
nullable: bool = True,
*,
snowflake_plan_uuid: str = None,
) -> None:
super().__init__()
self.name = name
self.datatype: DataType = datatype
self.nullable = nullable
# non-breaking way to add snowflake_plan_uuid
self.snowflake_plan_uuid = snowflake_plan_uuid

def with_name(self, new_name: str) -> "Attribute":
def with_name(
self, new_name: str, *, snowflake_plan_uuid: str = None
) -> "Attribute":
assert snowflake_plan_uuid
if self.name == new_name:
# # lazy update snowflake_plan_uuid
# if not self.snowflake_plan_uuid:
# self.snowflake_plan_uuid = snowflake_plan_uuid
# else:
# # one attribute can only belong to one snowflake plan
# assert self.snowflake_plan_uuid == snowflake_plan_uuid
return self
else:
return Attribute(
snowflake.snowpark._internal.utils.quote_name(new_name),
self.datatype,
self.nullable,
snowflake_plan_uuid=snowflake_plan_uuid,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(
self._column_states: Optional[ColumnStateDict] = None
self._snowflake_plan: Optional[SnowflakePlan] = None
self.expr_to_alias = {}
self.expr_to_alias_v2 = {}
self.df_aliased_col_name_to_real_col_name: DefaultDict[
str, Dict[str, str]
] = defaultdict(dict)
Expand Down Expand Up @@ -315,6 +316,7 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name,
source_plan=self,
referenced_ctes=self.referenced_ctes,
expr_to_alias_v2=self.expr_to_alias_v2,
)
# set api_calls to self._snowflake_plan outside of the above constructor
# because the constructor copy api_calls.
Expand Down Expand Up @@ -588,6 +590,7 @@ def __init__(self, snowflake_plan: LogicalPlan, *, analyzer: "Analyzer") -> None
else analyzer.resolve(snowflake_plan)
)
self.expr_to_alias.update(self._snowflake_plan.expr_to_alias)
self.expr_to_alias_v2.update(self._snowflake_plan.expr_to_alias_v2)
self.df_aliased_col_name_to_real_col_name.update(
self._snowflake_plan.df_aliased_col_name_to_real_col_name
)
Expand Down Expand Up @@ -680,6 +683,7 @@ def __init__(
self._projection_in_str = None
self._query_params = None
self.expr_to_alias.update(self.from_.expr_to_alias)
self.expr_to_alias_v2.update(self.from_.expr_to_alias_v2)
self.df_aliased_col_name_to_real_col_name.update(
self.from_.df_aliased_col_name_to_real_col_name
)
Expand Down Expand Up @@ -1042,6 +1046,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
new.expr_to_alias = copy(
self.expr_to_alias
) # use copy because we don't want two plans to share the same list. If one mutates, the other ones won't be impacted.
new.expr_to_alias_v2 = copy(self.expr_to_alias_v2)
new.flatten_disabled = self.flatten_disabled
# no need to flatten the projection complexity since the select projection is already flattened.
new._merge_projection_complexity_with_subquery = False
Expand Down
58 changes: 56 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@
else:
from collections.abc import Iterable

iiid = 0


def get_next_id():
global iiid
iiid += 1
return str(iiid)


class SnowflakePlan(LogicalPlan):
class Decorator:
Expand Down Expand Up @@ -226,12 +234,14 @@ def __init__(
referenced_ctes: Optional[Dict[WithQueryBlock, int]] = None,
*,
session: "snowflake.snowpark.session.Session",
expr_to_alias_v2: Optional[Dict] = None,
) -> None:
super().__init__()
self.queries = queries
self.schema_query = schema_query
self.post_actions = post_actions if post_actions else []
self.expr_to_alias = expr_to_alias if expr_to_alias else {}
self.expr_to_alias_v2 = expr_to_alias_v2 if expr_to_alias_v2 else {}
self.session = session
self.source_plan = source_plan
self.is_ddl_on_temp_object = is_ddl_on_temp_object
Expand All @@ -258,7 +268,8 @@ def __init__(
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None
# UUID for the plan to uniquely identify the SnowflakePlan object. We also use this
# to UUID track queries that are generated from the same plan.
self._uuid = str(uuid.uuid4())
self._uuid = get_next_id()
# self._uuid = str(uuid.uuid4())
# Metadata for the plan
self._metadata: PlanMetadata = infer_metadata(
self.source_plan,
Expand Down Expand Up @@ -328,6 +339,7 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl
new_schema_query,
post_actions=new_post_actions,
expr_to_alias=self.expr_to_alias,
expr_to_alias_v2=self.expr_to_alias_v2,
session=self.session,
source_plan=self.source_plan,
api_calls=api_calls,
Expand Down Expand Up @@ -361,11 +373,16 @@ def attributes(self) -> List[Attribute]:
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
self.schema_query = schema_value_statement(attributes)
for attr in attributes:
attr.plan_uuid = self.uuid
return attributes

@cached_property
def output(self) -> List[Attribute]:
return [Attribute(a.name, a.datatype, a.nullable) for a in self.attributes]
return [
Attribute(a.name, a.datatype, a.nullable, snowflake_plan_uuid=self.uuid)
for a in self.attributes
]

@property
def output_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -457,6 +474,9 @@ def __copy__(self) -> "SnowflakePlan":
self.df_aliased_col_name_to_real_col_name,
session=self.session,
referenced_ctes=self.referenced_ctes,
expr_to_alias_v2=dict(self.expr_to_alias_v2)
if self.expr_to_alias_v2
else None,
)
else:
return SnowflakePlan(
Expand All @@ -470,6 +490,9 @@ def __copy__(self) -> "SnowflakePlan":
self.df_aliased_col_name_to_real_col_name,
session=self.session,
referenced_ctes=self.referenced_ctes,
expr_to_alias_v2=dict(self.expr_to_alias_v2)
if self.expr_to_alias_v2
else None,
)

def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
Expand All @@ -487,6 +510,9 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
expr_to_alias=copy.deepcopy(self.expr_to_alias)
if self.expr_to_alias
else None,
expr_to_alias_v2=copy.deepcopy(self.expr_to_alias_v2)
if self.expr_to_alias_v2
else None,
source_plan=copied_source_plan,
is_ddl_on_temp_object=self.is_ddl_on_temp_object,
api_calls=copy.deepcopy(self.api_calls) if self.api_calls else None,
Expand All @@ -509,6 +535,18 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
def add_aliases(self, to_add: Dict) -> None:
self.expr_to_alias = {**self.expr_to_alias, **to_add}

def add_aliases_v2(self, to_add: Dict) -> None:
# conflicted = False
for key in self.expr_to_alias_v2.keys() & to_add.keys(): # Find common keys
if self.expr_to_alias_v2[key] != to_add[key]:
# conflicted = True
# print(
# f"need to overwrite, the expr has been realiased for key '{key}', old value {self.expr_to_alias_v2[key]} -> new value {to_add[key]}"
# )
self.expr_to_alias_v2[key] = to_add[key]
for key in to_add.keys() - self.expr_to_alias_v2.keys(): # Find new keys
self.expr_to_alias_v2[key] = to_add[key]


class SnowflakePlanBuilder:
def __init__(
Expand Down Expand Up @@ -592,6 +630,21 @@ def build_binary(
}.items()
if k not in common_columns
}

from snowflake.snowpark._internal.utils import (
merge_multiple_dicts_with_assertion,
)

new_expr_to_alias_v2 = merge_multiple_dicts_with_assertion(
select_left.expr_to_alias_v2, select_right.expr_to_alias_v2
)
# new_expr_to_alias_v2 = {
# k: v
# for k, v in {
# **select_left.expr_to_alias_v2,
# **select_right.expr_to_alias_v2,
# }.items()
# }
api_calls = [*select_left.api_calls, *select_right.api_calls]

# Need to do a deduplication to avoid repeated query.
Expand Down Expand Up @@ -639,6 +692,7 @@ def build_binary(
api_calls=api_calls,
session=self.session,
referenced_ctes=referenced_ctes,
expr_to_alias_v2=new_expr_to_alias_v2,
)

def query(
Expand Down
21 changes: 21 additions & 0 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,3 +1457,24 @@ def next(self) -> int:


global_counter: GlobalCounter = GlobalCounter()


def merge_multiple_dicts_with_assertion(*dicts):
"""
Merge multiple dictionaries with an assertion that checks if values for
duplicate keys are the same.

:param dicts: Multiple dictionaries to merge
:return: Merged dictionary
"""
merged_dict = {}

for d in dicts:
for key, value in d.items():
if key in merged_dict:
assert (
merged_dict[key] == value
), f"Conflict for key '{key}': {merged_dict[key]} != {value}"
merged_dict[key] = value

return merged_dict
Loading
Loading