diff --git a/CHANGELOG.md b/CHANGELOG.md
index 18d2d7a1257..dac7ffe3aca 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -42,6 +42,7 @@
- `try_to_binary`
- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
+- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
#### Improvements
diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py
index 7b071cf4adb..92d987b293a 100644
--- a/src/snowflake/snowpark/_internal/type_utils.py
+++ b/src/snowflake/snowpark/_internal/type_utils.py
@@ -969,6 +969,17 @@ def get_data_type_string_object_mappings(
STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$")
# support type string format like " string ( 23 ) "
+ARRAY_RE = re.compile(r"(?i)^\s*array\s*<")
+# support type string format like starting with "array<..."
+
+MAP_RE = re.compile(r"(?i)^\s*map\s*<")
+# support type string format like starting with "map<..."
+
+STRUCT_RE = re.compile(r"(?i)^\s*struct\s*<")
+# support type string format like starting with "struct<..."
+
+_NOT_NULL_PATTERN = re.compile(r"^(?P.*?)\s+not\s+null\s*$", re.IGNORECASE)
+
def get_number_precision_scale(type_str: str) -> Optional[Tuple[int, int]]:
decimal_matches = DECIMAL_RE.match(type_str)
@@ -982,7 +993,169 @@ def get_string_length(type_str: str) -> Optional[int]:
return int(string_matches.group(2))
+def extract_bracket_content(type_str: str, keyword: str) -> str:
+ """
+ Given a string that starts with e.g. "array<", returns the content inside the top-level <...>.
+ e.g., "array" => "int". It also parses the nested array like "array>".
+ Raises ValueError on mismatched or missing bracket.
+ """
+ type_str = type_str.strip()
+ prefix_pattern = rf"(?i)^\s*{keyword}\s*<"
+ match = re.match(prefix_pattern, type_str)
+ if not match:
+ raise ValueError(
+ f"'{type_str}' does not match expected '{keyword}<...>' syntax."
+ )
+
+ start_index = match.end() - 1 # position at '<'
+ bracket_depth = 0
+ inside_chars: List[str] = []
+ i = start_index
+ while i < len(type_str):
+ c = type_str[i]
+ if c == "<":
+ bracket_depth += 1
+ # we don't store the opening bracket in 'inside_chars'
+ # if bracket_depth was 0 -> 1, to skip the outer bracket
+ if bracket_depth > 1:
+ inside_chars.append(c)
+ elif c == ">":
+ bracket_depth -= 1
+ if bracket_depth < 0:
+ raise ValueError(f"Mismatched '>' in '{type_str}'")
+ if bracket_depth == 0:
+ if i != len(type_str) - 1:
+ raise ValueError(
+ f"Unexpected characters after closing '>' in '{type_str}'"
+ )
+ # done
+ return "".join(inside_chars).strip()
+ inside_chars.append(c)
+ else:
+ inside_chars.append(c)
+ i += 1
+
+ raise ValueError(f"Missing closing '>' in '{type_str}'.")
+
+
+def extract_nullable_keyword(type_str: str) -> Tuple[str, bool]:
+ """
+ Checks if `type_str` ends with something like 'NOT NULL' (ignoring
+ case and allowing arbitrary space between NOT and NULL). If found,
+ return the type substring minus that part, along with nullable=False.
+ Otherwise, return (type_str, True).
+ """
+ trimmed = type_str.strip()
+ match = _NOT_NULL_PATTERN.match(trimmed)
+ if match:
+ # Group 'base' is everything before 'not null'
+ base_type_str = match.group("base").strip()
+ return base_type_str, False
+
+ # By default, the field is nullable
+ return trimmed, True
+
+
+def parse_struct_field_list(fields_str: str) -> StructType:
+ """
+ Parse something like "a: int, b: string, c: array"
+ into StructType([StructField('a', IntegerType()), ...]).
+ """
+ fields = []
+ field_defs = split_top_level_comma_fields(fields_str)
+ for field_def in field_defs:
+ # Try splitting on colon first, else whitespace
+ if ":" in field_def:
+ left, right = field_def.split(":", 1)
+ else:
+ parts = field_def.split(None, 1)
+ if len(parts) != 2:
+ raise ValueError(f"Cannot parse struct field definition: '{field_def}'")
+ left, right = parts[0], parts[1]
+
+ field_name = left.strip()
+ type_part = right.strip()
+ if not field_name:
+ raise ValueError(f"Struct field missing name in '{field_def}'")
+
+ # 1) Check for trailing "NOT NULL" => sets nullable=False
+ base_type_str, nullable = extract_nullable_keyword(type_part)
+ # 2) Parse the base type
+ field_type = type_string_to_type_object(base_type_str)
+ fields.append(StructField(field_name, field_type, nullable=nullable))
+
+ return StructType(fields)
+
+
+def split_top_level_comma_fields(s: str) -> List[str]:
+ """
+ Splits 's' by commas not enclosed in matching brackets.
+ Example: "int, array, decimal(10,2)" => ["int", "array", "decimal(10,2)"].
+ """
+ parts = []
+ bracket_depth = 0
+ start_idx = 0
+ for i, c in enumerate(s):
+ if c in ["<", "("]:
+ bracket_depth += 1
+ elif c in [">", ")"]:
+ bracket_depth -= 1
+ if bracket_depth < 0:
+ raise ValueError(f"Mismatched bracket in '{s}'.")
+ elif c == "," and bracket_depth == 0:
+ parts.append(s[start_idx:i].strip())
+ start_idx = i + 1
+ parts.append(s[start_idx:].strip())
+ return parts
+
+
+def is_likely_struct(s: str) -> bool:
+ """
+ Heuristic: If there's a top-level comma or colon outside brackets,
+ treat it like a struct with multiple fields, e.g. "a: int, b: string".
+ """
+ bracket_depth = 0
+ for c in s:
+ if c in ["<", "("]:
+ bracket_depth += 1
+ elif c in [">", ")"]:
+ bracket_depth -= 1
+ elif (c in [":", ","]) and bracket_depth == 0:
+ return True
+ return False
+
+
def type_string_to_type_object(type_str: str) -> DataType:
+ type_str = type_str.strip()
+ if not type_str:
+ raise ValueError("Empty type string")
+
+ # First check if this might be a top-level multi-field struct
+ # (e.g. "a: int, b: string") even if not written as "struct<...>"
+ if is_likely_struct(type_str):
+ return parse_struct_field_list(type_str)
+
+ # Check for array<...>
+ if ARRAY_RE.match(type_str):
+ inner = extract_bracket_content(type_str, "array")
+ element_type = type_string_to_type_object(inner)
+ return ArrayType(element_type)
+
+ # Check for map
+ if MAP_RE.match(type_str):
+ inner = extract_bracket_content(type_str, "map")
+ parts = split_top_level_comma_fields(inner)
+ if len(parts) != 2:
+ raise ValueError(f"Invalid map type definition: '{type_str}'")
+ key_type = type_string_to_type_object(parts[0])
+ val_type = type_string_to_type_object(parts[1])
+ return MapType(key_type, val_type)
+
+ # Check for explicit struct<...>
+ if STRUCT_RE.match(type_str):
+ inner = extract_bracket_content(type_str, "struct")
+ return parse_struct_field_list(inner)
+
precision_scale = get_number_precision_scale(type_str)
if precision_scale:
return DecimalType(*precision_scale)
diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py
index e63274b5069..113e07443ca 100644
--- a/src/snowflake/snowpark/session.py
+++ b/src/snowflake/snowpark/session.py
@@ -95,6 +95,7 @@
infer_schema,
infer_type,
merge_type,
+ type_string_to_type_object,
)
from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql
from snowflake.snowpark._internal.utils import (
@@ -3029,7 +3030,7 @@ def write_pandas(
def create_dataframe(
self,
data: Union[List, Tuple, "pandas.DataFrame"],
- schema: Optional[Union[StructType, Iterable[str]]] = None,
+ schema: Optional[Union[StructType, Iterable[str], str]] = None,
_emit_ast: bool = True,
) -> DataFrame:
"""Creates a new DataFrame containing the specified values from the local data.
@@ -3046,9 +3047,15 @@ def create_dataframe(
``data`` will constitute a row in the DataFrame.
schema: A :class:`~snowflake.snowpark.types.StructType` containing names and
data types of columns, or a list of column names, or ``None``.
- When ``schema`` is a list of column names or ``None``, the schema of the
- DataFrame will be inferred from the data across all rows. To improve
- performance, provide a schema. This avoids the need to infer data types
+
+ - When passing a **string**, it can be either an *explicit* struct
+ (e.g. ``"struct"``) or an *implicit* struct
+ (e.g. ``"a: int, b: string"``). Internally, the string is parsed and
+ converted into a :class:`StructType` using Snowpark's type parsing.
+ - When ``schema`` is a list of column names or ``None``, the schema of the
+ DataFrame will be inferred from the data across all rows.
+
+ To improve performance, provide a schema. This avoids the need to infer data types
with large data sets.
Examples::
@@ -3078,6 +3085,10 @@ def create_dataframe(
>>> session.create_dataframe(pd.DataFrame([(1, 2, 3, 4)], columns=["a", "b", "c", "d"])).collect()
[Row(a=1, b=2, c=3, d=4)]
+ >>> # create a dataframe using an implicit struct schema string
+ >>> session.create_dataframe([[10, 20], [30, 40]], schema="x: int, y: int").collect()
+ [Row(X=10, Y=20), Row(X=30, Y=40)]
+
Note:
When `data` is a pandas DataFrame, `snowflake.connector.pandas_tools.write_pandas` is called, which
requires permission to (1) CREATE STAGE (2) CREATE TABLE and (3) CREATE FILE FORMAT under the current
@@ -3156,6 +3167,13 @@ def create_dataframe(
# infer the schema based on the data
names = None
schema_query = None
+ if isinstance(schema, str):
+ schema = type_string_to_type_object(schema)
+ if not isinstance(schema, StructType):
+ raise ValueError(
+ f"Invalid schema string: {schema}. "
+ f"You should provide a valid schema string representing a struct type."
+ )
if isinstance(schema, StructType):
new_schema = schema
# SELECT query has an undefined behavior for nullability, so if the schema requires non-nullable column and
diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py
index db861ad4a95..cbb282e7a85 100644
--- a/tests/integ/test_dataframe.py
+++ b/tests/integ/test_dataframe.py
@@ -4479,3 +4479,277 @@ def test_SNOW_1879403_replace_with_lit(session):
).collect()
Utils.check_answer(ans, [Row("orange"), Row("orange pie"), Row("orange juice")])
+
+
+def test_create_dataframe_with_implicit_struct_simple(session):
+ """
+ Test an implicit struct string with two integer columns.
+ """
+ data = [
+ [1, 2],
+ [3, 4],
+ ]
+ # The new feature: implicit struct string "col1: int, col2: int"
+ schema_str = "col1: int, col2: int"
+
+ # Create the dataframe
+ df = session.create_dataframe(data, schema=schema_str)
+ # Check schema
+ # We expect the schema to be a StructType with 2 fields
+ assert isinstance(df.schema, StructType)
+ assert len(df.schema.fields) == 2
+ expected_fields = [
+ StructField("COL1", LongType(), nullable=True),
+ StructField("COL2", LongType(), nullable=True),
+ ]
+ assert df.schema.fields == expected_fields
+
+ # Collect rows
+ result = df.collect()
+ expected_rows = [
+ Row(COL1=1, COL2=2),
+ Row(COL1=3, COL2=4),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_with_implicit_struct_nested(session):
+ """
+ Test an implicit struct string with nested array and decimal columns.
+ """
+ data = [
+ [["1", "2"], Decimal("3.14")],
+ [["5", "6"], Decimal("2.72")],
+ ]
+ # Nested schema: first column is array, second is decimal(10,2)
+ schema_str = "arr: array, val: decimal(10,2)"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Verify schema
+ assert len(df.schema.fields) == 2
+ expected_fields = [
+ StructField("ARR", ArrayType(StringType()), nullable=True),
+ StructField("VAL", DecimalType(10, 2), nullable=True),
+ ]
+ assert df.schema.fields == expected_fields
+
+ # Verify rows
+ result = df.collect()
+ expected_rows = [
+ Row(ARR='[\n "1",\n "2"\n]', VAL=Decimal("3.14")),
+ Row(ARR='[\n "5",\n "6"\n]', VAL=Decimal("2.72")),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_with_explicit_struct_string(session):
+ """
+ Test an explicit struct string "struct"
+ to confirm it also works (even though it's not strictly 'implicit').
+ """
+ data = [
+ ["hello", 3.14],
+ ["world", 2.72],
+ ]
+ schema_str = "struct"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Verify schema
+ assert len(df.schema.fields) == 2
+ expected_fields = [
+ StructField("COLA", StringType(), nullable=True),
+ StructField("COLB", DoubleType(), nullable=True),
+ ]
+ assert df.schema.fields == expected_fields
+
+ # Verify rows
+ result = df.collect()
+ expected_rows = [
+ Row(COLA="hello", COLB=3.14),
+ Row(COLA="world", COLB=2.72),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_with_implicit_struct_malformed(session):
+ """
+ Test malformed implicit struct string, which should raise an error.
+ """
+ data = [[1, 2]]
+ # Missing type for second column
+ schema_str = "col1: int, col2"
+
+ with pytest.raises(ValueError) as ex_info:
+ session.create_dataframe(data, schema=schema_str)
+ # Check that the error message mentions the problem
+ assert (
+ "col2" in str(ex_info.value).lower()
+ ), f"Unexpected error message: {ex_info.value}"
+
+
+def test_create_dataframe_with_implicit_struct_datetime(session):
+ """
+ Another example mixing basic data with boolean and dates, ensuring
+ the implicit struct string handles them properly.
+ """
+ data = [
+ [True, datetime.date(2020, 1, 1)],
+ [False, datetime.date(2021, 12, 31)],
+ ]
+ schema_str = "flag: boolean, d: date"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Check schema
+ assert len(df.schema.fields) == 2
+ expected_fields = [
+ StructField("FLAG", BooleanType(), nullable=True),
+ StructField("D", DateType(), nullable=True),
+ ]
+ assert df.schema.fields == expected_fields
+
+ # Check rows
+ result = df.collect()
+ expected_rows = [
+ Row(FLAG=True, D=datetime.date(2020, 1, 1)),
+ Row(FLAG=False, D=datetime.date(2021, 12, 31)),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_invalid_schema_string_not_struct(session):
+ """
+ Verifies that a non-struct schema string (e.g. "int") raises ValueError
+ because the resulting type is not an instance of StructType.
+ """
+ data = [1, 2, 3]
+ # "int" does not represent a struct, so we expect a ValueError
+ with pytest.raises(ValueError) as ex_info:
+ session.create_dataframe(data, schema="int")
+
+ # Check that the error message mentions "Invalid schema string" or "struct type"
+ err_msg = str(ex_info.value).lower()
+ assert (
+ "invalid schema string" in err_msg and "struct type" in err_msg
+ ), f"Expected error message about invalid schema string or struct type. Got: {ex_info.value}"
+
+
+def test_create_dataframe_implicit_struct_not_null_single(session):
+ """
+ Test a schema with one NOT NULL field.
+ """
+ data = [
+ [1],
+ [2],
+ ]
+ # One field 'col1: int not null'
+ schema_str = "col1: int NOT NULL"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Verify schema
+ assert isinstance(df.schema, StructType)
+ assert len(df.schema.fields) == 1
+
+ expected_field = StructField("COL1", LongType(), nullable=False)
+ assert df.schema.fields[0] == expected_field
+
+ # Collect rows
+ result = df.collect()
+ expected_rows = [Row(COL1=1), Row(COL1=2)]
+ assert result == expected_rows
+
+
+def test_create_dataframe_implicit_struct_not_null_multiple(session):
+ """
+ Test a schema with multiple fields, one of which is NOT NULL.
+ """
+ data = [
+ [10, "foo"],
+ [20, "bar"],
+ ]
+ schema_str = "col1: int not null, col2: string"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Verify schema
+ assert len(df.schema.fields) == 2
+
+ expected_fields = [
+ StructField("COL1", LongType(), nullable=False),
+ StructField("COL2", StringType(), nullable=True),
+ ]
+ assert df.schema.fields == expected_fields
+
+ # Verify rows
+ result = df.collect()
+ expected_rows = [
+ Row(COL1=10, COL2="foo"),
+ Row(COL1=20, COL2="bar"),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_implicit_struct_not_null_nested(session):
+ """
+ Test a schema with nested array and a NOT NULL decimal field.
+ """
+ data = [
+ [["1", "2"], Decimal("3.14")],
+ [["5", "6"], Decimal("2.72")],
+ ]
+ schema_str = "arr: array, val: decimal(10,2) NOT NULL"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Verify schema
+ assert len(df.schema.fields) == 2
+
+ expected_fields = [
+ StructField("ARR", ArrayType(StringType()), nullable=True),
+ StructField("VAL", DecimalType(10, 2), nullable=False),
+ ]
+ assert df.schema.fields == expected_fields
+
+ # Verify rows
+ result = df.collect()
+ expected_rows = [
+ Row(ARR='[\n "1",\n "2"\n]', VAL=Decimal("3.14")),
+ Row(ARR='[\n "5",\n "6"\n]', VAL=Decimal("2.72")),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_implicit_struct_not_null_mixed(session):
+ """
+ Test a schema mixing NOT NULL columns with normal columns,
+ plus various data types like boolean or date.
+ """
+ data = [
+ [True, datetime.date(2020, 1, 1), "Hello"],
+ [False, datetime.date(2021, 1, 2), "World"],
+ ]
+ schema_str = "flag: boolean not null, dt: date, txt: string not null"
+
+ df = session.create_dataframe(data, schema=schema_str)
+ # Verify schema
+ assert len(df.schema.fields) == 3
+
+ expected_fields = [
+ StructField("FLAG", BooleanType(), nullable=False),
+ StructField("DT", df.schema.fields[1].datatype, nullable=True),
+ StructField("TXT", StringType(), nullable=False),
+ ]
+
+ assert df.schema.fields == expected_fields
+
+ # Verify rows
+ result = df.collect()
+ expected_rows = [
+ Row(FLAG=True, DT=datetime.date(2020, 1, 1), TXT="Hello"),
+ Row(FLAG=False, DT=datetime.date(2021, 1, 2), TXT="World"),
+ ]
+ assert result == expected_rows
+
+
+def test_create_dataframe_implicit_struct_not_null_invalid(session):
+ data = [1, 2, 3]
+ schema_str = "int not null" # not a struct => ValueError
+ with pytest.raises(ValueError, match="'intnotnull' is not a supported type"):
+ session.create_dataframe(data, schema=schema_str)
diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py
index 7085d3ef6d5..dd479fe855e 100644
--- a/tests/unit/test_types.py
+++ b/tests/unit/test_types.py
@@ -43,6 +43,12 @@
retrieve_func_defaults_from_source,
retrieve_func_type_hints_from_source,
snow_type_to_dtype_str,
+ type_string_to_type_object,
+ is_likely_struct,
+ parse_struct_field_list,
+ split_top_level_comma_fields,
+ extract_bracket_content,
+ extract_nullable_keyword,
)
from snowflake.snowpark.types import (
ArrayType,
@@ -1457,3 +1463,459 @@ def test_maptype_alias():
assert tpe.valueType == tpe.value_type
assert tpe.keyType == tpe.key_type
+
+
+def test_type_string_to_type_object_basic_int():
+ dt = type_string_to_type_object("int")
+ assert isinstance(dt, IntegerType), f"Expected IntegerType, got {dt}"
+
+
+def test_type_string_to_type_object_smallint():
+ dt = type_string_to_type_object("smallint")
+ assert isinstance(dt, ShortType), f"Expected ShortType, got {dt}"
+
+
+def test_type_string_to_type_object_byteint():
+ dt = type_string_to_type_object("byteint")
+ assert isinstance(dt, ByteType), f"Expected ByteType, got {dt}"
+
+
+def test_type_string_to_type_object_bigint():
+ dt = type_string_to_type_object("bigint")
+ assert isinstance(dt, LongType), f"Expected LongType, got {dt}"
+
+
+def test_type_string_to_type_object_number_decimal():
+ # For number(precision, scale) => DecimalType
+ dt = type_string_to_type_object("number(10,2)")
+ assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}"
+ assert dt.precision == 10, f"Expected precision=10, got {dt.precision}"
+ assert dt.scale == 2, f"Expected scale=2, got {dt.scale}"
+
+
+def test_type_string_to_type_object_numeric_decimal():
+ dt = type_string_to_type_object("numeric(20, 5)")
+ assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}"
+ assert dt.precision == 20, f"Expected precision=20, got {dt.precision}"
+ assert dt.scale == 5, f"Expected scale=5, got {dt.scale}"
+
+
+def test_type_string_to_type_object_decimal_spaces():
+ # Check spaces inside parentheses
+ dt = type_string_to_type_object(" decimal ( 2 , 1 ) ")
+ assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}"
+ assert dt.precision == 2, f"Expected precision=2, got {dt.precision}"
+ assert dt.scale == 1, f"Expected scale=1, got {dt.scale}"
+
+
+def test_type_string_to_type_object_string_with_length():
+ dt = type_string_to_type_object("string(50)")
+ assert isinstance(dt, StringType), f"Expected StringType, got {dt}"
+ # Snowpark's StringType typically doesn't store length internally,
+ # but here, you're returning StringType(50) in your code, so let's check
+ if hasattr(dt, "length"):
+ assert dt.length == 50, f"Expected length=50, got {dt.length}"
+
+
+def test_type_string_to_type_object_text_with_length():
+ dt = type_string_to_type_object("text(100)")
+ assert isinstance(dt, StringType), f"Expected StringType, got {dt}"
+ if hasattr(dt, "length"):
+ assert dt.length == 100, f"Expected length=100, got {dt.length}"
+
+
+def test_type_string_to_type_object_array_of_int():
+ dt = type_string_to_type_object("array")
+ assert isinstance(dt, ArrayType), f"Expected ArrayType, got {dt}"
+ assert isinstance(
+ dt.element_type, IntegerType
+ ), f"Expected element_type=IntegerType, got {dt.element_type}"
+
+
+def test_type_string_to_type_object_array_of_decimal():
+ dt = type_string_to_type_object("array")
+ assert isinstance(dt, ArrayType), f"Expected ArrayType, got {dt}"
+ assert isinstance(
+ dt.element_type, DecimalType
+ ), f"Expected element_type=DecimalType, got {dt.element_type}"
+ assert dt.element_type.precision == 10
+ assert dt.element_type.scale == 2
+
+ try:
+ type_string_to_type_object("array")
+ raise AssertionError("Expected ValueError for not a supported type")
+ except ValueError as ex:
+ assert "is not a supported type" in str(
+ ex
+ ), f"Expected not a supported type, got: {ex}"
+
+
+def test_type_string_to_type_object_map_of_int_string():
+ dt = type_string_to_type_object("map")
+ assert isinstance(dt, MapType), f"Expected MapType, got {dt}"
+ assert isinstance(
+ dt.key_type, IntegerType
+ ), f"Expected key_type=IntegerType, got {dt.key_type}"
+ assert isinstance(
+ dt.value_type, StringType
+ ), f"Expected value_type=StringType, got {dt.value_type}"
+
+
+def test_type_string_to_type_object_map_of_array_decimal():
+ dt = type_string_to_type_object("map< array, decimal(12,5)>")
+ assert isinstance(dt, MapType), f"Expected MapType, got {dt}"
+ assert isinstance(
+ dt.key_type, ArrayType
+ ), f"Expected key_type=ArrayType, got {dt.key_type}"
+ assert isinstance(
+ dt.key_type.element_type, IntegerType
+ ), f"Expected key_type.element_type=IntegerType, got {dt.key_type.element_type}"
+ assert isinstance(
+ dt.value_type, DecimalType
+ ), f"Expected value_type=DecimalType, got {dt.value_type}"
+ assert dt.value_type.precision == 12
+ assert dt.value_type.scale == 5
+
+
+def test_type_string_to_type_object_explicit_struct_simple():
+ dt = type_string_to_type_object("struct")
+ assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
+ assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}"
+
+ # Now assert exact StructField matches
+ expected_field_a = StructField("a", IntegerType(), nullable=True)
+ expected_field_b = StructField("b", StringType(), nullable=True)
+ assert (
+ dt.fields[0] == expected_field_a
+ ), f"Expected {expected_field_a}, got {dt.fields[0]}"
+ assert (
+ dt.fields[1] == expected_field_b
+ ), f"Expected {expected_field_b}, got {dt.fields[1]}"
+
+
+def test_type_string_to_type_object_explicit_struct_nested():
+ dt = type_string_to_type_object(
+ "struct, y: map>"
+ )
+ assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
+ assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}"
+
+ # Check each field directly against StructField(...)
+ expected_field_x = StructField("x", ArrayType(IntegerType()), nullable=True)
+ expected_field_y = StructField(
+ "y", MapType(StringType(), DecimalType(5, 2)), nullable=True
+ )
+
+ assert (
+ dt.fields[0] == expected_field_x
+ ), f"Expected {expected_field_x}, got {dt.fields[0]}"
+ assert (
+ dt.fields[1] == expected_field_y
+ ), f"Expected {expected_field_y}, got {dt.fields[1]}"
+
+
+def test_type_string_to_type_object_unknown_type():
+ try:
+ type_string_to_type_object("unknown_type")
+ raise AssertionError("Expected ValueError for unknown type")
+ except ValueError as ex:
+ assert "unknown_type" in str(
+ ex
+ ), f"Error message doesn't mention 'unknown_type': {ex}"
+
+
+def test_type_string_to_type_object_mismatched_bracket_array():
+ try:
+ type_string_to_type_object("array>"))
+ raise AssertionError("Expected ValueError for mismatched bracket")
+ except ValueError as ex:
+ assert "Unexpected characters after closing '>' in" in str(
+ ex
+ ), f"Expected Unexpected characters after closing '>' error, got: {ex}"
+
+
+def test_type_string_to_type_object_bad_decimal():
+ try:
+ type_string_to_type_object("decimal(10,2,5)")
+ raise AssertionError("Expected ValueError for a malformed decimal argument")
+ except ValueError:
+ # "decimal(10,2,5)" doesn't match the DECIMAL_RE regex => unknown type => ValueError
+ pass
+
+
+def test_type_string_to_type_object_bad_struct_syntax():
+ try:
+ type_string_to_type_object("struct'.
+ """
+ dt = type_string_to_type_object("a: int, b: string")
+ assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
+ assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}"
+
+ expected_field_a = StructField("a", IntegerType(), nullable=True)
+ expected_field_b = StructField("b", StringType(), nullable=True)
+
+ assert (
+ dt.fields[0] == expected_field_a
+ ), f"Expected {expected_field_a}, got {dt.fields[0]}"
+ assert (
+ dt.fields[1] == expected_field_b
+ ), f"Expected {expected_field_b}, got {dt.fields[1]}"
+
+
+def test_type_string_to_type_object_implicit_struct_single_field():
+ """
+ Even a single 'name: type' with no commas should parse to StructType
+ if your parser logic treats it as an implicit struct.
+ """
+ dt = type_string_to_type_object("c: decimal(10,2)")
+ assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
+ assert len(dt.fields) == 1, f"Expected 1 field, got {len(dt.fields)}"
+
+ expected_field_c = StructField("c", DecimalType(10, 2), nullable=True)
+ assert (
+ dt.fields[0] == expected_field_c
+ ), f"Expected {expected_field_c}, got {dt.fields[0]}"
+
+
+def test_type_string_to_type_object_implicit_struct_nested():
+ """
+ Test an implicit struct with multiple fields,
+ including nested array/map types.
+ """
+ dt = type_string_to_type_object("arr: array, kv: map")
+ assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
+ assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}"
+
+ expected_field_arr = StructField("arr", ArrayType(IntegerType()), nullable=True)
+ expected_field_kv = StructField(
+ "kv", MapType(StringType(), DecimalType(5, 2)), nullable=True
+ )
+
+ assert (
+ dt.fields[0] == expected_field_arr
+ ), f"Expected {expected_field_arr}, got {dt.fields[0]}"
+ assert (
+ dt.fields[1] == expected_field_kv
+ ), f"Expected {expected_field_kv}, got {dt.fields[1]}"
+
+
+def test_type_string_to_type_object_implicit_struct_with_spaces():
+ """
+ Test spacing variations. E.g. " col1 : int , col2 : map< string , decimal(5,2) > ".
+ """
+ dt = type_string_to_type_object(
+ " col1 : int , col2 : map< string , decimal( 5 , 2 ) > "
+ )
+ assert isinstance(dt, StructType), f"Expected StructType, got {dt}"
+ assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}"
+
+ expected_field_col1 = StructField("col1", IntegerType(), nullable=True)
+ expected_field_col2 = StructField(
+ "col2", MapType(StringType(), DecimalType(5, 2)), nullable=True
+ )
+
+ assert (
+ dt.fields[0] == expected_field_col1
+ ), f"Expected {expected_field_col1}, got {dt.fields[0]}"
+ assert (
+ dt.fields[1] == expected_field_col2
+ ), f"Expected {expected_field_col2}, got {dt.fields[1]}"
+
+
+def test_type_string_to_type_object_implicit_struct_error():
+ """
+ Check a malformed implicit struct that should raise ValueError
+ (e.g. trailing comma or missing bracket for nested).
+ """
+ try:
+ type_string_to_type_object("a: int, b:")
+ raise AssertionError("Expected ValueError for malformed struct (b: )")
+ except ValueError as ex:
+ # We expect an error message about Empty type string
+ assert "Empty type string" in str(
+ ex
+ ), f"Expected error 'Empty type string', got: {ex}"
+
+ try:
+ type_string_to_type_object("arr: array'
+ assert "Missing closing" in str(
+ ex
+ ), f"Expected Missing closing error, got: {ex}"
+
+
+def test_extract_bracket_content_array_ok():
+ s = "array"
+ # We expect to extract "int" from inside <...>
+ content = extract_bracket_content(s, keyword="array")
+ assert content == "int", f"Expected 'int', got {content}"
+
+
+def test_extract_bracket_content_map_spaces():
+ s = " map< int , string >"
+ content = extract_bracket_content(s, keyword="map")
+ assert content == "int , string", f"Expected 'int , string', got {content}"
+
+
+def test_extract_bracket_content_missing_closing():
+ s = "array'")
+ except ValueError as ex:
+ assert (
+ "Missing closing" in str(ex) or "mismatched" in str(ex).lower()
+ ), f"Error does not mention missing bracket: {ex}"
+
+
+def test_extract_bracket_content_mismatched_extra_close():
+ s = "struct>"
+ try:
+ extract_bracket_content(s, keyword="struct")
+ raise AssertionError("Expected ValueError for extra '>'")
+ except ValueError as ex:
+ assert "Unexpected characters after closing '>' in" in str(
+ ex
+ ), f"Error does not mention Unexpected characters after closing '>' in: {ex}"
+
+
+def test_split_top_level_comma_fields_no_brackets():
+ s = "int, string, decimal(10,2)"
+ parts = split_top_level_comma_fields(s)
+ assert parts == ["int", "string", "decimal(10,2)"], f"Got unexpected parts: {parts}"
+
+
+def test_split_top_level_comma_fields_nested_brackets():
+ s = "int, array, decimal(10,2), map>"
+ parts = split_top_level_comma_fields(s)
+ assert parts == [
+ "int",
+ "array",
+ "decimal(10,2)",
+ "map>",
+ ], f"Got unexpected parts: {parts}"
+
+
+def test_parse_struct_field_list_simple():
+ s = "a: int, b: string"
+ struct_type = parse_struct_field_list(s)
+ assert (
+ len(struct_type.fields) == 2
+ ), f"Expected 2 fields, got {len(struct_type.fields)}"
+ # Direct equality checks on each StructField
+ from snowflake.snowpark.types import StructField, IntegerType, StringType
+
+ assert struct_type.fields[0] == StructField("a", IntegerType(), nullable=True)
+ assert struct_type.fields[1] == StructField("b", StringType(), nullable=True)
+
+
+def test_parse_struct_field_list_malformed():
+ s = "col1: int, col2"
+ try:
+ parse_struct_field_list(s)
+ raise AssertionError("Expected ValueError for missing type in 'col2'")
+ except ValueError as ex:
+ assert (
+ "Cannot parse struct field definition" in str(ex)
+ or "missing" in str(ex).lower()
+ ), f"Unexpected error message: {ex}"
+
+
+def test_is_likely_struct_true():
+ # top-level colon => likely struct
+ s = "a: int, b: string"
+ assert is_likely_struct(s) is True, "Expected True for struct-like string"
+
+
+def test_is_likely_struct_false():
+ # No top-level colon or comma => not a struct
+ s = "array"
+ assert is_likely_struct(s) is False, "Expected False for non-struct string"
+
+
+def test_extract_nullable_keyword_no_not_null():
+ """
+ Verifies that if there's no NOT NULL keyword, the function
+ returns the original string and nullable=True.
+ """
+ base_str, is_nullable = extract_nullable_keyword("integer")
+ assert base_str == "integer"
+ assert is_nullable is True
+
+
+def test_extract_nullable_keyword_case_insensitive():
+ """
+ Verifies that NOT NULL is matched regardless of case,
+ and the returned base_str excludes that portion.
+ """
+ base_str, is_nullable = extract_nullable_keyword("INT NOT NULL")
+ assert base_str == "INT"
+ assert is_nullable is False
+
+
+def test_extract_nullable_keyword_weird_spacing():
+ """
+ Verifies that arbitrary spacing in 'not null' is handled,
+ returning the correct base_str and nullable=False.
+ """
+ base_str, is_nullable = extract_nullable_keyword("decimal(10,2) not null")
+ assert base_str == "decimal(10,2)"
+ assert is_nullable is False
+
+
+def test_extract_nullable_keyword_random_case():
+ """
+ Verifies that random case usage like 'NoT nUlL' is detected,
+ returning nullable=False.
+ """
+ base_str, is_nullable = extract_nullable_keyword("decimal(10,2) NoT nUlL")
+ assert base_str == "decimal(10,2)"
+ assert is_nullable is False
+
+
+def test_extract_nullable_keyword_with_leading_trailing_spaces():
+ """
+ Verifies leading/trailing whitespace is stripped properly,
+ and the base_str excludes 'not null'.
+ """
+ base_str, is_nullable = extract_nullable_keyword(" decimal(10,2) not null ")
+ assert base_str == "decimal(10,2)"
+ assert is_nullable is False
+
+
+def test_extract_nullable_keyword_mix_of_no_keywords():
+ """
+ If there's a keyword 'null' alone (no 'not'),
+ it's not recognized by this pattern, so we treat it as normal text.
+ """
+ base_str, is_nullable = extract_nullable_keyword("mytype null")
+ # This doesn't match 'NOT NULL', so it returns original string with is_nullable=True
+ assert base_str == "mytype null"
+ assert is_nullable is True