Skip to content

Commit

Permalink
SNOW-1794355: [API Coverage] StructType (#2623)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang authored Nov 18, 2024
1 parent d4b03af commit 53291b4
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 6 deletions.
17 changes: 11 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

#### New Features

- Added support for following methods in class `DataType`, derived class of `DataType` and `StructField`:
- `type_name`
- `simple_string`
- `json_value`
- `json`
- Added support for variables `keyType` and `valueType` in class `MapType`
- Added new methods and variables to enhance data type handling and JSON serialization/deserialization:
- To `DataType`, its derived classes, and `StructField`:
- `type_name`: Returns the type name of the data.
- `simple_string`: Provides a simple string representation of the data.
- `json_value`: Returns the data as a JSON-compatible value.
- `json`: Converts the data to a JSON string.
- To `ArrayType`, `MapType`, `StructField`, and `StructType`:
- `from_json`: Enables these types to be created from JSON data.
- To `MapType`:
- `keyType`: keys of the map
- `valueType`: values of the map

#### Improvements

Expand Down
90 changes: 90 additions & 0 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ def __repr__(self) -> str:
def is_primitive(self):
return False

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "ArrayType":
return ArrayType(
_parse_datatype_json_value(
json_dict["elementType"]
if "elementType" in json_dict
else json_dict["element_type"]
)
)

def simple_string(self) -> str:
return f"array<{self.element_type.simple_string()}>"

Expand All @@ -297,6 +307,7 @@ def json_value(self) -> Dict[str, Any]:

simpleString = simple_string
jsonValue = json_value
fromJson = from_json


class MapType(DataType):
Expand All @@ -318,6 +329,21 @@ def __repr__(self) -> str:
def is_primitive(self):
return False

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "MapType":
return MapType(
_parse_datatype_json_value(
json_dict["keyType"]
if "keyType" in json_dict
else json_dict["key_type"]
),
_parse_datatype_json_value(
json_dict["valueType"]
if "valueType" in json_dict
else json_dict["value_type"]
),
)

def simple_string(self) -> str:
return f"map<{self.key_type.simple_string()},{self.value_type.simple_string()}>"

Expand All @@ -338,6 +364,7 @@ def valueType(self):

simpleString = simple_string
jsonValue = json_value
fromJson = from_json


class VectorType(DataType):
Expand Down Expand Up @@ -461,6 +488,14 @@ def __repr__(self) -> str:
def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "StructField":
return StructField(
json_dict["name"],
_parse_datatype_json_value(json_dict["type"]),
json_dict["nullable"],
)

def simple_string(self) -> str:
return f"{self.name}:{self.datatype.simple_string()}"

Expand All @@ -482,6 +517,7 @@ def type_name(self) -> str:
typeName = type_name
simpleString = simple_string
jsonValue = json_value
fromJson = from_json


class StructType(DataType):
Expand Down Expand Up @@ -554,6 +590,10 @@ def names(self) -> List[str]:
"""Returns the list of names of the :class:`StructField`"""
return [f.name for f in self.fields]

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "StructType":
return StructType([StructField.fromJson(f) for f in json_dict["fields"]])

def simple_string(self) -> str:
return f"struct<{','.join(f.simple_string() for f in self)}>"

Expand All @@ -562,6 +602,8 @@ def json_value(self) -> Dict[str, Any]:

simpleString = simple_string
jsonValue = json_value
fieldNames = names
fromJson = from_json


class VariantType(DataType):
Expand Down Expand Up @@ -614,6 +656,54 @@ def get_snowflake_col_datatypes(self):
]


_atomic_types: List[Type[DataType]] = [
StringType,
BinaryType,
BooleanType,
DecimalType,
FloatType,
DoubleType,
ByteType,
ShortType,
IntegerType,
LongType,
DateType,
TimestampType,
NullType,
]
_all_atomic_types: Dict[str, Type[DataType]] = {t.typeName(): t for t in _atomic_types}

_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [
ArrayType,
MapType,
StructType,
]
_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = {
v.typeName(): v for v in _complex_types
}

_FIXED_DECIMAL_PATTERN = re.compile(r"decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)")


def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
if not isinstance(json_value, dict):
if json_value in _all_atomic_types.keys():
return _all_atomic_types[json_value]()
elif json_value == "decimal":
return DecimalType()
elif _FIXED_DECIMAL_PATTERN.match(json_value):
m = _FIXED_DECIMAL_PATTERN.match(json_value)
return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr]
else:
raise ValueError(f"Cannot parse data type: {str(json_value)}")
else:
tpe = json_value["type"]
if tpe in _all_complex_types:
return _all_complex_types[tpe].fromJson(json_value)
else:
raise ValueError(f"Unsupported data type: {str(tpe)}")


#: The type hint for annotating Variant data when registering UDFs.
Variant = TypeVar("Variant")

Expand Down
83 changes: 83 additions & 0 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,89 @@ def test_datatype(tpe, simple_string, json, type_name, json_value):
assert tpe.typeName() == type_name


@pytest.mark.parametrize(
"datatype, tpe",
[
(
MapType,
MapType(IntegerType(), StringType()),
),
(
MapType,
MapType(StringType(), MapType(IntegerType(), StringType())),
),
(
ArrayType,
ArrayType(IntegerType()),
),
(
ArrayType,
ArrayType(ArrayType(IntegerType())),
),
(
StructType,
StructType(
[
StructField(
"nested",
StructType(
[
StructField("A", IntegerType()),
StructField("B", StringType()),
]
),
)
]
),
),
(
StructField,
StructField("AA", StringType()),
),
(
StructType,
StructType(
[StructField("a", StringType()), StructField("b", IntegerType())]
),
),
(
StructField,
StructField("AA", DecimalType()),
),
(
StructField,
StructField("AA", DecimalType(20, 10)),
),
],
)
def test_structtype_from_json(datatype, tpe):
json_dict = tpe.json_value()
new_obj = datatype.from_json(json_dict)
assert new_obj == tpe


def test_from_json_wrong_data_type():
wrong_json = {
"name": "AA",
"type": "wrong_type",
"nullable": True,
}
with pytest.raises(ValueError, match="Cannot parse data type: wrong_type"):
StructField.from_json(wrong_json)

wrong_json = {
"name": "AA",
"type": {
"type": "wrong_type",
"key_type": "integer",
"value_type": "string",
},
"nullable": True,
}
with pytest.raises(ValueError, match="Unsupported data type: wrong_type"):
StructField.from_json(wrong_json)


def test_maptype_alias():
expected_key = StringType()
expected_value = IntegerType()
Expand Down

0 comments on commit 53291b4

Please sign in to comment.