Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton committed Feb 26, 2024
1 parent 6a7ce69 commit 466eeba
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,26 @@ def _get_outputs_path(destination: Union[Parameter, Artifact]) -> Path:
return path


def _get_annotated_input_param_value(
func_param_name: str,
param_annotation: Parameter,
kwargs: Dict[str, str],
) -> str:
if param_annotation.name in kwargs:
return kwargs[param_annotation.name]

if func_param_name in kwargs:
return kwargs[func_param_name]

raise RuntimeError(
f"Parameter {param_annotation.name if param_annotation.name else func_param_name} was not given a value"
)


def get_annotated_param_value(
func_param_name: str,
param_annotation: Parameter,
kwargs: dict[str, Union[Path, str]],
kwargs: Dict[str, str],
) -> Union[Path, str]:
"""Get the value from a given function param and its annotation.
Expand All @@ -62,16 +78,7 @@ def get_annotated_param_value(
# Automatically create the parent directory (if required)
path.parent.mkdir(parents=True, exist_ok=True)
return path

if param_annotation.name in kwargs:
return kwargs[param_annotation.name]

if func_param_name in kwargs:
return kwargs[func_param_name]

raise RuntimeError(
f"Parameter {param_annotation.name if param_annotation.name else func_param_name} was not given a value"
)
return _get_annotated_input_param_value(func_param_name, param_annotation, kwargs)


def get_annotated_artifact_value(artifact_annotation: Artifact) -> Union[Path, Any]:
Expand Down Expand Up @@ -116,7 +123,7 @@ def get_annotated_artifact_value(artifact_annotation: Artifact) -> Union[Path, A

def map_runner_input(
runner_input_class: T,
kwargs: dict[str, Union[Path, str]],
kwargs: Dict[str, str],
) -> T:
"""Map argo input kwargs to the fields of the given RunnerInput, return an instance of the class.
Expand All @@ -133,14 +140,14 @@ def load_parameter_value(value: str) -> Any:

def map_field(
field: str,
kwargs: dict[str, Union[Path, str]],
kwargs: Dict[str, str],
) -> Any:
annotation = runner_input_class.__annotations__[field]
if get_origin(annotation) is Annotated:
annotation = get_args(annotation)[1]
if isinstance(annotation, Parameter):
assert not annotation.output
return load_parameter_value(get_annotated_param_value(field, annotation, kwargs))
return load_parameter_value(_get_annotated_input_param_value(field, annotation, kwargs))

if isinstance(annotation, Artifact):
return get_annotated_artifact_value(annotation)
Expand All @@ -151,7 +158,7 @@ def map_field(
for field in runner_input_class.__fields__:
input_model_obj[field] = map_field(field, kwargs)

return runner_input_class.parse_raw(json.dumps(input_model_obj))
return cast(T, runner_input_class.parse_raw(json.dumps(input_model_obj)))


def _map_argo_inputs_to_function(function: Callable, kwargs: Dict[str, str]) -> Dict:
Expand Down

0 comments on commit 466eeba

Please sign in to comment.