diff --git a/python/experiment/model/frontends/dsl.py b/python/experiment/model/frontends/dsl.py index f7c9586..5258a76 100644 --- a/python/experiment/model/frontends/dsl.py +++ b/python/experiment/model/frontends/dsl.py @@ -203,9 +203,9 @@ def __str__(self): def split( self, - scopes: typing.Iterable[typing.Tuple[str]], + scopes: typing.Iterable[typing.Tuple[str, ...]], ) -> typing.Tuple[ - typing.Tuple[str], + typing.Tuple[str, ...], typing.Optional[str], ]: """Utility method to infer the partition of an OutputReference location into stepName and fileRef @@ -819,7 +819,7 @@ def _replace_many_parameter_references( def replace_parameter_references( value: ParameterValueType, - all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"], + all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"], location: typing.Iterable[str], is_replica: bool, variables: typing.Optional[typing.Dict[str, ParameterValueType]] = None @@ -1045,7 +1045,7 @@ def fold_in_defaults_of_parameters(self): def resolve_parameter_references_of_instance( self: "ScopeStack.Scope", - all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"] + all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"] ) -> typing.List[Exception]: errors: typing.List[Exception] = [] for idx, (name, value) in enumerate(self.parameters.items()): @@ -1069,7 +1069,7 @@ def resolve_parameter_references_of_instance( def resolve_output_references_of_instance( self: "ScopeStack.Scope", - all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"], + all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"], ensure_references_point_to_sibling_steps: bool, ) -> typing.List[Exception]: errors: typing.List[Exception] = [] @@ -1107,7 +1107,7 @@ def resolve_output_references_of_instance( def resolve_legacy_data_references( self: "ScopeStack.Scope", - all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"], + all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"], ) -> typing.List[Exception]: errors: typing.List[Exception] = [] for idx, (name, value) in enumerate(self.parameters.items()): @@ -1160,7 +1160,7 @@ def replace_step_references( self, value: ParameterValueType, field: typing.List[str], - all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"], + all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"], ensure_references_point_to_sibling_steps: bool = True, ) -> ParameterValueType: """Rewrites all references to steps in a value to their absolute form @@ -1190,7 +1190,6 @@ def replace_step_references( sibling_steps = [] if len(self.location) > 1: uid_parent = tuple(self.location[:-1]) - parent_workflow_name = uid_parent[-1] if len(uid_parent) > 0 else "**missing**" try: parent_scope = all_scopes[uid_parent] @@ -1257,7 +1256,15 @@ def __init__(self): # The root of the namespace is what the entrypoint invokes, it's name is always `entry-instance`. # The values are the ScopeEntries which are effectively instances of a Template i.e. # The instance name, definition, and parameters of a Template - self.scopes: typing.Dict[typing.Tuple[str], ScopeStack.Scope] = {} + self.scopes: typing.Dict[typing.Tuple[str, ...], ScopeStack.Scope] = {} + + # VV: Keys are "locations" of Component template instances, values are either True or False indicating + # whether the component is replicating or not + self.replicating_components: typing.Dict[typing.Tuple[str, ...], bool] = {} + + # VV: Keys are "locations" of Component template instances, values are either True or False indicating + # whether the component is aggregating or not + self.aggregating_components: typing.Dict[typing.Tuple[str, ...], bool] = {} def can_template_replicate(self, location: typing.Iterable[str]) -> bool: """Returns whether the template can replicate @@ -1277,27 +1284,79 @@ def can_template_replicate(self, location: typing.Iterable[str]) -> bool: If the location does not map to a known scope """ - location = list(location) + scope: ScopeStack.Scope = self.scopes[tuple(location)] + if isinstance(scope.template, Workflow): + return False + + # VV: This component is using a variable for its workflowAttributes.replicate field. + # For now, we can assume that this means the component is replicating + if scope.template.workflowAttributes.replicate not in ["0", 0, "", None]: + return True + + if scope.template.workflowAttributes.aggregate is True: + self.aggregating_components[tuple(location)] = True + return False + + # VV: This component doesn't replicate, let's start walking from its producers all the way to the + # root of the experiment. Stop when one of the components is either replicating or aggregating. + # If you reach the entrypoint then this component is not replicating. - # VV: Iterate scopes starting from the CURRENT template and moving up till you reach: - # 1. an aggregating component -> replica = False - # 2. a POTENTIALLY replicating component -> replica = True - # 3. the entrypoint -> replica = False + pattern_vanilla = re.compile(OutputReferenceVanilla) + pattern_nested = re.compile(OutputReferenceNested) - while location: - scope = self.scopes[tuple(location)] - location = location[:-1] + component_locations_to_check = [scope.location] - if not isinstance(scope.template, Component): + while component_locations_to_check: + location = component_locations_to_check.pop() + + scope: ScopeStack.Scope = self.scopes[tuple(location)] + if isinstance(scope.template, Workflow): continue + # VV: This component is using a variable for its workflowAttributes.replicate field. + # For now, we can assume that this means the component is replicating if scope.template.workflowAttributes.replicate not in ["0", 0, "", None]: return True - elif ( - isinstance(scope.template.workflowAttributes.aggregate, bool) - and scope.template.workflowAttributes.aggregate is True - ): - return False + + for value in scope.parameters.values(): + if not isinstance(value, str): + continue + + for pattern in [pattern_vanilla, pattern_nested]: + for match in pattern.finditer(value): + ref = OutputReference.from_str(match.group(0)) + location = ref.location + + producer = None + while location: + try: + producer = self.scopes[tuple(location)] + if isinstance(producer.template, Component) is False: + continue + break + except KeyError: + # VV: This location doesn't map to a component. The OutputReference must be pointing + # to an output of a step. Trim one level and check whether that points to a known step. + location = location[:-1] + + if not producer: + continue + + replicating = self.replicating_components.get(tuple(location)) + + if (replicating + or producer.template.workflowAttributes.replicate not in ["0", 0, "", None]): + self.replicating_components[tuple(scope.location)] = True + return True + + if (self.aggregating_components.get(tuple(producer.location), False) is True + or producer.template.workflowAttributes.aggregate is True): + # VV: If this component is not aggregating then it **might** be replicating + self.aggregating_components[tuple(producer.location)] = True + break + + # VV: Cannot tell whether this component replicates or not, need to visit its producer + component_locations_to_check.append(producer.location) return False @@ -1880,7 +1939,7 @@ def discover_legacy_references(self) -> typing.Dict[str, typing.List[str]]: def convert_outputreferences_to_datareferences( self, - uid_to_name: typing.Dict[typing.Tuple[str], typing.Tuple[int, str]], + uid_to_name: typing.Dict[typing.Tuple[str, ...], typing.Tuple[int, str]], location: typing.List[typing.Union[str, int]], ): """Utility method to convert OutputReference instances into Legacy DataReferences @@ -1917,17 +1976,18 @@ def convert_outputreferences_to_datareferences( self.flowir["command"]["arguments"] = args # VV: TODO Here we'll need to do something about :copy - I'll figure this out in a future update - for match in pattern_output.finditer(args): - ref = OutputReference.from_str(match.group(0)) - if not ref.method: - raise experiment.model.errors.DSLInvalidFieldError( - location=["components", self.scope.template.signature.name, "command", "arguments"], - underlying_error=ValueError(f"The arguments of {self.scope.location} contain a reference to " - f"the output {match.group(0)} but the OutputReference is partial, it does not " - f"end with a :$method suffix.") - ) - else: - arguments_output.add(match.group(0)) + if isinstance(args, str): + for match in pattern_output.finditer(args): + ref = OutputReference.from_str(match.group(0)) + if not ref.method: + raise experiment.model.errors.DSLInvalidFieldError( + location=["components", self.scope.template.signature.name, "command", "arguments"], + underlying_error=ValueError(f"The arguments of {self.scope.location} contain a reference to " + f"the output {match.group(0)} but the OutputReference is partial, it does not " + f"end with a :$method suffix.") + ) + else: + arguments_output.add(match.group(0)) for idx, (name, value) in enumerate(self.scope.parameters.items()): if not isinstance(value, str): @@ -1960,8 +2020,9 @@ def convert_outputreferences_to_datareferences( ) self.errors.append(err) - for match in pattern_legacy.finditer(args): - arguments_legacy.add(match.group(0)) + if isinstance(args, str): + for match in pattern_legacy.finditer(args): + arguments_legacy.add(match.group(0)) for name, value in self.scope.parameters.items(): if not isinstance(value, str): @@ -2081,9 +2142,13 @@ def replace_parameter_references( except experiment.model.errors.DSLInvalidFieldError as e: self.errors.append(e) except Exception as e: + str_location = "/".join(map(str, self.scope.dsl_location())) + self.errors.append( experiment.model.errors.DSLInvalidFieldError( - self.template_dsl_location + node.location, underlying_error=e + self.template_dsl_location + node.location, + underlying_error=ValueError(f"The component was instantiated at {str_location}. " + f"Error: {e}") ) ) @@ -2210,7 +2275,7 @@ def namespace_to_flowir( override_entrypoint_args=override_entrypoint_args ) - components: typing.Dict[typing.Tuple[str], ComponentFlowIR] = {} + components: typing.Dict[typing.Tuple[str, ...], ComponentFlowIR] = {} errors = [] for location, scope in scopes.scopes.items(): @@ -2236,7 +2301,7 @@ def namespace_to_flowir( raise experiment.model.errors.DSLInvalidError.from_errors(errors) component_names: typing.Dict[str, int] = {} - uid_to_name: typing.Dict[typing.Tuple[str], typing.Tuple[int, str]] = {} + uid_to_name: typing.Dict[typing.Tuple[str, ...], typing.Tuple[int, str]] = {} pattern_name = re.compile(SignatureNamePattern) diff --git a/python/experiment/model/frontends/flowir.py b/python/experiment/model/frontends/flowir.py index f149370..4786edc 100644 --- a/python/experiment/model/frontends/flowir.py +++ b/python/experiment/model/frontends/flowir.py @@ -16,6 +16,7 @@ import pprint import re import traceback +import typing from string import Template from threading import RLock from typing import (Any, Callable, Dict, List, MutableMapping, Optional, Set, @@ -4930,8 +4931,13 @@ def refresh_component_dictionary(self): ) self._component_dictionary[comp_id] = component - def replicate(self, platform=None, ignore_errors=False, top_level_folders=None): - # type: (str, bool, List[str]) -> DictFlowIR + def replicate( + self, + platform: typing.Optional[str]=None, + ignore_errors: bool=False, + top_level_folders: typing.Optional[typing.List[str]]=None + ) -> DictFlowIR: + """Replicates a primitive FlowIRConcrete Arguments: diff --git a/tests/test_dsl.py b/tests/test_dsl.py index 4f5fc4e..d0b923c 100644 --- a/tests/test_dsl.py +++ b/tests/test_dsl.py @@ -1080,8 +1080,10 @@ def test_validate_dsl_with_unknown_params(): assert len(exc.underlying_errors) == 1 print(exc.underlying_errors[0].pretty_error()) assert exc.underlying_errors[0].location == ["components", 0, "command", "arguments"] - assert str(exc.underlying_errors[0].underlying_error) == ('Reference to unknown parameter "hello". ' - 'Known parameters are {}') + assert str(exc.underlying_errors[0].underlying_error) == ( + 'The component was instantiated at workflows/0/execute/0. ' + 'Error: Reference to unknown parameter "hello". Known parameters are {}' + ) def test_unknown_outputreference(): @@ -1521,3 +1523,133 @@ def test_no_replica_names(): 'ctx': {'pattern': '^(stage(?P([0-9]+))\\.)?(?P([A-Za-z0-9._-]*[A-Za-z_-]+))$'}, }, ] + + +def test_replicate_propagate(): + dsl = yaml.safe_load(""" +entrypoint: + entry-instance: product-of-sums + execute: + - target: + args: + N: 3 +workflows: +- signature: + name: product-of-sums + parameters: + - name: N + default: 1 + steps: + generate-random-pairs: generate-random-pairs + nested: nested + calculate-product: calculate-product + execute: + - target: + args: + number_pairs: "%(N)s" + - target: + args: + file: /pairs.yaml:output + - target: + args: + numbers: :output + +- signature: + name: nested + parameters: + - name: file + steps: + sum: sum + execute: + - target: + args: + file: "%(file)s" + +components: +- signature: + name: generate-random-pairs + parameters: + - name: number_pairs + workflowAttributes: + replicate: "%(number_pairs)s" + command: + expandArguments: none + executable: python + arguments: "%(number_pairs)s" +- signature: + name: sum + parameters: + - name: file + command: + expandArguments: none + executable: python + arguments: "%(file)s %(replica)s" +- signature: + name: calculate-product + parameters: + - name: numbers + workflowAttributes: + aggregate: true + command: + expandArguments: none + executable: python + arguments: "%(numbers)s" + """) + + dsl = experiment.model.frontends.dsl.Namespace(**dsl) + + flowir = experiment.model.frontends.dsl.namespace_to_flowir(dsl) + + +def test_replicate_propagate_through_chain(): + dsl = yaml.safe_load(""" +entrypoint: + entry-instance: dummy + execute: + - target: +workflows: +- signature: + name: dummy + steps: + first: replicate + second: plain + third: plain + fourth: plain + execute: + - target: + - target: + args: + message: :output + - target: + args: + message: :output + - target: + args: + message: :output + +components: +- signature: + name: replicate + workflowAttributes: + replicate: 1 + command: + expandArguments: none + executable: echo +- signature: + name: plain + parameters: + - name: message + command: + expandArguments: none + executable: echo + arguments: "%(message)s %(replica)s" + """) + + dsl = experiment.model.frontends.dsl.Namespace(**dsl) + + flowir = experiment.model.frontends.dsl.namespace_to_flowir(dsl) + + raw = flowir.replicate() + + for comp in raw["components"]: + assert comp["name"].endswith("0")