Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1871175: Add support for specifying a schema string for DataFrame.create_dataframe #2828

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- `nullifzero`
- `snowflake_cortex_sentiment`
- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.

#### Improvements

Expand Down
149 changes: 149 additions & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,15 @@ def get_data_type_string_object_mappings(
STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$")
# support type string format like " string ( 23 ) "

ARRAY_RE = re.compile(r"(?i)^\s*array\s*<")
# support type string format like starting with "array<..."

MAP_RE = re.compile(r"(?i)^\s*map\s*<")
# support type string format like starting with "map<..."

STRUCT_RE = re.compile(r"(?i)^\s*struct\s*<")
# support type string format like starting with "struct<..."


def get_number_precision_scale(type_str: str) -> Optional[Tuple[int, int]]:
decimal_matches = DECIMAL_RE.match(type_str)
Expand All @@ -979,7 +988,147 @@ def get_string_length(type_str: str) -> Optional[int]:
return int(string_matches.group(2))


def extract_bracket_content(type_str: str, keyword: str) -> str:
"""
Given a string that starts with e.g. "array<", returns the content inside the top-level <...>.
e.g., "array<int>" => "int". It also parses the nested array like "array<array<...>>".
Raises ValueError on mismatched or missing bracket.
"""
prefix_pattern = rf"(?i)^\s*{keyword}\s*<"
match = re.match(prefix_pattern, type_str)
if not match:
raise ValueError(
f"'{type_str}' does not match expected '{keyword}<...>' syntax."
)

start_index = match.end() - 1 # position at '<'
bracket_depth = 0
inside_chars: List[str] = []
i = start_index
while i < len(type_str):
c = type_str[i]
if c == "<":
bracket_depth += 1
# we don't store the opening bracket in 'inside_chars'
# if bracket_depth was 0 -> 1, to skip the outer bracket
if bracket_depth > 1:
inside_chars.append(c)
Comment on lines +1011 to +1015
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this form allowed array<<<...>>>?

Copy link
Collaborator Author

@sfc-gh-jdu sfc-gh-jdu Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea for something like array<array<...>>, added a comment here

elif c == ">":
bracket_depth -= 1
if bracket_depth < 0:
raise ValueError(f"Mismatched '>' in '{type_str}'")
if bracket_depth == 0:
if i != len(type_str) - 1:
raise ValueError(
f"Unexpected characters after closing '>' in '{type_str}'"
)
# done
return "".join(inside_chars).strip()
inside_chars.append(c)
else:
inside_chars.append(c)
i += 1

raise ValueError(f"Missing closing '>' in '{type_str}'.")


def parse_struct_field_list(fields_str: str) -> StructType:
"""
Parse something like "a: int, b: string, c: array<int>"
into StructType([StructField('a', IntegerType()), ...]).
"""
fields = []
field_defs = split_top_level_comma_fields(fields_str)
for field_def in field_defs:
# Try splitting on colon first, else whitespace
if ":" in field_def:
left, right = field_def.split(":", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we consider multiple colon cases like "a:b:c"? or this is handled by upstream/downstream logic already

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah, PySpark's simpleString format only considers the first colon or whitespace.

else:
parts = field_def.split(None, 1)
if len(parts) != 2:
raise ValueError(f"Cannot parse struct field definition: '{field_def}'")
left, right = parts[0], parts[1]

field_name = left.strip()
type_part = right.strip()
if not field_name:
raise ValueError(f"Struct field missing name in '{field_def}'")

field_type = type_string_to_type_object(type_part)
fields.append(StructField(field_name, field_type, nullable=True))

return StructType(fields)


def split_top_level_comma_fields(s: str) -> List[str]:
"""
Splits 's' by commas not enclosed in matching brackets.
Example: "int, array<long>, decimal(10,2)" => ["int", "array<long>", "decimal(10,2)"].
"""
parts = []
bracket_depth = 0
start_idx = 0
for i, c in enumerate(s):
if c in ["<", "("]:
bracket_depth += 1
elif c in [">", ")"]:
bracket_depth -= 1
if bracket_depth < 0:
raise ValueError(f"Mismatched bracket in '{s}'.")
Comment on lines +1071 to +1077
Copy link
Contributor

@sfc-gh-aling sfc-gh-aling Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this bracket check logic has repeated multiple times
do you think it's possible to check the bracket match as the initial step for only one time for the whole input string, and then in the downstream logic we can only focus on extracting the names and types

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to parse bracket to split fields, and extract names and types anyway. There is indeed a duplicate of validating whether the bracket expression is valid or not, maybe we can remove it. But to make the function self-contained, maybe let's still keep it? They are also covered in the test.

elif c == "," and bracket_depth == 0:
parts.append(s[start_idx:i].strip())
start_idx = i + 1
parts.append(s[start_idx:].strip())
return parts


def is_likely_struct(s: str) -> bool:
"""
Heuristic: If there's a top-level comma or colon outside brackets,
treat it like a struct with multiple fields, e.g. "a: int, b: string".
"""
bracket_depth = 0
for c in s:
if c in ["<", "("]:
bracket_depth += 1
elif c in [">", ")"]:
bracket_depth -= 1
elif (c in [":", ","]) and bracket_depth == 0:
return True
return False


def type_string_to_type_object(type_str: str) -> DataType:
type_str = type_str.strip()
if not type_str:
raise ValueError("Empty type string")

# First check if this might be a top-level multi-field struct
# (e.g. "a: int, b: string") even if not written as "struct<...>"
if is_likely_struct(type_str):
return parse_struct_field_list(type_str)

# Check for array<...>
if ARRAY_RE.match(type_str):
inner = extract_bracket_content(type_str, "array")
element_type = type_string_to_type_object(inner)
return ArrayType(element_type)

# Check for map<key, value>
if MAP_RE.match(type_str):
inner = extract_bracket_content(type_str, "map")
parts = split_top_level_comma_fields(inner)
if len(parts) != 2:
raise ValueError(f"Invalid map type definition: '{type_str}'")
key_type = type_string_to_type_object(parts[0])
val_type = type_string_to_type_object(parts[1])
return MapType(key_type, val_type)

# Check for explicit struct<...>
if STRUCT_RE.match(type_str):
inner = extract_bracket_content(type_str, "struct")
return parse_struct_field_list(inner)

precision_scale = get_number_precision_scale(type_str)
if precision_scale:
return DecimalType(*precision_scale)
Expand Down
26 changes: 22 additions & 4 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
infer_schema,
infer_type,
merge_type,
type_string_to_type_object,
)
from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql
from snowflake.snowpark._internal.utils import (
Expand Down Expand Up @@ -3044,7 +3045,7 @@ def write_pandas(
def create_dataframe(
self,
data: Union[List, Tuple, "pandas.DataFrame"],
schema: Optional[Union[StructType, Iterable[str]]] = None,
schema: Optional[Union[StructType, Iterable[str], str]] = None,
_emit_ast: bool = True,
) -> DataFrame:
"""Creates a new DataFrame containing the specified values from the local data.
Expand All @@ -3061,9 +3062,15 @@ def create_dataframe(
``data`` will constitute a row in the DataFrame.
schema: A :class:`~snowflake.snowpark.types.StructType` containing names and
data types of columns, or a list of column names, or ``None``.
When ``schema`` is a list of column names or ``None``, the schema of the
DataFrame will be inferred from the data across all rows. To improve
performance, provide a schema. This avoids the need to infer data types

- When passing a **string**, it can be either an *explicit* struct
(e.g. ``"struct<a: int, b: string>"``) or an *implicit* struct
(e.g. ``"a: int, b: string"``). Internally, the string is parsed and
converted into a :class:`StructType` using Snowpark's type parsing.
- When ``schema`` is a list of column names or ``None``, the schema of the
DataFrame will be inferred from the data across all rows.

To improve performance, provide a schema. This avoids the need to infer data types
with large data sets.

Examples::
Expand Down Expand Up @@ -3093,6 +3100,10 @@ def create_dataframe(
>>> session.create_dataframe(pd.DataFrame([(1, 2, 3, 4)], columns=["a", "b", "c", "d"])).collect()
[Row(a=1, b=2, c=3, d=4)]

>>> # create a dataframe using an implicit struct schema string
>>> session.create_dataframe([[10, 20], [30, 40]], schema="x: int, y: int").collect()
[Row(X=10, Y=20), Row(X=30, Y=40)]

Note:
When `data` is a pandas DataFrame, `snowflake.connector.pandas_tools.write_pandas` is called, which
requires permission to (1) CREATE STAGE (2) CREATE TABLE and (3) CREATE FILE FORMAT under the current
Expand Down Expand Up @@ -3173,6 +3184,13 @@ def create_dataframe(
# infer the schema based on the data
names = None
schema_query = None
if isinstance(schema, str):
schema = type_string_to_type_object(schema)
if not isinstance(schema, StructType):
raise ValueError(
f"Invalid schema string: {schema}. "
f"You should provide a valid schema string representing a struct type."
)
if isinstance(schema, StructType):
new_schema = schema
# SELECT query has an undefined behavior for nullability, so if the schema requires non-nullable column and
Expand Down
152 changes: 152 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4447,3 +4447,155 @@ def test_map_negative(session):
output_types=[IntegerType(), StringType()],
output_column_names=["a", "b", "c"],
)


def test_create_dataframe_with_implicit_struct_simple(session):
"""
Test an implicit struct string with two integer columns.
"""
data = [
[1, 2],
[3, 4],
]
# The new feature: implicit struct string "col1: int, col2: int"
schema_str = "col1: int, col2: int"

# Create the dataframe
df = session.create_dataframe(data, schema=schema_str)
# Check schema
# We expect the schema to be a StructType with 2 fields
assert isinstance(df.schema, StructType)
assert len(df.schema.fields) == 2
expected_fields = [
StructField("COL1", LongType(), nullable=True),
StructField("COL2", LongType(), nullable=True),
]
assert df.schema.fields == expected_fields

# Collect rows
result = df.collect()
expected_rows = [
Row(COL1=1, COL2=2),
Row(COL1=3, COL2=4),
]
assert result == expected_rows


def test_create_dataframe_with_implicit_struct_nested(session):
"""
Test an implicit struct string with nested array and decimal columns.
"""
data = [
[["1", "2"], Decimal("3.14")],
[["5", "6"], Decimal("2.72")],
]
# Nested schema: first column is array<int>, second is decimal(10,2)
schema_str = "arr: array<int>, val: decimal(10,2)"

df = session.create_dataframe(data, schema=schema_str)
# Verify schema
assert len(df.schema.fields) == 2
expected_fields = [
StructField("ARR", ArrayType(StringType()), nullable=True),
StructField("VAL", DecimalType(10, 2), nullable=True),
]
assert df.schema.fields == expected_fields

# Verify rows
result = df.collect()
expected_rows = [
Row(ARR='[\n "1",\n "2"\n]', VAL=Decimal("3.14")),
Row(ARR='[\n "5",\n "6"\n]', VAL=Decimal("2.72")),
]
assert result == expected_rows


def test_create_dataframe_with_explicit_struct_string(session):
"""
Test an explicit struct string "struct<colA: string, colB: double>"
to confirm it also works (even though it's not strictly 'implicit').
"""
data = [
["hello", 3.14],
["world", 2.72],
]
schema_str = "struct<colA: string, colB: double>"

df = session.create_dataframe(data, schema=schema_str)
# Verify schema
assert len(df.schema.fields) == 2
expected_fields = [
StructField("COLA", StringType(), nullable=True),
StructField("COLB", DoubleType(), nullable=True),
]
assert df.schema.fields == expected_fields

# Verify rows
result = df.collect()
expected_rows = [
Row(COLA="hello", COLB=3.14),
Row(COLA="world", COLB=2.72),
]
assert result == expected_rows


def test_create_dataframe_with_implicit_struct_malformed(session):
"""
Test malformed implicit struct string, which should raise an error.
"""
data = [[1, 2]]
# Missing type for second column
schema_str = "col1: int, col2"

with pytest.raises(ValueError) as ex_info:
session.create_dataframe(data, schema=schema_str)
# Check that the error message mentions the problem
assert (
"col2" in str(ex_info.value).lower()
), f"Unexpected error message: {ex_info.value}"


def test_create_dataframe_with_implicit_struct_datetime(session):
"""
Another example mixing basic data with boolean and dates, ensuring
the implicit struct string handles them properly.
"""
data = [
[True, datetime.date(2020, 1, 1)],
[False, datetime.date(2021, 12, 31)],
]
schema_str = "flag: boolean, d: date"

df = session.create_dataframe(data, schema=schema_str)
# Check schema
assert len(df.schema.fields) == 2
expected_fields = [
StructField("FLAG", BooleanType(), nullable=True),
StructField("D", DateType(), nullable=True),
]
assert df.schema.fields == expected_fields

# Check rows
result = df.collect()
expected_rows = [
Row(FLAG=True, D=datetime.date(2020, 1, 1)),
Row(FLAG=False, D=datetime.date(2021, 12, 31)),
]
assert result == expected_rows


def test_create_dataframe_invalid_schema_string_not_struct(session):
"""
Verifies that a non-struct schema string (e.g. "int") raises ValueError
because the resulting type is not an instance of StructType.
"""
data = [1, 2, 3]
# "int" does not represent a struct, so we expect a ValueError
with pytest.raises(ValueError) as ex_info:
session.create_dataframe(data, schema="int")

# Check that the error message mentions "Invalid schema string" or "struct type"
err_msg = str(ex_info.value).lower()
assert (
"invalid schema string" in err_msg and "struct type" in err_msg
), f"Expected error message about invalid schema string or struct type. Got: {ex_info.value}"
Loading
Loading