From 57797893a17b1dd56c360529d4c8d7b596c94015 Mon Sep 17 00:00:00 2001 From: sankalp Date: Tue, 21 Nov 2023 07:29:33 -0500 Subject: [PATCH] fixed executor_data propagation from lattice to default electrons --- covalent/_workflow/electron.py | 24 +++++++++++++++++++++++- covalent/_workflow/lattice.py | 8 ++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 67f3044ba..aa319b630 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -16,6 +16,8 @@ """Class corresponding to computation nodes.""" + +import contextlib import inspect import json import operator @@ -359,6 +361,26 @@ def get_item(e, key): raise StopIteration + def _is_metadata_field_empty(self, key: str) -> bool: + """ + Checks if the metadata field is empty. + + Args: + key: Name of the metadata field. + + Returns: + True if metadata field is empty, else False. + """ + + if self.get_metadata(key) is None: + return True + + with contextlib.suppress(TypeError): + if len(self.get_metadata(key)) == 0: + return True + + return False + def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]: """ Function to execute the electron. @@ -391,7 +413,7 @@ def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]: if ( k not in consumable_constraints and k in DEFAULT_METADATA_VALUES - and self.get_metadata(k) is None + and self._is_metadata_field_empty(k) ): meta = active_lattice.get_metadata(k) or DEFAULT_METADATA_VALUES[k] self.set_metadata(k, meta) diff --git a/covalent/_workflow/lattice.py b/covalent/_workflow/lattice.py index aa09ac033..2c4159d99 100644 --- a/covalent/_workflow/lattice.py +++ b/covalent/_workflow/lattice.py @@ -203,7 +203,7 @@ def build_graph(self, *args, **kwargs) -> None: named_args, named_kwargs = get_named_params(workflow_function, args, kwargs) new_args = [v for _, v in named_args.items()] - new_kwargs = {k: v for k, v in named_kwargs.items()} + new_kwargs = dict(named_kwargs.items()) self.inputs = TransportableObject({"args": args, "kwargs": kwargs}) self.named_args = TransportableObject(named_args) @@ -215,7 +215,7 @@ def build_graph(self, *args, **kwargs) -> None: new_metadata = { name: DEFAULT_METADATA_VALUES[name] for name in constraint_names - if not self.metadata[name] + if self.metadata[name] is None } new_metadata = encode_metadata(new_metadata) @@ -330,8 +330,8 @@ def lattice( # Add custom metadata fields here deps_bash: Union[DepsBash, list, str] = None, deps_pip: Union[DepsPip, list] = None, - call_before: Union[List[DepsCall], DepsCall] = [], - call_after: Union[List[DepsCall], DepsCall] = [], + call_before: Union[List[DepsCall], DepsCall] = None, + call_after: Union[List[DepsCall], DepsCall] = None, triggers: Union["BaseTrigger", List["BaseTrigger"]] = None, # e.g. schedule: True, whether to use a custom scheduling logic or not ) -> Lattice: