From ac8fab72dcf4c8e148bab2f04d05f248d64129f3 Mon Sep 17 00:00:00 2001 From: Elliot Gunton Date: Mon, 26 Feb 2024 14:44:28 +0000 Subject: [PATCH] Use real type to check whether to load parameter value * Also tidy up types Signed-off-by: Elliot Gunton --- .../_runner/script_annotations_util.py | 31 ++++++++++++------- src/hera/workflows/_runner/util.py | 23 +++++++------- .../test_unit/test_script_annotations_util.py | 14 +++++++++ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/hera/workflows/_runner/script_annotations_util.py b/src/hera/workflows/_runner/script_annotations_util.py index 924b53de7..9a4838991 100644 --- a/src/hera/workflows/_runner/script_annotations_util.py +++ b/src/hera/workflows/_runner/script_annotations_util.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast +from hera.shared._pydantic import BaseModel, get_fields from hera.shared.serialization import serialize from hera.workflows import Artifact, Parameter from hera.workflows.artifact import ArtifactLoader @@ -118,7 +119,7 @@ def get_annotated_artifact_value(artifact_annotation: Artifact) -> Union[Path, A raise RuntimeError(f"Artifact {artifact_annotation.name} was not given a value") -T = TypeVar("T", bound=Union[RunnerInputV1, RunnerInputV2]) +T = TypeVar("T", bound=Type[BaseModel]) def map_runner_input( @@ -130,9 +131,14 @@ def map_runner_input( If the field is annotated, we look for the kwarg with the name from the annotation (Parameter or Artifact). Otherwise, we look for the kwarg with the name of the field. """ + from hera.workflows._runner.util import _get_type + input_model_obj = {} - def load_parameter_value(value: str) -> Any: + def load_parameter_value(value: str, value_type: type) -> Any: + if issubclass(_get_type(value_type), str): + return value + try: return json.loads(value) except json.JSONDecodeError: @@ -144,18 +150,19 @@ def map_field( ) -> Any: annotation = runner_input_class.__annotations__[field] if get_origin(annotation) is Annotated: - annotation = get_args(annotation)[1] - if isinstance(annotation, Parameter): - assert not annotation.output - return load_parameter_value(_get_annotated_input_param_value(field, annotation, kwargs)) + meta_annotation = get_args(annotation)[1] + if isinstance(meta_annotation, Parameter): + assert not meta_annotation.output + return load_parameter_value( + _get_annotated_input_param_value(field, meta_annotation, kwargs), get_args(annotation)[0] + ) - if isinstance(annotation, Artifact): - return get_annotated_artifact_value(annotation) + if isinstance(meta_annotation, Artifact): + return get_annotated_artifact_value(meta_annotation) - # change to _parse to better deal with raw strings and json-serialised strings - return load_parameter_value(kwargs[field]) + return load_parameter_value(kwargs[field], annotation) - for field in runner_input_class.__fields__: + for field in get_fields(runner_input_class): input_model_obj[field] = map_field(field, kwargs) return cast(T, runner_input_class.parse_raw(json.dumps(input_model_obj))) @@ -288,7 +295,7 @@ def _save_dummy_outputs( if os.environ.get("hera__script_pydantic_io", None) is None: raise ValueError("hera__script_pydantic_io environment variable is not set") - for field, _ in dest.__fields__: + for field in get_fields(dest): if field in {"exit_code", "result"}: continue diff --git a/src/hera/workflows/_runner/util.py b/src/hera/workflows/_runner/util.py index 98c385142..c27b014eb 100644 --- a/src/hera/workflows/_runner/util.py +++ b/src/hera/workflows/_runner/util.py @@ -42,7 +42,7 @@ ) -def _ignore_unmatched_kwargs(f: Callable): +def _ignore_unmatched_kwargs(f: Callable) -> Callable: """Make function ignore unmatched kwargs. If the function already has the catch all **kwargs, do nothing. @@ -73,7 +73,7 @@ def _is_kwarg_of(key: str, f: Callable) -> bool: ) -def _parse(value: str, key: str, f: Callable): +def _parse(value: str, key: str, f: Callable) -> Any: """Parse a value to the correct type. Args: @@ -106,10 +106,7 @@ def _parse(value: str, key: str, f: Callable): return value -def _get_type(key: str, f: Callable) -> Optional[type]: - type_ = inspect.signature(f).parameters[key].annotation - if type_ is inspect.Parameter.empty: - return None +def _get_type(type_: type) -> type: if get_origin(type_) is None: return type_ origin_type = cast(type, get_origin(type_)) @@ -136,13 +133,15 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]: def _is_str_kwarg_of(key: str, f: Callable) -> bool: """Check if param `key` of function `f` has a type annotation of a subclass of str.""" - type_ = _get_type(key, f) - if type_ is None: + func_param_annotation = inspect.signature(f).parameters[key].annotation + if func_param_annotation is inspect.Parameter.empty: return False + + type_ = _get_type(func_param_annotation) return issubclass(type_, str) -def _is_artifact_loaded(key: str, f: Callable): +def _is_artifact_loaded(key: str, f: Callable) -> bool: """Check if param `key` of function `f` is actually an Artifact that has already been loaded.""" param = inspect.signature(f).parameters[key] return ( @@ -152,7 +151,7 @@ def _is_artifact_loaded(key: str, f: Callable): ) -def _is_output_kwarg(key: str, f: Callable): +def _is_output_kwarg(key: str, f: Callable) -> bool: """Check if param `key` of function `f` is an output Artifact/Parameter.""" param = inspect.signature(f).parameters[key] return ( @@ -228,7 +227,7 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any: return function(**kwargs) -def _parse_args(): +def _parse_args() -> argparse.Namespace: """Creates an argparse for the runner function. The returned argparse takes a module and function name as flags and a path to a json file as an argument. @@ -239,7 +238,7 @@ def _parse_args(): return parser.parse_args() -def _run(): +def _run() -> None: """Runs a function from a specific path using parsed arguments from Argo. Note that this prints the result of the function to stdout, which is the normal mode of operation for Argo. Any diff --git a/tests/test_unit/test_script_annotations_util.py b/tests/test_unit/test_script_annotations_util.py index 3e9ca4dac..991f5145d 100644 --- a/tests/test_unit/test_script_annotations_util.py +++ b/tests/test_unit/test_script_annotations_util.py @@ -160,3 +160,17 @@ class MyInput(RunnerInput): a_dict={"a-key": "a-value"}, a_list=[1, 2, 3], ) + + +def test_map_runner_input_strings(): + """Test the parsing logic when str type fields are passed json-serialized strings.""" + + class MyInput(RunnerInput): + a_dict_str: str + a_list_str: str + + kwargs = {"a_dict_str": json.dumps({"key": "value"}), "a_list_str": json.dumps([1, 2, 3])} + assert map_runner_input(MyInput, kwargs) == MyInput( + a_dict_str=json.dumps({"key": "value"}), + a_list_str=json.dumps([1, 2, 3]), + )