diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 87150c47c1..05690e175b 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -128,6 +128,9 @@ def __init__( **kwargs, ) + self.sub_node_metadata: NodeMetadata = super().construct_node_metadata() + self.sub_node_metadata._name = self.name + @property def name(self) -> str: return self._name @@ -137,16 +140,13 @@ def python_interface(self): return self._collection_interface def construct_node_metadata(self) -> NodeMetadata: - # TODO: add support for other Flyte entities + """ + This returns metadata for the parent ArrayNode, not the sub-node getting mapped over + """ return NodeMetadata( name=self.name, ) - def construct_sub_node_metadata(self) -> NodeMetadata: - nm = super().construct_node_metadata() - nm._name = self.name - return nm - @property def min_success_ratio(self) -> Optional[float]: return self._min_success_ratio diff --git a/flytekit/core/node.py b/flytekit/core/node.py index ea089c6fd3..61ae41c060 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -124,6 +124,57 @@ def run_entity(self) -> Any: def metadata(self) -> _workflow_model.NodeMetadata: return self._metadata + def _override_node_metadata( + self, + name, + timeout: Optional[Union[int, datetime.timedelta]] = None, + retries: Optional[int] = None, + interruptible: typing.Optional[bool] = None, + cache: typing.Optional[bool] = None, + cache_version: typing.Optional[str] = None, + cache_serialize: typing.Optional[bool] = None, + ): + from flytekit.core.array_node_map_task import ArrayNodeMapTask + + if isinstance(self.flyte_entity, ArrayNodeMapTask): + # override the sub-node's metadata + node_metadata = self.flyte_entity.sub_node_metadata + else: + node_metadata = self._metadata + + if timeout is None: + node_metadata._timeout = datetime.timedelta() + elif isinstance(timeout, int): + node_metadata._timeout = datetime.timedelta(seconds=timeout) + elif isinstance(timeout, datetime.timedelta): + node_metadata._timeout = timeout + else: + raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") + if retries is not None: + assert_not_promise(retries, "retries") + node_metadata._retries = ( + _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) + ) + + if interruptible is not None: + assert_not_promise(interruptible, "interruptible") + node_metadata._interruptible = interruptible + + if name is not None: + node_metadata._name = name + + if cache is not None: + assert_not_promise(cache, "cache") + node_metadata._cacheable = cache + + if cache_version is not None: + assert_not_promise(cache_version, "cache_version") + node_metadata._cache_version = cache_version + + if cache_serialize is not None: + assert_not_promise(cache_serialize, "cache_serialize") + node_metadata._cache_serializable = cache_serialize + def with_overrides( self, node_name: Optional[str] = None, @@ -174,27 +225,6 @@ def with_overrides( assert_no_promises_in_resources(resources) self._resources = resources - if timeout is None: - self._metadata._timeout = datetime.timedelta() - elif isinstance(timeout, int): - self._metadata._timeout = datetime.timedelta(seconds=timeout) - elif isinstance(timeout, datetime.timedelta): - self._metadata._timeout = timeout - else: - raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") - if retries is not None: - assert_not_promise(retries, "retries") - self._metadata._retries = ( - _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) - ) - - if interruptible is not None: - assert_not_promise(interruptible, "interruptible") - self._metadata._interruptible = interruptible - - if name is not None: - self._metadata._name = name - if task_config is not None: logger.warning("This override is beta. We may want to revisit this in the future.") if not isinstance(task_config, type(self.run_entity._task_config)): @@ -209,17 +239,7 @@ def with_overrides( assert_not_promise(accelerator, "accelerator") self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl()) - if cache is not None: - assert_not_promise(cache, "cache") - self._metadata._cacheable = cache - - if cache_version is not None: - assert_not_promise(cache_version, "cache_version") - self._metadata._cache_version = cache_version - - if cache_serialize is not None: - assert_not_promise(cache_serialize, "cache_serialize") - self._metadata._cache_serializable = cache_serialize + self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize) return self diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index ee905a4218..e74f4c1c71 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -624,7 +624,7 @@ def get_serializable_array_node_map_task( ) node = workflow_model.Node( id=entity.name, - metadata=entity.construct_sub_node_metadata(), + metadata=entity.sub_node_metadata, inputs=node.bindings, upstream_node_ids=[], output_aliases=[], diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index ed1fc7fdd0..97693940e0 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -9,7 +9,7 @@ import pytest from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask +from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, Resources from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver @@ -21,6 +21,7 @@ LiteralMap, LiteralOffloadedMetadata, ) +from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable from flytekit.types.directory import FlyteDirectory @@ -349,16 +350,59 @@ def my_wf1() -> typing.List[typing.Optional[int]]: assert my_wf1() == [1, None, 3, 4] -def test_map_task_override(serialization_settings): - @task - def my_mappable_task(a: int) -> typing.Optional[str]: - return str(a) +@task +def my_mappable_task(a: int) -> typing.Optional[str]: + return str(a) + + +@task( + container_image="original-image", + timeout=timedelta(seconds=10), + interruptible=False, + retries=10, + cache=True, + cache_version="original-version", + requests=Resources(cpu=1) +) +def my_mappable_task_1(a: int) -> typing.Optional[str]: + return str(a) + + +@pytest.mark.parametrize( + "task_func", + [my_mappable_task, my_mappable_task_1] +) +def test_map_task_override(serialization_settings, task_func): + array_node_map_task = map_task(task_func) @workflow def wf(x: typing.List[int]): - map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") + array_node_map_task(a=x).with_overrides( + container_image="new-image", + timeout=timedelta(seconds=20), + interruptible=True, + retries=5, + cache=True, + cache_version="new-version", + requests=Resources(cpu=2) + ) + + assert wf.nodes[0]._container_image == "new-image" + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf) - assert wf.nodes[0]._container_image == "random:image" + array_node = wf_spec.template.nodes[0] + assert array_node.metadata.timeout == timedelta() + sub_node_spec = array_node.array_node.node + assert sub_node_spec.metadata.timeout == timedelta(seconds=20) + assert sub_node_spec.metadata.interruptible + assert sub_node_spec.metadata.retries.retries == 5 + assert sub_node_spec.metadata.cacheable + assert sub_node_spec.metadata.cache_version == "new-version" + assert sub_node_spec.target.overrides.resources.requests == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2") + ] def test_serialization_metadata(serialization_settings):