Skip to content

Commit

Permalink
Use real type to check whether to load parameter value
Browse files Browse the repository at this point in the history
* Also tidy up types

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton committed Feb 26, 2024
1 parent 4dd83a6 commit ac8fab7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
31 changes: 19 additions & 12 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

from hera.shared._pydantic import BaseModel, get_fields
from hera.shared.serialization import serialize
from hera.workflows import Artifact, Parameter
from hera.workflows.artifact import ArtifactLoader
Expand Down Expand Up @@ -118,7 +119,7 @@ def get_annotated_artifact_value(artifact_annotation: Artifact) -> Union[Path, A
raise RuntimeError(f"Artifact {artifact_annotation.name} was not given a value")


T = TypeVar("T", bound=Union[RunnerInputV1, RunnerInputV2])
T = TypeVar("T", bound=Type[BaseModel])


def map_runner_input(
Expand All @@ -130,9 +131,14 @@ def map_runner_input(
If the field is annotated, we look for the kwarg with the name from the annotation (Parameter or Artifact).
Otherwise, we look for the kwarg with the name of the field.
"""
from hera.workflows._runner.util import _get_type

input_model_obj = {}

def load_parameter_value(value: str) -> Any:
def load_parameter_value(value: str, value_type: type) -> Any:
if issubclass(_get_type(value_type), str):
return value

try:
return json.loads(value)
except json.JSONDecodeError:
Expand All @@ -144,18 +150,19 @@ def map_field(
) -> Any:
annotation = runner_input_class.__annotations__[field]
if get_origin(annotation) is Annotated:
annotation = get_args(annotation)[1]
if isinstance(annotation, Parameter):
assert not annotation.output
return load_parameter_value(_get_annotated_input_param_value(field, annotation, kwargs))
meta_annotation = get_args(annotation)[1]
if isinstance(meta_annotation, Parameter):
assert not meta_annotation.output
return load_parameter_value(
_get_annotated_input_param_value(field, meta_annotation, kwargs), get_args(annotation)[0]
)

if isinstance(annotation, Artifact):
return get_annotated_artifact_value(annotation)
if isinstance(meta_annotation, Artifact):
return get_annotated_artifact_value(meta_annotation)

# change to _parse to better deal with raw strings and json-serialised strings
return load_parameter_value(kwargs[field])
return load_parameter_value(kwargs[field], annotation)

for field in runner_input_class.__fields__:
for field in get_fields(runner_input_class):
input_model_obj[field] = map_field(field, kwargs)

return cast(T, runner_input_class.parse_raw(json.dumps(input_model_obj)))
Expand Down Expand Up @@ -288,7 +295,7 @@ def _save_dummy_outputs(
if os.environ.get("hera__script_pydantic_io", None) is None:
raise ValueError("hera__script_pydantic_io environment variable is not set")

for field, _ in dest.__fields__:
for field in get_fields(dest):
if field in {"exit_code", "result"}:
continue

Expand Down
23 changes: 11 additions & 12 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)


def _ignore_unmatched_kwargs(f: Callable):
def _ignore_unmatched_kwargs(f: Callable) -> Callable:
"""Make function ignore unmatched kwargs.
If the function already has the catch all **kwargs, do nothing.
Expand Down Expand Up @@ -73,7 +73,7 @@ def _is_kwarg_of(key: str, f: Callable) -> bool:
)


def _parse(value: str, key: str, f: Callable):
def _parse(value: str, key: str, f: Callable) -> Any:
"""Parse a value to the correct type.
Args:
Expand Down Expand Up @@ -106,10 +106,7 @@ def _parse(value: str, key: str, f: Callable):
return value


def _get_type(key: str, f: Callable) -> Optional[type]:
type_ = inspect.signature(f).parameters[key].annotation
if type_ is inspect.Parameter.empty:
return None
def _get_type(type_: type) -> type:
if get_origin(type_) is None:
return type_
origin_type = cast(type, get_origin(type_))
Expand All @@ -136,13 +133,15 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]:

def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
type_ = _get_type(key, f)
if type_ is None:
func_param_annotation = inspect.signature(f).parameters[key].annotation
if func_param_annotation is inspect.Parameter.empty:
return False

type_ = _get_type(func_param_annotation)
return issubclass(type_, str)


def _is_artifact_loaded(key: str, f: Callable):
def _is_artifact_loaded(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` is actually an Artifact that has already been loaded."""
param = inspect.signature(f).parameters[key]
return (
Expand All @@ -152,7 +151,7 @@ def _is_artifact_loaded(key: str, f: Callable):
)


def _is_output_kwarg(key: str, f: Callable):
def _is_output_kwarg(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` is an output Artifact/Parameter."""
param = inspect.signature(f).parameters[key]
return (
Expand Down Expand Up @@ -228,7 +227,7 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any:
return function(**kwargs)


def _parse_args():
def _parse_args() -> argparse.Namespace:
"""Creates an argparse for the runner function.
The returned argparse takes a module and function name as flags and a path to a json file as an argument.
Expand All @@ -239,7 +238,7 @@ def _parse_args():
return parser.parse_args()


def _run():
def _run() -> None:
"""Runs a function from a specific path using parsed arguments from Argo.
Note that this prints the result of the function to stdout, which is the normal mode of operation for Argo. Any
Expand Down
14 changes: 14 additions & 0 deletions tests/test_unit/test_script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,17 @@ class MyInput(RunnerInput):
a_dict={"a-key": "a-value"},
a_list=[1, 2, 3],
)


def test_map_runner_input_strings():
"""Test the parsing logic when str type fields are passed json-serialized strings."""

class MyInput(RunnerInput):
a_dict_str: str
a_list_str: str

kwargs = {"a_dict_str": json.dumps({"key": "value"}), "a_list_str": json.dumps([1, 2, 3])}
assert map_runner_input(MyInput, kwargs) == MyInput(
a_dict_str=json.dumps({"key": "value"}),
a_list_str=json.dumps([1, 2, 3]),
)

0 comments on commit ac8fab7

Please sign in to comment.