Skip to content

Commit

Permalink
SNOW-1890997: Support schema string without colon in create_dataframe (
Browse files Browse the repository at this point in the history
…#2913)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1890997

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.

Please write a short description of how your code change solves the
related issue.
  • Loading branch information
sfc-gh-jdu authored Jan 24, 2025
1 parent c6d83b8 commit cf804cd
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 27 deletions.
83 changes: 69 additions & 14 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ctypes
import datetime
import decimal
import functools
import re
import sys
import typing # noqa: F401
Expand Down Expand Up @@ -966,8 +967,18 @@ def get_data_type_string_object_mappings(
DATA_TYPE_STRING_OBJECT_MAPPINGS["bigint"] = LongType
DATA_TYPE_STRING_OBJECT_MAPPINGS["number"] = DecimalType
DATA_TYPE_STRING_OBJECT_MAPPINGS["numeric"] = DecimalType
DATA_TYPE_STRING_OBJECT_MAPPINGS["decimal"] = DecimalType
DATA_TYPE_STRING_OBJECT_MAPPINGS["object"] = MapType
DATA_TYPE_STRING_OBJECT_MAPPINGS["array"] = ArrayType
DATA_TYPE_STRING_OBJECT_MAPPINGS["timestamp_ntz"] = functools.partial(
TimestampType, timezone=TimestampTimeZone.NTZ
)
DATA_TYPE_STRING_OBJECT_MAPPINGS["timestamp_tz"] = functools.partial(
TimestampType, timezone=TimestampTimeZone.TZ
)
DATA_TYPE_STRING_OBJECT_MAPPINGS["timestamp_ltz"] = functools.partial(
TimestampType, timezone=TimestampTimeZone.LTZ
)

DECIMAL_RE = re.compile(
r"^\s*(numeric|number|decimal)\s*\(\s*(\s*)(\d*)\s*,\s*(\d*)\s*\)\s*$"
Expand Down Expand Up @@ -1064,18 +1075,43 @@ def extract_nullable_keyword(type_str: str) -> Tuple[str, bool]:
return trimmed, True


def parse_struct_field_list(fields_str: str) -> StructType:
def find_top_level_colon(field_def: str) -> int:
"""
Returns the index of the first top-level colon in 'field_def',
or -1 if there is no top-level colon. A colon is considered top-level
if it is not enclosed in <...> or (...).
Example:
'a struct<i: integer>' => returns -1 (colon is nested).
'x: struct<i: integer>' => returns index of the colon after 'x'.
"""
bracket_depth = 0
for i, ch in enumerate(field_def):
if ch in ("<", "("):
bracket_depth += 1
elif ch in (">", ")"):
bracket_depth -= 1
elif ch == ":" and bracket_depth == 0:
return i
return -1


def parse_struct_field_list(fields_str: str) -> Optional[StructType]:
"""
Parse something like "a: int, b: string, c: array<int>"
into StructType([StructField('a', IntegerType()), ...]).
"""
fields = []
field_defs = split_top_level_comma_fields(fields_str)
for field_def in field_defs:
# Try splitting on colon first, else whitespace
if ":" in field_def:
left, right = field_def.split(":", 1)
# Find first top-level colon (if any)
colon_index = find_top_level_colon(field_def)
if colon_index != -1:
# We found a top-level colon => split on it
left = field_def[:colon_index]
right = field_def[colon_index + 1 :]
else:
# No top-level colon => fallback to whitespace-based split
parts = field_def.split(None, 1)
if len(parts) != 2:
raise ValueError(f"Cannot parse struct field definition: '{field_def}'")
Expand All @@ -1089,7 +1125,17 @@ def parse_struct_field_list(fields_str: str) -> StructType:
# 1) Check for trailing "NOT NULL" => sets nullable=False
base_type_str, nullable = extract_nullable_keyword(type_part)
# 2) Parse the base type
field_type = type_string_to_type_object(base_type_str)
try:
field_type = type_string_to_type_object(base_type_str)
except ValueError as ex:
# Spark supports both `x: int` and `x int`. In our original implementation, we don't support x int,
# and will raise this error. However, handling space is tricky because we need to handle something like
# decimal(10, 2) containing space too, as a valid schema string (without a column name).
# Therefore, if this error is raised, we just catch it and return None, then in next step,
# we can process it again as a structured schema string (x int).
if "is not a supported type" in str(ex):
return None
raise ex
fields.append(StructField(field_name, field_type, nullable=nullable))

return StructType(fields)
Expand Down Expand Up @@ -1119,18 +1165,25 @@ def split_top_level_comma_fields(s: str) -> List[str]:

def is_likely_struct(s: str) -> bool:
"""
Heuristic: If there's a top-level comma or colon outside brackets,
treat it like a struct with multiple fields, e.g. "a: int, b: string".
Return True if there's a top-level colon, comma, or space.
e.g. "arr array<integer>" => top-level space => struct
"arr: array<int>" => colon => struct
"a: int, b: string" => comma => struct
"""
bracket_depth = 0
for c in s:
if c in ["<", "("]:
top_level_space_found = False
for ch in s:
if ch in ("<", "("):
bracket_depth += 1
elif c in [">", ")"]:
elif ch in (">", ")"):
bracket_depth -= 1
elif (c in [":", ","]) and bracket_depth == 0:
return True
return False
elif bracket_depth == 0:
if ch in [":", ","]:
return True
elif ch == " ":
top_level_space_found = True

return top_level_space_found


def type_string_to_type_object(type_str: str) -> DataType:
Expand All @@ -1141,7 +1194,9 @@ def type_string_to_type_object(type_str: str) -> DataType:
# First check if this might be a top-level multi-field struct
# (e.g. "a: int, b: string") even if not written as "struct<...>"
if is_likely_struct(type_str):
return parse_struct_field_list(type_str)
result = parse_struct_field_list(type_str)
if result is not None:
return result

# Check for array<...>
if ARRAY_RE.match(type_str):
Expand Down
90 changes: 85 additions & 5 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4934,8 +4934,88 @@ def test_create_dataframe_implicit_struct_not_null_mixed(session):
assert result == expected_rows


def test_create_dataframe_implicit_struct_not_null_invalid(session):
data = [1, 2, 3]
schema_str = "int not null" # not a struct => ValueError
with pytest.raises(ValueError, match="'intnotnull' is not a supported type"):
session.create_dataframe(data, schema=schema_str)
def test_create_dataframe_arr_array(session):
"""
Verifies schema="arr array" is interpreted as a single field named "ARR"
with an 'ArrayType'. If your parser maps 'array' => ArrayType() by default,
the element type might be undefined or assumed.
"""
data = [
[[1, 2]],
[[3, 4]],
]
schema_str = "arr array"

df = session.create_dataframe(data, schema_str)
# Check schema
assert len(df.schema.fields) == 1
# Since "array" is mapped via your DATA_TYPE_STRING_OBJECT_MAPPINGS["array"] = ArrayType,
# we should expect an ArrayType(...) with no specific element type.
# For Snowpark, ArrayType() can be valid (element type might be 'AnyType' internally).
actual_field = df.schema.fields[0]
assert actual_field.name == "ARR"
assert isinstance(
actual_field.datatype, ArrayType
), f"Expected ArrayType(), got {actual_field.datatype}"
# default is nullable=True unless "not null" is specified
assert actual_field.nullable is True

# Check data
result = df.collect()
expected_rows = [Row(ARR="[\n 1,\n 2\n]"), Row(ARR="[\n 3,\n 4\n]")]
assert result == expected_rows


def test_create_dataframe_x_string(session):
"""
Verifies schema="x STRING" is interpreted as a single field named "X"
with type StringType().
"""
data = [
["hello"],
["world"],
]
schema_str = "x STRING"

df = session.create_dataframe(data, schema_str)
# Check schema
assert len(df.schema.fields) == 1
expected_field = StructField("X", StringType(), nullable=True)
assert df.schema.fields[0] == expected_field

# Check rows
result = df.collect()
expected_rows = [
Row(X="hello"),
Row(X="world"),
]
assert result == expected_rows


def test_create_dataframe_x_string_y_integer(session):
"""
Verifies schema="x STRING, y INTEGER" is interpreted as a struct with two fields:
'X' => StringType (nullable), 'Y' => IntegerType (nullable).
"""
data = [
["a", 1],
["b", 2],
]
schema_str = "x STRING, y INTEGER"

df = session.create_dataframe(data, schema_str)
# Check schema
assert len(df.schema.fields) == 2
expected_fields = [
StructField("X", StringType(), nullable=True),
StructField("Y", LongType(), nullable=True),
]
assert df.schema.fields == expected_fields

# Check rows
result = df.collect()
expected_rows = [
Row(X="a", Y=1),
Row(X="b", Y=2),
]
assert result == expected_rows
124 changes: 116 additions & 8 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,10 @@ def test_type_string_to_type_object_number_decimal():
assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}"
assert dt.precision == 10, f"Expected precision=10, got {dt.precision}"
assert dt.scale == 2, f"Expected scale=2, got {dt.scale}"
dt = type_string_to_type_object("decimal")
assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}"
assert dt.precision == 38, f"Expected precision=38, got {dt.precision}"
assert dt.scale == 0, f"Expected scale=0, got {dt.scale}"


def test_type_string_to_type_object_numeric_decimal():
Expand Down Expand Up @@ -1553,6 +1557,21 @@ def test_type_string_to_type_object_text_with_length():
assert dt.length == 100, f"Expected length=100, got {dt.length}"


def test_type_string_to_type_object_timestamp():
dt = type_string_to_type_object("timestamp")
assert isinstance(dt, TimestampType)
assert dt.tz == TimestampTimeZone.DEFAULT
dt = type_string_to_type_object("timestamp_ntz")
assert isinstance(dt, TimestampType)
assert dt.tz == TimestampTimeZone.NTZ
dt = type_string_to_type_object("timestamp_tz")
assert isinstance(dt, TimestampType)
assert dt.tz == TimestampTimeZone.TZ
dt = type_string_to_type_object("timestamp_ltz")
assert isinstance(dt, TimestampType)
assert dt.tz == TimestampTimeZone.LTZ


def test_type_string_to_type_object_array_of_int():
dt = type_string_to_type_object("array<int>")
assert isinstance(dt, ArrayType), f"Expected ArrayType, got {dt}"
Expand Down Expand Up @@ -1776,6 +1795,20 @@ def test_type_string_to_type_object_implicit_struct_with_spaces():
), f"Expected {expected_field_col2}, got {dt.fields[1]}"


