From c77383766de0afaee2d48604c43f69157f72fad4 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Mon, 6 Jan 2025 22:04:14 -0800 Subject: [PATCH 1/3] init --- CHANGELOG.md | 1 + .../snowpark/_internal/type_utils.py | 149 +++++++ src/snowflake/snowpark/session.py | 28 +- tests/integ/test_dataframe.py | 152 +++++++ tests/unit/test_types.py | 392 ++++++++++++++++++ 5 files changed, 717 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b48523e8f51..81cc232aee6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - `nullifzero` - `snowflake_cortex_sentiment` - Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`. +- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`. #### Improvements diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 0910a2a4aae..fcab1b1d6f7 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -966,6 +966,15 @@ def get_data_type_string_object_mappings( STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$") # support type string format like " string ( 23 ) " +ARRAY_RE = re.compile(r"(?i)^\s*array\s*<") +# support type string format like starting with "array<..." + +MAP_RE = re.compile(r"(?i)^\s*map\s*<") +# support type string format like starting with "map<..." + +STRUCT_RE = re.compile(r"(?i)^\s*struct\s*<") +# support type string format like starting with "array<..." starting with "struct<..." + def get_number_precision_scale(type_str: str) -> Optional[Tuple[int, int]]: decimal_matches = DECIMAL_RE.match(type_str) @@ -979,7 +988,147 @@ def get_string_length(type_str: str) -> Optional[int]: return int(string_matches.group(2)) +def extract_bracket_content(type_str: str, keyword: str) -> str: + """ + Given a string that starts with e.g. 'array<', returns the content inside the top-level <...>. + e.g., "array" => "int" + Raises ValueError on mismatched or missing bracket. + """ + prefix_pattern = rf"(?i)^\s*{keyword}\s*<" + match = re.match(prefix_pattern, type_str) + if not match: + raise ValueError( + f"'{type_str}' does not match expected '{keyword}<...>' syntax." + ) + + start_index = match.end() - 1 # position at '<' + bracket_depth = 0 + inside_chars: List[str] = [] + i = start_index + while i < len(type_str): + c = type_str[i] + if c == "<": + bracket_depth += 1 + # we don't store the opening bracket in 'inside_chars' + # if bracket_depth was 0 -> 1, to skip the outer bracket + if bracket_depth > 1: + inside_chars.append(c) + elif c == ">": + bracket_depth -= 1 + if bracket_depth < 0: + raise ValueError(f"Mismatched '>' in '{type_str}'") + if bracket_depth == 0: + if i != len(type_str) - 1: + raise ValueError( + f"Unexpected characters after closing '>' in '{type_str}'" + ) + # done + return "".join(inside_chars).strip() + inside_chars.append(c) + else: + inside_chars.append(c) + i += 1 + + raise ValueError(f"Missing closing '>' in '{type_str}'.") + + +def parse_struct_field_list(fields_str: str) -> StructType: + """ + Parse something like "a: int, b: string, c: array" + into StructType([StructField('a', IntegerType()), ...]). + """ + fields = [] + field_defs = split_top_level_comma_fields(fields_str) + for field_def in field_defs: + # Try splitting on colon first, else whitespace + if ":" in field_def: + left, right = field_def.split(":", 1) + else: + parts = field_def.split(None, 1) + if len(parts) != 2: + raise ValueError(f"Cannot parse struct field definition: '{field_def}'") + left, right = parts[0], parts[1] + + field_name = left.strip() + type_part = right.strip() + if not field_name: + raise ValueError(f"Struct field missing name in '{field_def}'") + + field_type = type_string_to_type_object(type_part) + fields.append(StructField(field_name, field_type, nullable=True)) + + return StructType(fields) + + +def split_top_level_comma_fields(s: str) -> List[str]: + """ + Splits 's' by commas not enclosed in matching brackets. + Example: "int, array, decimal(10,2)" => ["int", "array", "decimal(10,2)"]. + """ + parts = [] + bracket_depth = 0 + start_idx = 0 + for i, c in enumerate(s): + if c in ["<", "("]: + bracket_depth += 1 + elif c in [">", ")"]: + bracket_depth -= 1 + if bracket_depth < 0: + raise ValueError(f"Mismatched bracket in '{s}'.") + elif c == "," and bracket_depth == 0: + parts.append(s[start_idx:i].strip()) + start_idx = i + 1 + parts.append(s[start_idx:].strip()) + return parts + + +def is_likely_struct(s: str) -> bool: + """ + Heuristic: If there's a top-level comma or colon outside brackets, + treat it like a struct with multiple fields, e.g. "a: int, b: string". + """ + bracket_depth = 0 + for c in s: + if c in ["<", "("]: + bracket_depth += 1 + elif c in [">", ")"]: + bracket_depth -= 1 + elif (c in [":", ","]) and bracket_depth == 0: + return True + return False + + def type_string_to_type_object(type_str: str) -> DataType: + type_str = type_str.strip() + if not type_str: + raise ValueError("Empty type string") + + # First check if this might be a top-level multi-field struct + # (e.g. "a: int, b: string") even if not written as "struct<...>" + if is_likely_struct(type_str): + return parse_struct_field_list(type_str) + + # Check for array<...> + if ARRAY_RE.match(type_str): + inner = extract_bracket_content(type_str, "array") + element_type = type_string_to_type_object(inner) + return ArrayType(element_type) + + # Check for map + if MAP_RE.match(type_str): + inner = extract_bracket_content(type_str, "map") + parts = split_top_level_comma_fields(inner) + if len(parts) != 2: + raise ValueError(f"Invalid map type definition: '{type_str}'") + key_type = type_string_to_type_object(parts[0]) + val_type = type_string_to_type_object(parts[1]) + return MapType(key_type, val_type) + + # Check for explicit struct<...> + if STRUCT_RE.match(type_str): + inner = extract_bracket_content(type_str, "struct") + return parse_struct_field_list(inner) + precision_scale = get_number_precision_scale(type_str) if precision_scale: return DecimalType(*precision_scale) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 88d2a4a32d0..90496aab269 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -93,6 +93,7 @@ infer_schema, infer_type, merge_type, + type_string_to_type_object, ) from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql from snowflake.snowpark._internal.utils import ( @@ -3044,7 +3045,7 @@ def write_pandas( def create_dataframe( self, data: Union[List, Tuple, "pandas.DataFrame"], - schema: Optional[Union[StructType, Iterable[str]]] = None, + schema: Optional[Union[StructType, Iterable[str], str]] = None, _emit_ast: bool = True, ) -> DataFrame: """Creates a new DataFrame containing the specified values from the local data. @@ -3061,9 +3062,15 @@ def create_dataframe( ``data`` will constitute a row in the DataFrame. schema: A :class:`~snowflake.snowpark.types.StructType` containing names and data types of columns, or a list of column names, or ``None``. - When ``schema`` is a list of column names or ``None``, the schema of the - DataFrame will be inferred from the data across all rows. To improve - performance, provide a schema. This avoids the need to infer data types + + - When passing a **string**, it can be either an *explicit* struct + (e.g. ``"struct"``) or an *implicit* struct + (e.g. ``"a: int, b: string"``). Internally, the string is parsed and + converted into a :class:`StructType` using Snowpark's type parsing. + - When ``schema`` is a list of column names or ``None``, the schema of the + DataFrame will be inferred from the data across all rows. + + To improve performance, provide a schema. This avoids the need to infer data types with large data sets. Examples:: @@ -3093,6 +3100,10 @@ def create_dataframe( >>> session.create_dataframe(pd.DataFrame([(1, 2, 3, 4)], columns=["a", "b", "c", "d"])).collect() [Row(a=1, b=2, c=3, d=4)] + >>> # create a dataframe using an implicit struct schema string + >>> session.create_dataframe([[10, 20], [30, 40]], schema="x: int, y: int").collect() + [Row(X=10, Y=20), Row(X=30, Y=40)] + Note: When `data` is a pandas DataFrame, `snowflake.connector.pandas_tools.write_pandas` is called, which requires permission to (1) CREATE STAGE (2) CREATE TABLE and (3) CREATE FILE FORMAT under the current @@ -3173,7 +3184,14 @@ def create_dataframe( # infer the schema based on the data names = None schema_query = None - if isinstance(schema, StructType): + if isinstance(schema, str): + schema = type_string_to_type_object(schema) + if not isinstance(schema, StructType): + raise ValueError( + f"Invalid schema string: {schema}. " + f"You should provide a valid schema string representing a struct type." + ) + if isinstance(schema, (StructType, str)): new_schema = schema # SELECT query has an undefined behavior for nullability, so if the schema requires non-nullable column and # all columns are primitive type columns, we use a temp table to lock in the nullabilities. diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 7c91222181b..22fd4ba5eb6 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -4447,3 +4447,155 @@ def test_map_negative(session): output_types=[IntegerType(), StringType()], output_column_names=["a", "b", "c"], ) + + +def test_create_dataframe_with_implicit_struct_simple(session): + """ + Test an implicit struct string with two integer columns. + """ + data = [ + [1, 2], + [3, 4], + ] + # The new feature: implicit struct string "col1: int, col2: int" + schema_str = "col1: int, col2: int" + + # Create the dataframe + df = session.create_dataframe(data, schema=schema_str) + # Check schema + # We expect the schema to be a StructType with 2 fields + assert isinstance(df.schema, StructType) + assert len(df.schema.fields) == 2 + expected_fields = [ + StructField("COL1", LongType(), nullable=True), + StructField("COL2", LongType(), nullable=True), + ] + assert df.schema.fields == expected_fields + + # Collect rows + result = df.collect() + expected_rows = [ + Row(COL1=1, COL2=2), + Row(COL1=3, COL2=4), + ] + assert result == expected_rows + + +def test_create_dataframe_with_implicit_struct_nested(session): + """ + Test an implicit struct string with nested array and decimal columns. + """ + data = [ + [["1", "2"], Decimal("3.14")], + [["5", "6"], Decimal("2.72")], + ] + # Nested schema: first column is array, second is decimal(10,2) + schema_str = "arr: array, val: decimal(10,2)" + + df = session.create_dataframe(data, schema=schema_str) + # Verify schema + assert len(df.schema.fields) == 2 + expected_fields = [ + StructField("ARR", ArrayType(StringType()), nullable=True), + StructField("VAL", DecimalType(10, 2), nullable=True), + ] + assert df.schema.fields == expected_fields + + # Verify rows + result = df.collect() + expected_rows = [ + Row(ARR='[\n "1",\n "2"\n]', VAL=Decimal("3.14")), + Row(ARR='[\n "5",\n "6"\n]', VAL=Decimal("2.72")), + ] + assert result == expected_rows + + +def test_create_dataframe_with_explicit_struct_string(session): + """ + Test an explicit struct string "struct" + to confirm it also works (even though it's not strictly 'implicit'). + """ + data = [ + ["hello", 3.14], + ["world", 2.72], + ] + schema_str = "struct" + + df = session.create_dataframe(data, schema=schema_str) + # Verify schema + assert len(df.schema.fields) == 2 + expected_fields = [ + StructField("COLA", StringType(), nullable=True), + StructField("COLB", DoubleType(), nullable=True), + ] + assert df.schema.fields == expected_fields + + # Verify rows + result = df.collect() + expected_rows = [ + Row(COLA="hello", COLB=3.14), + Row(COLA="world", COLB=2.72), + ] + assert result == expected_rows + + +def test_create_dataframe_with_implicit_struct_malformed(session): + """ + Test malformed implicit struct string, which should raise an error. + """ + data = [[1, 2]] + # Missing type for second column + schema_str = "col1: int, col2" + + with pytest.raises(ValueError) as ex_info: + session.create_dataframe(data, schema=schema_str) + # Check that the error message mentions the problem + assert ( + "col2" in str(ex_info.value).lower() + ), f"Unexpected error message: {ex_info.value}" + + +def test_create_dataframe_with_implicit_struct_datetime(session): + """ + Another example mixing basic data with boolean and dates, ensuring + the implicit struct string handles them properly. + """ + data = [ + [True, datetime.date(2020, 1, 1)], + [False, datetime.date(2021, 12, 31)], + ] + schema_str = "flag: boolean, d: date" + + df = session.create_dataframe(data, schema=schema_str) + # Check schema + assert len(df.schema.fields) == 2 + expected_fields = [ + StructField("FLAG", BooleanType(), nullable=True), + StructField("D", DateType(), nullable=True), + ] + assert df.schema.fields == expected_fields + + # Check rows + result = df.collect() + expected_rows = [ + Row(FLAG=True, D=datetime.date(2020, 1, 1)), + Row(FLAG=False, D=datetime.date(2021, 12, 31)), + ] + assert result == expected_rows + + +def test_create_dataframe_invalid_schema_string_not_struct(session): + """ + Verifies that a non-struct schema string (e.g. "int") raises ValueError + because the resulting type is not an instance of StructType. + """ + data = [1, 2, 3] + # "int" does not represent a struct, so we expect a ValueError + with pytest.raises(ValueError) as ex_info: + session.create_dataframe(data, schema="int") + + # Check that the error message mentions "Invalid schema string" or "struct type" + err_msg = str(ex_info.value).lower() + assert ( + "invalid schema string" in err_msg and "struct type" in err_msg + ), f"Expected error message about invalid schema string or struct type. Got: {ex_info.value}" diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index db5355d1cec..ae0998e4b64 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -43,6 +43,11 @@ retrieve_func_defaults_from_source, retrieve_func_type_hints_from_source, snow_type_to_dtype_str, + type_string_to_type_object, + is_likely_struct, + parse_struct_field_list, + split_top_level_comma_fields, + extract_bracket_content, ) from snowflake.snowpark.types import ( ArrayType, @@ -1440,3 +1445,390 @@ def test_maptype_alias(): assert tpe.valueType == tpe.value_type assert tpe.keyType == tpe.key_type + + +def test_type_string_to_type_object_basic_int(): + dt = type_string_to_type_object("int") + assert isinstance(dt, IntegerType), f"Expected IntegerType, got {dt}" + + +def test_type_string_to_type_object_smallint(): + dt = type_string_to_type_object("smallint") + assert isinstance(dt, ShortType), f"Expected ShortType, got {dt}" + + +def test_type_string_to_type_object_byteint(): + dt = type_string_to_type_object("byteint") + assert isinstance(dt, ByteType), f"Expected ByteType, got {dt}" + + +def test_type_string_to_type_object_bigint(): + dt = type_string_to_type_object("bigint") + assert isinstance(dt, LongType), f"Expected LongType, got {dt}" + + +def test_type_string_to_type_object_number_decimal(): + # For number(precision, scale) => DecimalType + dt = type_string_to_type_object("number(10,2)") + assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}" + assert dt.precision == 10, f"Expected precision=10, got {dt.precision}" + assert dt.scale == 2, f"Expected scale=2, got {dt.scale}" + + +def test_type_string_to_type_object_numeric_decimal(): + dt = type_string_to_type_object("numeric(20, 5)") + assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}" + assert dt.precision == 20, f"Expected precision=20, got {dt.precision}" + assert dt.scale == 5, f"Expected scale=5, got {dt.scale}" + + +def test_type_string_to_type_object_decimal_spaces(): + # Check spaces inside parentheses + dt = type_string_to_type_object(" decimal ( 2 , 1 ) ") + assert isinstance(dt, DecimalType), f"Expected DecimalType, got {dt}" + assert dt.precision == 2, f"Expected precision=2, got {dt.precision}" + assert dt.scale == 1, f"Expected scale=1, got {dt.scale}" + + +def test_type_string_to_type_object_string_with_length(): + dt = type_string_to_type_object("string(50)") + assert isinstance(dt, StringType), f"Expected StringType, got {dt}" + # Snowpark's StringType typically doesn't store length internally, + # but here, you're returning StringType(50) in your code, so let's check + if hasattr(dt, "length"): + assert dt.length == 50, f"Expected length=50, got {dt.length}" + + +def test_type_string_to_type_object_text_with_length(): + dt = type_string_to_type_object("text(100)") + assert isinstance(dt, StringType), f"Expected StringType, got {dt}" + if hasattr(dt, "length"): + assert dt.length == 100, f"Expected length=100, got {dt.length}" + + +def test_type_string_to_type_object_array_of_int(): + dt = type_string_to_type_object("array") + assert isinstance(dt, ArrayType), f"Expected ArrayType, got {dt}" + assert isinstance( + dt.element_type, IntegerType + ), f"Expected element_type=IntegerType, got {dt.element_type}" + + +def test_type_string_to_type_object_array_of_decimal(): + dt = type_string_to_type_object("array") + assert isinstance(dt, ArrayType), f"Expected ArrayType, got {dt}" + assert isinstance( + dt.element_type, DecimalType + ), f"Expected element_type=DecimalType, got {dt.element_type}" + assert dt.element_type.precision == 10 + assert dt.element_type.scale == 2 + + +def test_type_string_to_type_object_map_of_int_string(): + dt = type_string_to_type_object("map") + assert isinstance(dt, MapType), f"Expected MapType, got {dt}" + assert isinstance( + dt.key_type, IntegerType + ), f"Expected key_type=IntegerType, got {dt.key_type}" + assert isinstance( + dt.value_type, StringType + ), f"Expected value_type=StringType, got {dt.value_type}" + + +def test_type_string_to_type_object_map_of_array_decimal(): + dt = type_string_to_type_object("map< array, decimal(12,5)>") + assert isinstance(dt, MapType), f"Expected MapType, got {dt}" + assert isinstance( + dt.key_type, ArrayType + ), f"Expected key_type=ArrayType, got {dt.key_type}" + assert isinstance( + dt.key_type.element_type, IntegerType + ), f"Expected key_type.element_type=IntegerType, got {dt.key_type.element_type}" + assert isinstance( + dt.value_type, DecimalType + ), f"Expected value_type=DecimalType, got {dt.value_type}" + assert dt.value_type.precision == 12 + assert dt.value_type.scale == 5 + + +def test_type_string_to_type_object_explicit_struct_simple(): + dt = type_string_to_type_object("struct") + assert isinstance(dt, StructType), f"Expected StructType, got {dt}" + assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}" + + # Now assert exact StructField matches + expected_field_a = StructField("a", IntegerType(), nullable=True) + expected_field_b = StructField("b", StringType(), nullable=True) + assert ( + dt.fields[0] == expected_field_a + ), f"Expected {expected_field_a}, got {dt.fields[0]}" + assert ( + dt.fields[1] == expected_field_b + ), f"Expected {expected_field_b}, got {dt.fields[1]}" + + +def test_type_string_to_type_object_explicit_struct_nested(): + dt = type_string_to_type_object( + "struct, y: map>" + ) + assert isinstance(dt, StructType), f"Expected StructType, got {dt}" + assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}" + + # Check each field directly against StructField(...) + expected_field_x = StructField("x", ArrayType(IntegerType()), nullable=True) + expected_field_y = StructField( + "y", MapType(StringType(), DecimalType(5, 2)), nullable=True + ) + + assert ( + dt.fields[0] == expected_field_x + ), f"Expected {expected_field_x}, got {dt.fields[0]}" + assert ( + dt.fields[1] == expected_field_y + ), f"Expected {expected_field_y}, got {dt.fields[1]}" + + +def test_type_string_to_type_object_unknown_type(): + try: + type_string_to_type_object("unknown_type") + raise AssertionError("Expected ValueError for unknown type") + except ValueError as ex: + assert "unknown_type" in str( + ex + ), f"Error message doesn't mention 'unknown_type': {ex}" + + +def test_type_string_to_type_object_mismatched_bracket_array(): + try: + type_string_to_type_object("array>")) + raise AssertionError("Expected ValueError for mismatched bracket") + except ValueError as ex: + assert "Unexpected characters after closing '>' in" in str( + ex + ), f"Expected Unexpected characters after closing '>' error, got: {ex}" + + +def test_type_string_to_type_object_bad_decimal(): + try: + type_string_to_type_object("decimal(10,2,5)") + raise AssertionError("Expected ValueError for a malformed decimal argument") + except ValueError: + # "decimal(10,2,5)" doesn't match the DECIMAL_RE regex => unknown type => ValueError + pass + + +def test_type_string_to_type_object_bad_struct_syntax(): + try: + type_string_to_type_object("struct'. + """ + dt = type_string_to_type_object("a: int, b: string") + assert isinstance(dt, StructType), f"Expected StructType, got {dt}" + assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}" + + expected_field_a = StructField("a", IntegerType(), nullable=True) + expected_field_b = StructField("b", StringType(), nullable=True) + + assert ( + dt.fields[0] == expected_field_a + ), f"Expected {expected_field_a}, got {dt.fields[0]}" + assert ( + dt.fields[1] == expected_field_b + ), f"Expected {expected_field_b}, got {dt.fields[1]}" + + +def test_type_string_to_type_object_implicit_struct_single_field(): + """ + Even a single 'name: type' with no commas should parse to StructType + if your parser logic treats it as an implicit struct. + """ + dt = type_string_to_type_object("c: decimal(10,2)") + assert isinstance(dt, StructType), f"Expected StructType, got {dt}" + assert len(dt.fields) == 1, f"Expected 1 field, got {len(dt.fields)}" + + expected_field_c = StructField("c", DecimalType(10, 2), nullable=True) + assert ( + dt.fields[0] == expected_field_c + ), f"Expected {expected_field_c}, got {dt.fields[0]}" + + +def test_type_string_to_type_object_implicit_struct_nested(): + """ + Test an implicit struct with multiple fields, + including nested array/map types. + """ + dt = type_string_to_type_object("arr: array, kv: map") + assert isinstance(dt, StructType), f"Expected StructType, got {dt}" + assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}" + + expected_field_arr = StructField("arr", ArrayType(IntegerType()), nullable=True) + expected_field_kv = StructField( + "kv", MapType(StringType(), DecimalType(5, 2)), nullable=True + ) + + assert ( + dt.fields[0] == expected_field_arr + ), f"Expected {expected_field_arr}, got {dt.fields[0]}" + assert ( + dt.fields[1] == expected_field_kv + ), f"Expected {expected_field_kv}, got {dt.fields[1]}" + + +def test_type_string_to_type_object_implicit_struct_with_spaces(): + """ + Test spacing variations. E.g. " col1 : int , col2 : map< string , decimal(5,2) > ". + """ + dt = type_string_to_type_object( + " col1 : int , col2 : map< string , decimal(5,2) > " + ) + assert isinstance(dt, StructType), f"Expected StructType, got {dt}" + assert len(dt.fields) == 2, f"Expected 2 fields, got {len(dt.fields)}" + + expected_field_col1 = StructField("col1", IntegerType(), nullable=True) + expected_field_col2 = StructField( + "col2", MapType(StringType(), DecimalType(5, 2)), nullable=True + ) + + assert ( + dt.fields[0] == expected_field_col1 + ), f"Expected {expected_field_col1}, got {dt.fields[0]}" + assert ( + dt.fields[1] == expected_field_col2 + ), f"Expected {expected_field_col2}, got {dt.fields[1]}" + + +def test_type_string_to_type_object_implicit_struct_error(): + """ + Check a malformed implicit struct that should raise ValueError + (e.g. trailing comma or missing bracket for nested). + """ + try: + type_string_to_type_object("a: int, b:") + raise AssertionError("Expected ValueError for malformed struct (b: )") + except ValueError as ex: + # We expect an error message about Empty type string + assert "Empty type string" in str( + ex + ), f"Expected error 'Empty type string', got: {ex}" + + try: + type_string_to_type_object("arr: array' + assert "Missing closing" in str( + ex + ), f"Expected Missing closing error, got: {ex}" + + +def test_extract_bracket_content_array_ok(): + s = "array" + # We expect to extract "int" from inside <...> + content = extract_bracket_content(s, keyword="array") + assert content == "int", f"Expected 'int', got {content}" + + +def test_extract_bracket_content_map_spaces(): + s = " map< int , string >" + content = extract_bracket_content(s, keyword="map") + assert content == "int , string", f"Expected 'int , string', got {content}" + + +def test_extract_bracket_content_missing_closing(): + s = "array'") + except ValueError as ex: + assert ( + "Missing closing" in str(ex) or "mismatched" in str(ex).lower() + ), f"Error does not mention missing bracket: {ex}" + + +def test_extract_bracket_content_mismatched_extra_close(): + s = "struct>" + try: + extract_bracket_content(s, keyword="struct") + raise AssertionError("Expected ValueError for extra '>'") + except ValueError as ex: + assert "Unexpected characters after closing '>' in" in str( + ex + ), f"Error does not mention Unexpected characters after closing '>' in: {ex}" + + +def test_split_top_level_comma_fields_no_brackets(): + s = "int, string, decimal(10,2)" + parts = split_top_level_comma_fields(s) + assert parts == ["int", "string", "decimal(10,2)"], f"Got unexpected parts: {parts}" + + +def test_split_top_level_comma_fields_nested_brackets(): + s = "int, array, decimal(10,2), map>" + parts = split_top_level_comma_fields(s) + assert parts == [ + "int", + "array", + "decimal(10,2)", + "map>", + ], f"Got unexpected parts: {parts}" + + +def test_parse_struct_field_list_simple(): + s = "a: int, b: string" + struct_type = parse_struct_field_list(s) + assert ( + len(struct_type.fields) == 2 + ), f"Expected 2 fields, got {len(struct_type.fields)}" + # Direct equality checks on each StructField + from snowflake.snowpark.types import StructField, IntegerType, StringType + + assert struct_type.fields[0] == StructField("a", IntegerType(), nullable=True) + assert struct_type.fields[1] == StructField("b", StringType(), nullable=True) + + +def test_parse_struct_field_list_malformed(): + s = "col1: int, col2" + try: + parse_struct_field_list(s) + raise AssertionError("Expected ValueError for missing type in 'col2'") + except ValueError as ex: + assert ( + "Cannot parse struct field definition" in str(ex) + or "missing" in str(ex).lower() + ), f"Unexpected error message: {ex}" + + +def test_is_likely_struct_true(): + # top-level colon => likely struct + s = "a: int, b: string" + assert is_likely_struct(s) is True, "Expected True for struct-like string" + + +def test_is_likely_struct_false(): + # No top-level colon or comma => not a struct + s = "array" + assert is_likely_struct(s) is False, "Expected False for non-struct string" From 0f073dcfc87239f365c7bf905c73652b62b630dd Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Tue, 7 Jan 2025 10:14:26 -0800 Subject: [PATCH 2/3] fix --- src/snowflake/snowpark/_internal/type_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index fcab1b1d6f7..63f9b167950 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -973,7 +973,7 @@ def get_data_type_string_object_mappings( # support type string format like starting with "map<..." STRUCT_RE = re.compile(r"(?i)^\s*struct\s*<") -# support type string format like starting with "array<..." starting with "struct<..." +# support type string format like starting with "struct<..." def get_number_precision_scale(type_str: str) -> Optional[Tuple[int, int]]: From 64fe8b428ba513e90e9099815fef7ef4b68df669 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Thu, 9 Jan 2025 15:21:22 -0800 Subject: [PATCH 3/3] address comment --- src/snowflake/snowpark/_internal/type_utils.py | 4 ++-- src/snowflake/snowpark/session.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 63f9b167950..aa68c1e3044 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -990,8 +990,8 @@ def get_string_length(type_str: str) -> Optional[int]: def extract_bracket_content(type_str: str, keyword: str) -> str: """ - Given a string that starts with e.g. 'array<', returns the content inside the top-level <...>. - e.g., "array" => "int" + Given a string that starts with e.g. "array<", returns the content inside the top-level <...>. + e.g., "array" => "int". It also parses the nested array like "array>". Raises ValueError on mismatched or missing bracket. """ prefix_pattern = rf"(?i)^\s*{keyword}\s*<" diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 90496aab269..7c252e0c119 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3191,7 +3191,7 @@ def create_dataframe( f"Invalid schema string: {schema}. " f"You should provide a valid schema string representing a struct type." ) - if isinstance(schema, (StructType, str)): + if isinstance(schema, StructType): new_schema = schema # SELECT query has an undefined behavior for nullability, so if the schema requires non-nullable column and # all columns are primitive type columns, we use a temp table to lock in the nullabilities.