From 39a4f3e942e1337436bb90d78d8cf53c490b5d6b Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 15 Jan 2025 12:17:33 +0800 Subject: [PATCH] Pydantic Transformer guess python type Signed-off-by: Future-Outlier --- .../pydantic_transformer/transformer.py | 89 ++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/flytekit/extras/pydantic_transformer/transformer.py b/flytekit/extras/pydantic_transformer/transformer.py index e9048d8880..047bac1421 100644 --- a/flytekit/extras/pydantic_transformer/transformer.py +++ b/flytekit/extras/pydantic_transformer/transformer.py @@ -1,11 +1,11 @@ import json import os -from typing import Type +from typing import Any, List, Optional, Type import msgpack from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct -from pydantic import BaseModel +from pydantic import BaseModel, create_model from flytekit import FlyteContext from flytekit.core.constants import CACHE_KEY_METADATA, FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK, SERIALIZATION_FORMAT @@ -103,5 +103,90 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True}) return python_val + def guess_python_type(self, literal_type: LiteralType) -> Type[BaseModel]: + """ + Reconstructs the Pydantic BaseModel subclass from the JSON schema stored in LiteralType metadata. + """ + + schema = literal_type.metadata + model_name = schema.get("title", "DynamicModel") + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + annotations = {} + field_definitions = {} + + for field_name, field_info in properties.items(): + field_type = self._map_json_type_to_python(field_info) + annotations[field_name] = Optional[field_type] if field_name not in required_fields else field_type + field_definitions[field_name] = (field_type, ... if field_name in required_fields else None) + + try: + DynamicModel = create_model(model_name, **field_definitions) + return DynamicModel + except Exception as e: + raise TypeTransformerFailedError(f"Failed to create Pydantic model from schema: {e}") + + def _map_json_type_to_python(self, field_info: dict) -> Any: + """ + Maps JSON schema types to Python types for Pydantic model fields. + """ + json_type = field_info.get("type") + if isinstance(json_type, list): + # Handle Union types like ["string", "null"] + json_type = [t for t in json_type if t != "null"] + if len(json_type) == 1: + json_type = json_type[0] + else: + # More complex unions can be handled here if needed + json_type = "string" # default fallback + + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "object": dict, + "array": list, + } + + python_type = type_mapping.get(json_type, Any) + + # Handle nested objects + if python_type == dict and "properties" in field_info: + # Recursively create a nested Pydantic model + nested_model = self._create_nested_model(field_info) + return nested_model + + # Handle arrays with specified items + if python_type == list and "items" in field_info: + item_type = self._map_json_type_to_python(field_info["items"]) + return List[item_type] + + return python_type + + def _create_nested_model(self, field_info: dict) -> Type[BaseModel]: + """ + Recursively creates nested Pydantic models for objects within the schema. + """ + properties = field_info.get("properties", {}) + required_fields = field_info.get("required", []) + + model_name = field_info.get("title", "NestedModel") + annotations = {} + field_definitions = {} + + for field_name, sub_field_info in properties.items(): + sub_field_type = self._map_json_type_to_python(sub_field_info) + annotations[field_name] = Optional[sub_field_type] if field_name not in required_fields else sub_field_type + field_definitions[field_name] = (sub_field_type, ... if field_name in required_fields else None) + + try: + NestedModel = create_model(model_name, **field_definitions) + return NestedModel + except Exception as e: + raise TypeTransformerFailedError(f"Failed to create nested Pydantic model: {e}") + TypeEngine.register(PydanticTransformer())