def test_type_string_to_type_object_implicit_struct_inner_colon():
dt = type_string_to_type_object("struct struct<i: integer not null>")
assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
assert len(dt.fields) == 1, f"Expected 1 field, got {len(dt.fields)}"
expected_field_i = StructField(
"STRUCT",
StructType([StructField("I", IntegerType(), nullable=False)]),
nullable=True,
)
assert (
dt.fields[0] == expected_field_i
), f"Expected {expected_field_i}, got {dt.fields[0]}"


def test_type_string_to_type_object_implicit_struct_error():
"""
Check a malformed implicit struct that should raise ValueError
Expand Down Expand Up @@ -1877,16 +1910,91 @@ def test_parse_struct_field_list_malformed():
), f"Unexpected error message: {ex}"


def test_is_likely_struct_true():
# top-level colon => likely struct
s = "a: int, b: string"
assert is_likely_struct(s) is True, "Expected True for struct-like string"
def test_parse_struct_field_list_single_type_with_space():
s = "decimal(1, 2)"
assert parse_struct_field_list(s) is None, "Expected None for single type"


def test_is_likely_struct_false():
# No top-level colon or comma => not a struct
s = "array<int>"
assert is_likely_struct(s) is False, "Expected False for non-struct string"
def test_is_likely_struct_colon():
"""
Strings with a top-level colon (outside any <...> or (...))
should return True.
"""
s = "col: int"
assert is_likely_struct(s) is True


def test_is_likely_struct_comma():
"""
Strings with a top-level comma (outside brackets)
should return True (e.g. multiple fields).
"""
s = "col1: int, col2: string"
assert is_likely_struct(s) is True


def test_is_likely_struct_top_level_space():
"""
Strings with a top-level space (and no colon/comma)
should return True (single field, e.g. 'arr array<int>').
"""
s = "arr array<int>"
assert is_likely_struct(s) is True


def test_is_likely_struct_no_space_colon_comma():
"""
If there's no top-level space, colon, or comma,
we return False (likely a single-type definition like 'decimal(10,2)').
"""
s = "decimal(10,2)"
assert is_likely_struct(s) is False


def test_is_likely_struct_space_inside_brackets():
"""
Spaces inside <...> should not trigger struct mode.
E.g. 'array< int >' has spaces inside brackets,
but no top-level space => should return False.
"""
s = "array< int >"
assert is_likely_struct(s) is False


def test_is_likely_struct_comma_inside_brackets():
"""
Comma inside <...> is not top-level,
so it shouldn't make the string 'likely struct'.
Example: 'array<int, string>' is not a struct definition,
it's a single type definition for an array of multiple types
(though typically invalid in Snowpark, but let's test bracket logic).
"""
s = "array<int, string>"
assert is_likely_struct(s) is False


def test_is_likely_struct_colon_inside_brackets():
"""
If a colon is inside brackets, e.g. 'map<int, struct<x: int>>',
that colon is not top-level => should return False.
"""
s = "map<int, struct<x: int>>"
assert is_likely_struct(s) is False


def test_is_likely_struct_complex_no_top_level_space():
"""
Example: 'struct<x int, y int>' => top-level
colon/space are inside <...> => so top-level
has none => returns False.
But note if you want 'struct<x: int, y: int>',
you might parse it differently. This test ensures
bracket-depth logic works properly.
"""
s = "struct<x int, y int>"
# top-level bracket depth covers entire string
# => no top-level space/colon/comma
assert is_likely_struct(s) is False


def test_extract_nullable_keyword_no_not_null():
Expand Down

0 comments on commit cf804cd

Please sign in to comment.