Skip to content

Commit

Permalink
fixed executor_data propagation from lattice to default electrons
Browse files Browse the repository at this point in the history
  • Loading branch information
kessler-frost committed Nov 21, 2023
1 parent 21ea564 commit 5779789
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
24 changes: 23 additions & 1 deletion covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

"""Class corresponding to computation nodes."""


import contextlib
import inspect
import json
import operator
Expand Down Expand Up @@ -359,6 +361,26 @@ def get_item(e, key):

raise StopIteration

def _is_metadata_field_empty(self, key: str) -> bool:
"""
Checks if the metadata field is empty.
Args:
key: Name of the metadata field.
Returns:
True if metadata field is empty, else False.
"""

if self.get_metadata(key) is None:
return True

with contextlib.suppress(TypeError):
if len(self.get_metadata(key)) == 0:
return True

return False

def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]:
"""
Function to execute the electron.
Expand Down Expand Up @@ -391,7 +413,7 @@ def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]:
if (
k not in consumable_constraints
and k in DEFAULT_METADATA_VALUES
and self.get_metadata(k) is None
and self._is_metadata_field_empty(k)
):
meta = active_lattice.get_metadata(k) or DEFAULT_METADATA_VALUES[k]
self.set_metadata(k, meta)
Expand Down
8 changes: 4 additions & 4 deletions covalent/_workflow/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def build_graph(self, *args, **kwargs) -> None:

named_args, named_kwargs = get_named_params(workflow_function, args, kwargs)
new_args = [v for _, v in named_args.items()]
new_kwargs = {k: v for k, v in named_kwargs.items()}
new_kwargs = dict(named_kwargs.items())

self.inputs = TransportableObject({"args": args, "kwargs": kwargs})
self.named_args = TransportableObject(named_args)
Expand All @@ -215,7 +215,7 @@ def build_graph(self, *args, **kwargs) -> None:
new_metadata = {
name: DEFAULT_METADATA_VALUES[name]
for name in constraint_names
if not self.metadata[name]
if self.metadata[name] is None
}
new_metadata = encode_metadata(new_metadata)

Expand Down Expand Up @@ -330,8 +330,8 @@ def lattice(
# Add custom metadata fields here
deps_bash: Union[DepsBash, list, str] = None,
deps_pip: Union[DepsPip, list] = None,
call_before: Union[List[DepsCall], DepsCall] = [],
call_after: Union[List[DepsCall], DepsCall] = [],
call_before: Union[List[DepsCall], DepsCall] = None,
call_after: Union[List[DepsCall], DepsCall] = None,
triggers: Union["BaseTrigger", List["BaseTrigger"]] = None,
# e.g. schedule: True, whether to use a custom scheduling logic or not
) -> Lattice:
Expand Down

0 comments on commit 5779789

Please sign in to comment.