diff --git a/src/aiida_workgraph/collection.py b/src/aiida_workgraph/collection.py index e73dbf59..6ed8454b 100644 --- a/src/aiida_workgraph/collection.py +++ b/src/aiida_workgraph/collection.py @@ -1,8 +1,10 @@ +from typing import Any, Callable, Optional, Union + +from aiida.engine import ProcessBuilder from node_graph.collection import ( NodeCollection, PropertyCollection, ) -from typing import Any, Callable, Optional, Union class TaskCollection(NodeCollection): @@ -14,14 +16,13 @@ def _new( **kwargs: Any ) -> Any: from aiida_workgraph.decorator import ( - build_task_from_callable, build_pythonjob_task, build_shelljob_task, + build_task_from_callable, build_task_from_workgraph, ) from aiida_workgraph.workgraph import WorkGraph - # build the task on the fly if the identifier is a callable if callable(identifier): identifier = build_task_from_callable(identifier) if isinstance(identifier, str) and identifier.upper() == "PYTHONJOB": @@ -41,6 +42,10 @@ def _new( return task if isinstance(identifier, WorkGraph): identifier = build_task_from_workgraph(identifier) + if isinstance(identifier, ProcessBuilder): + from aiida_workgraph.utils import get_dict_from_builder + kwargs = {**kwargs, **get_dict_from_builder(identifier)} + identifier = build_task_from_callable(identifier.process_class) return super()._new(identifier, name, uuid, **kwargs) @@ -58,3 +63,17 @@ def _new( identifier = build_property_from_AiiDA(identifier) # Call the original new method return super()._new(identifier, name, **kwargs) + + +# Backup + # print(identifier, type(identifier)) + # name = name or '' + # breakpoint() + # return super()._new(, name, **builder_dict) + # print("ProcessBuilder passed") + # task = build_task_from_builder(identifier) + # return task + # print(f"identifier: {identifier} ({type(identifier)})") + # print(f"name: {name}") + # print(f"uuid: {uuid}") + # print(f"kwargs: {kwargs}") \ No newline at end of file diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 263d4445..fc2ef7b6 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -2,9 +2,9 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple from aiida_workgraph.utils import get_executor -from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain +from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain, ProcessBuilder from aiida_workgraph.task import Task -from aiida_workgraph.utils import build_callable, validate_task_inout +from aiida_workgraph.utils import build_callable, validate_task_inout, get_dict_from_builder import inspect from aiida_workgraph.config import builtin_inputs, builtin_outputs, task_types from aiida_workgraph.orm.mapping import type_mapping @@ -608,3 +608,109 @@ def __call__(self, *args, **kwargs): task = TaskDecoratorCollection() + +# Backup +# def build_task_from_builder( +# builder: ProcessBuilder, +# # tdata: Dict[str, Any], +# inputs: Optional[List[str]] = None, +# outputs: Optional[List[str]] = None, +# ) -> Task: +# # ) -> Tuple[Task, Dict[str, Any]]: + +# """Build task from an aiida-core ProcessBuilder.""" +# from aiida.orm.utils.serialize import serialize +# from aiida_workgraph.utils.inspect_aiida_components import ( +# get_task_data_from_aiida_component, +# ) +# from aiida.engine.utils import instantiate_process +# from aiida.manage import get_manager +# from aiida.plugins import CalculationFactory + +# process_class = builder.process_class +# builder_inputs = builder._inputs() +# # metadata = builder_inputs.pop('metadata') +# # monitors = builder_inputs.pop('monitors') +# # code = builder_inputs.pop('code') + +# # process_node = process_class(**builder_inputs) + + +# # def build_task_from_callable( +# # executor: Callable, +# # inputs: Optional[List[str | dict]] = None, +# # outputs: Optional[List[str | dict]] = None, +# # inputs = inputs or {} +# # breakpoint() +# task = build_task_from_callable( +# executor=process_class, +# inputs=inputs, +# outputs=outputs, +# ) +# print(f"EXECUTOR: {process_class}") +# print(f"INPUTS: {inputs}") +# print(f"OUTPUTS: {outputs}") +# print(f"TASK: {task}") + +# # manager = get_manager() +# # runner = manager.get_runner() +# # breakpoint() + +# # process_class = CalculationFactory(entry_point_name) +# # process = instantiate_process(runner, process_class, **builder_inputs) +# # print(issubclass(process_class, CalcJob)) +# # raise SystemExit +# # breakpoint() + +# # tdata = { +# # "metadata": {"task_type": "builder"}, # Or maybe use node_class, node_type +# # "callable": process_class, +# # } +# # _, tdata = build_task_from_AiiDA(tdata) + +# # breakpoint() +# # elif is_process_function(process): +# # process_class = process.process_class # type: ignore[attr-defined] +# # elif inspect.isclass(process) and issubclass(process, Process): +# # process_class = process +# # else: +# # raise ValueError(f'invalid process {type(process)}, needs to be Process or ProcessBuilder') + +# # process = process_class(runner=runner, inputs=inputs) + +# # tdata = {"metadata": {"task_type": ""}} +# # inputs = [] +# # outputs = [] +# # group_outputs = [] + +# # process_class = builder._process_class + +# # builder_dict = get_dict_from_builder(builder) + +# # tdata = get_task_data_from_aiida_component( +# # tdata=tdata, inputs=inputs, outputs=outputs +# # ) + + +# # task_decorated = build_task_from_callable( +# # func_decorated, +# # inputs=kwargs.get("inputs", []), +# # outputs=kwargs.get("outputs", []), +# # ) + +# # executor = get_executor(self.get_executor())[0] +# # builder = executor.get_builder_from_protocol(*args, **kwargs) +# # TODO: Instantiate AiiDA object from the builder, and pass that, rather than having to manually construct the tdata +# # TODO: here again +# # data = get_dict_from_builder(builder) + +# # # data.pop('identifier') +# # data = {**data, **tdata} + +# # data['identifier'] = 'a' +# # # tdata["identifier"] = wg.name +# # task = Task.from_dict(data=data) +# return task + # print(f"TDATA: {tdata}") + # print(f"INPUTS: {inputs}") + # print(f"OUTPTUS: {outputs}") \ No newline at end of file