-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1871175: Add support for specifying a schema string for DataFrame.create_dataframe
#2828
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -966,6 +966,15 @@ def get_data_type_string_object_mappings( | |
STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$") | ||
# support type string format like " string ( 23 ) " | ||
|
||
ARRAY_RE = re.compile(r"(?i)^\s*array\s*<") | ||
# support type string format like starting with "array<..." | ||
|
||
MAP_RE = re.compile(r"(?i)^\s*map\s*<") | ||
# support type string format like starting with "map<..." | ||
|
||
STRUCT_RE = re.compile(r"(?i)^\s*struct\s*<") | ||
# support type string format like starting with "struct<..." | ||
|
||
|
||
def get_number_precision_scale(type_str: str) -> Optional[Tuple[int, int]]: | ||
decimal_matches = DECIMAL_RE.match(type_str) | ||
|
@@ -979,7 +988,147 @@ def get_string_length(type_str: str) -> Optional[int]: | |
return int(string_matches.group(2)) | ||
|
||
|
||
def extract_bracket_content(type_str: str, keyword: str) -> str: | ||
""" | ||
Given a string that starts with e.g. "array<", returns the content inside the top-level <...>. | ||
e.g., "array<int>" => "int". It also parses the nested array like "array<array<...>>". | ||
Raises ValueError on mismatched or missing bracket. | ||
""" | ||
prefix_pattern = rf"(?i)^\s*{keyword}\s*<" | ||
match = re.match(prefix_pattern, type_str) | ||
if not match: | ||
raise ValueError( | ||
f"'{type_str}' does not match expected '{keyword}<...>' syntax." | ||
) | ||
|
||
start_index = match.end() - 1 # position at '<' | ||
bracket_depth = 0 | ||
inside_chars: List[str] = [] | ||
i = start_index | ||
while i < len(type_str): | ||
c = type_str[i] | ||
if c == "<": | ||
bracket_depth += 1 | ||
# we don't store the opening bracket in 'inside_chars' | ||
# if bracket_depth was 0 -> 1, to skip the outer bracket | ||
if bracket_depth > 1: | ||
inside_chars.append(c) | ||
elif c == ">": | ||
bracket_depth -= 1 | ||
if bracket_depth < 0: | ||
raise ValueError(f"Mismatched '>' in '{type_str}'") | ||
if bracket_depth == 0: | ||
if i != len(type_str) - 1: | ||
raise ValueError( | ||
f"Unexpected characters after closing '>' in '{type_str}'" | ||
) | ||
# done | ||
return "".join(inside_chars).strip() | ||
inside_chars.append(c) | ||
else: | ||
inside_chars.append(c) | ||
i += 1 | ||
|
||
raise ValueError(f"Missing closing '>' in '{type_str}'.") | ||
|
||
|
||
def parse_struct_field_list(fields_str: str) -> StructType: | ||
""" | ||
Parse something like "a: int, b: string, c: array<int>" | ||
into StructType([StructField('a', IntegerType()), ...]). | ||
""" | ||
fields = [] | ||
field_defs = split_top_level_comma_fields(fields_str) | ||
for field_def in field_defs: | ||
# Try splitting on colon first, else whitespace | ||
if ":" in field_def: | ||
left, right = field_def.split(":", 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we consider multiple colon cases like "a:b:c"? or this is handled by upstream/downstream logic already There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nah, PySpark's simpleString format only considers the first colon or whitespace. |
||
else: | ||
parts = field_def.split(None, 1) | ||
if len(parts) != 2: | ||
raise ValueError(f"Cannot parse struct field definition: '{field_def}'") | ||
left, right = parts[0], parts[1] | ||
|
||
field_name = left.strip() | ||
type_part = right.strip() | ||
if not field_name: | ||
raise ValueError(f"Struct field missing name in '{field_def}'") | ||
|
||
field_type = type_string_to_type_object(type_part) | ||
fields.append(StructField(field_name, field_type, nullable=True)) | ||
|
||
return StructType(fields) | ||
|
||
|
||
def split_top_level_comma_fields(s: str) -> List[str]: | ||
""" | ||
Splits 's' by commas not enclosed in matching brackets. | ||
Example: "int, array<long>, decimal(10,2)" => ["int", "array<long>", "decimal(10,2)"]. | ||
""" | ||
parts = [] | ||
bracket_depth = 0 | ||
start_idx = 0 | ||
for i, c in enumerate(s): | ||
if c in ["<", "("]: | ||
bracket_depth += 1 | ||
elif c in [">", ")"]: | ||
bracket_depth -= 1 | ||
if bracket_depth < 0: | ||
raise ValueError(f"Mismatched bracket in '{s}'.") | ||
Comment on lines
+1071
to
+1077
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel this bracket check logic has repeated multiple times There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to parse bracket to split fields, and extract names and types anyway. There is indeed a duplicate of validating whether the bracket expression is valid or not, maybe we can remove it. But to make the function self-contained, maybe let's still keep it? They are also covered in the test. |
||
elif c == "," and bracket_depth == 0: | ||
parts.append(s[start_idx:i].strip()) | ||
start_idx = i + 1 | ||
parts.append(s[start_idx:].strip()) | ||
return parts | ||
|
||
|
||
def is_likely_struct(s: str) -> bool: | ||
""" | ||
Heuristic: If there's a top-level comma or colon outside brackets, | ||
treat it like a struct with multiple fields, e.g. "a: int, b: string". | ||
""" | ||
bracket_depth = 0 | ||
for c in s: | ||
if c in ["<", "("]: | ||
bracket_depth += 1 | ||
elif c in [">", ")"]: | ||
bracket_depth -= 1 | ||
elif (c in [":", ","]) and bracket_depth == 0: | ||
return True | ||
return False | ||
|
||
|
||
def type_string_to_type_object(type_str: str) -> DataType: | ||
type_str = type_str.strip() | ||
if not type_str: | ||
raise ValueError("Empty type string") | ||
|
||
# First check if this might be a top-level multi-field struct | ||
# (e.g. "a: int, b: string") even if not written as "struct<...>" | ||
if is_likely_struct(type_str): | ||
return parse_struct_field_list(type_str) | ||
|
||
# Check for array<...> | ||
if ARRAY_RE.match(type_str): | ||
inner = extract_bracket_content(type_str, "array") | ||
element_type = type_string_to_type_object(inner) | ||
return ArrayType(element_type) | ||
|
||
# Check for map<key, value> | ||
if MAP_RE.match(type_str): | ||
inner = extract_bracket_content(type_str, "map") | ||
parts = split_top_level_comma_fields(inner) | ||
if len(parts) != 2: | ||
raise ValueError(f"Invalid map type definition: '{type_str}'") | ||
key_type = type_string_to_type_object(parts[0]) | ||
val_type = type_string_to_type_object(parts[1]) | ||
return MapType(key_type, val_type) | ||
|
||
# Check for explicit struct<...> | ||
if STRUCT_RE.match(type_str): | ||
inner = extract_bracket_content(type_str, "struct") | ||
return parse_struct_field_list(inner) | ||
|
||
precision_scale = get_number_precision_scale(type_str) | ||
if precision_scale: | ||
return DecimalType(*precision_scale) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this form allowed
array<<<...>>>
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea for something like
array<array<...>>
, added a comment here