diff --git a/cfdraw/app/app.py b/cfdraw/app/app.py index f418275b..aafe8be1 100644 --- a/cfdraw/app/app.py +++ b/cfdraw/app/app.py @@ -5,18 +5,17 @@ from aiohttp import ClientSession from fastapi import FastAPI from contextlib import asynccontextmanager -from cftool.misc import print_info -from cftool.misc import random_hash from fastapi.middleware import cors from cfdraw import constants -from cfdraw.utils import console from cfdraw.config import get_config from cfdraw.app.schema import IApp from cfdraw.app.endpoints import * +from cfdraw.core.toolkit import console from cfdraw.schema.plugins import IPlugin from cfdraw.plugins.factory import Plugins from cfdraw.plugins.factory import PluginFactory +from cfdraw.core.toolkit.misc import random_hash async def ping() -> str: @@ -31,11 +30,8 @@ def __init__(self, notification: Optional[str] = None) -> None: async def lifespan(api: FastAPI) -> AsyncGenerator: # startup - def info(msg: str) -> None: - print_info(msg) - - info(f"๐Ÿš€ Starting Backend Server at {self.config.api_url} ...") - info("๐Ÿ”จ Compiling Plugins & Endpoints...") + console.log(f"๐Ÿš€ Starting Backend Server at {self.config.api_url} ...") + console.log("๐Ÿ”จ Compiling Plugins & Endpoints...") tplugin_with_notification: List[Type[IPlugin]] = [] for tplugin in self.plugins.values(): tplugin.hash = self.hash @@ -44,7 +40,7 @@ def info(msg: str) -> None: tplugin_with_notification.append(tplugin) if tplugin_with_notification or notification is not None: console.rule("") - info(f"๐Ÿ“ฃ Notifications:") + console.log(f"๐Ÿ“ฃ Notifications:") if notification is not None: console.rule(f"[bold green][ GLOBAL ]") console.print(notification) @@ -56,8 +52,8 @@ def info(msg: str) -> None: for endpoint in self.endpoints: await endpoint.on_startup() upload_root_path = self.config.upload_root_path - info(f"๐Ÿ”” Your files will be saved to '{upload_root_path}'") - info("๐ŸŽ‰ Backend Server is Ready!") + console.log(f"๐Ÿ”” Your files will be saved to '{upload_root_path}'") + console.log("๐ŸŽ‰ Backend Server is Ready!") yield diff --git a/cfdraw/app/endpoints/project.py b/cfdraw/app/endpoints/project.py index b05b5c78..17eb15c0 100644 --- a/cfdraw/app/endpoints/project.py +++ b/cfdraw/app/endpoints/project.py @@ -8,11 +8,11 @@ from typing import Any from typing import List from pydantic import BaseModel -from cftool.web import raise_err -from cftool.web import get_responses -from cftool.misc import get_err_msg from cfdraw.parsers import noli +from cfdraw.core.toolkit.web import raise_err +from cfdraw.core.toolkit.web import get_responses +from cfdraw.core.toolkit.misc import get_err_msg from cfdraw.app.endpoints.base import IEndpoint diff --git a/cfdraw/app/endpoints/queue.py b/cfdraw/app/endpoints/queue.py index 2abb6491..628c9e75 100644 --- a/cfdraw/app/endpoints/queue.py +++ b/cfdraw/app/endpoints/queue.py @@ -2,26 +2,109 @@ import logging from typing import Dict +from typing import List from typing import Tuple +from typing import Generic +from typing import Iterator from typing import Optional -from cftool.misc import get_err_msg -from cftool.misc import print_error -from cftool.misc import random_hash -from cftool.misc import print_warning -from cftool.data_structures import Item -from cftool.data_structures import QueuesInQueue from cfdraw.app.schema import IRequestQueue from cfdraw.app.schema import IRequestQueueData +from cfdraw.core.toolkit import console from cfdraw.utils.misc import offload from cfdraw.schema.plugins import ISend from cfdraw.schema.plugins import SocketStatus from cfdraw.schema.plugins import ISocketMessage +from cfdraw.core.toolkit.misc import get_err_msg +from cfdraw.core.toolkit.misc import random_hash +from cfdraw.core.toolkit.data_structures import Item +from cfdraw.core.toolkit.data_structures import Bundle +from cfdraw.core.toolkit.data_structures import TItemData DEBUG = False +class QueuesInQueue(Generic[TItemData]): + def __init__(self, *, no_mapping: bool = True) -> None: + self._cursor = 0 + self._queues: Bundle[Bundle[TItemData]] = Bundle(no_mapping=no_mapping) + + def __iter__(self) -> Iterator[Item[Bundle[TItemData]]]: + return iter(self._queues) + + @property + def is_empty(self) -> bool: + return self.num_items == 0 + + @property + def num_queues(self) -> int: + return len(self._queues) + + @property + def num_items(self) -> int: + return sum(len(q.data) for q in self._queues) + + def get(self, queue_id: str) -> Optional[Item[Bundle[TItemData]]]: + return self._queues.get(queue_id) + + def push(self, queue_id: str, item: Item[TItemData]) -> None: + queue_item = self._queues.get(queue_id) + if queue_item is not None: + queue = queue_item.data + else: + queue = Bundle() + self._queues.push(Item(queue_id, queue)) + queue.push(item) + + def next(self) -> Tuple[Optional[str], Optional[Item[TItemData]]]: + if self._queues.is_empty: + return None, None + self._cursor %= len(self._queues) + queue = self._queues.get_index(self._cursor) + item = queue.data.first + if item is None: + self._queues.remove(queue.key) + return self.next() + self._cursor += 1 + return queue.key, item + + def remove(self, queue_id: str, item_key: str) -> None: + queue_item = self._queues.get(queue_id) + if queue_item is None: + return + queue_item.data.remove(item_key) + if queue_item.data.is_empty: + self._queues.remove(queue_id) + + def get_pending(self, item_key: str) -> Optional[List[Item[TItemData]]]: + if self._queues.is_empty: + return None + layer = 0 + searched = False + pending: List[Item[TItemData]] = [] + finished_searching = [False] * len(self._queues) + + init = (self._cursor + len(self._queues) - 1) % len(self._queues) + cursor = init + while not all(finished_searching): + if not finished_searching[cursor]: + queue = self._queues.get_index(cursor) + if layer >= len(queue.data): + finished_searching[cursor] = True + else: + item = queue.data.get_index(layer) + if item.key == item_key: + searched = True + break + pending.append(item) + cursor = (cursor + 1) % len(self._queues) + if cursor == init: + layer += 1 + + return pending if searched else None + + class RequestQueue(IRequestQueue): def __init__(self) -> None: self._queues = QueuesInQueue[IRequestQueueData]() @@ -79,11 +162,11 @@ async def wait(self, user_id: str, uid: str) -> None: # So here we simply warn instead of raise. queue_item = self._queues.get(user_id) if queue_item is None: - print_warning("cannot find user request queue after submitted") + console.warn("cannot find user request queue after submitted") return request_item = queue_item.data.get(uid) if request_item is None: - print_warning("cannot find request item after submitted") + console.warn("cannot find request item after submitted") return await self._broadcast_pending() asyncio.create_task(self.run()) @@ -144,7 +227,7 @@ async def _broadcast_pending(self) -> None: except Exception: logging.exception(f"{prefix} failed to send message '{message}'") if not success: - print_error(f"Failed to send following message: {message}") + console.error(f"Failed to send following message: {message}") async def _broadcast_working(self, uid: str) -> bool: sender_pack = self._senders.get(uid) @@ -167,7 +250,7 @@ async def _broadcast_working(self, uid: str) -> bool: except Exception: logging.exception(f"{prefix} failed to send message '{message}'") if not success: - print_error(f"Failed to send following message: {message}") + console.error(f"Failed to send following message: {message}") return success async def _broadcast_exception(self, uid: str, message: str) -> bool: @@ -184,7 +267,7 @@ async def _broadcast_exception(self, uid: str, message: str) -> bool: except Exception: logging.exception(f"{prefix} failed to send message '{message}'") if not success: - print_error(f"Failed to send following message: {message}") + console.error(f"Failed to send following message: {message}") return success diff --git a/cfdraw/app/endpoints/upload.py b/cfdraw/app/endpoints/upload.py index e2e8dd66..400d8ff7 100644 --- a/cfdraw/app/endpoints/upload.py +++ b/cfdraw/app/endpoints/upload.py @@ -12,9 +12,6 @@ from fastapi import UploadFile from pydantic import BaseModel from PIL.PngImagePlugin import PngInfo -from cftool.web import get_responses -from cftool.web import get_image_response_kwargs -from cftool.misc import get_err_msg from cfdraw import constants from cfdraw.app.schema import IApp @@ -23,6 +20,9 @@ from cfdraw.utils.server import get_svg_response from cfdraw.utils.server import get_image_response from cfdraw.app.endpoints.base import IEndpoint +from cfdraw.core.toolkit.web import get_responses +from cfdraw.core.toolkit.web import get_image_response_kwargs +from cfdraw.core.toolkit.misc import get_err_msg class ImageDataModel(BaseModel): diff --git a/cfdraw/app/endpoints/websocket.py b/cfdraw/app/endpoints/websocket.py index 96470e59..77fe0863 100644 --- a/cfdraw/app/endpoints/websocket.py +++ b/cfdraw/app/endpoints/websocket.py @@ -4,18 +4,18 @@ from fastapi import WebSocket from fastapi import WebSocketDisconnect -from cftool.misc import get_err_msg -from cftool.misc import print_error from starlette.websockets import WebSocketState from cfdraw import constants from cfdraw.app.schema import IApp from cfdraw.app.schema import IRequestQueueData +from cfdraw.core.toolkit import console from cfdraw.utils.misc import offload from cfdraw.schema.plugins import ElapsedTimes from cfdraw.schema.plugins import ISocketRequest from cfdraw.schema.plugins import ISocketMessage from cfdraw.app.endpoints.base import IEndpoint +from cfdraw.core.toolkit.misc import get_err_msg def add_websocket(app: IApp) -> None: @@ -72,7 +72,7 @@ async def send_message(data: ISocketMessage) -> bool: ) exception = ISocketMessage.make_exception(data.hash, message) if not await send_message(exception): - print_error(f"[websocket.loop] {message}") + console.error(f"\[websocket.loop] {message}") except WebSocketDisconnect: break except Exception as e: diff --git a/cfdraw/cli.py b/cfdraw/cli.py index f0c1c9cf..1a3983bc 100644 --- a/cfdraw/cli.py +++ b/cfdraw/cli.py @@ -5,14 +5,13 @@ import pkg_resources from pathlib import Path -from cftool.misc import print_info from cfdraw import constants from cfdraw.utils import exec -from cfdraw.utils import console from cfdraw.utils import processes from cfdraw.utils import prerequisites from cfdraw.config import get_config +from cfdraw.core.toolkit import console from cfdraw.utils.template import set_init_codes from cfdraw.utils.template import TemplateType @@ -52,7 +51,7 @@ def run( pkg_resources.require(requirements) except Exception as err: console.rule("๐Ÿ“ฆ Installing Requirements") - print_info(f"Reason : {err}") + console.log(f"Reason : {err}") enclosed = lambda s: f'"{s}"' requirements_string = " ".join(map(enclosed, requirements)) cmd = f"{sys.executable} -m pip install {requirements_string}" @@ -90,11 +89,11 @@ def run( backend_fn(module, log_level=log_level) finally: console.rule("[bold]Shutting down") - print_info("Killing frontend") + console.log("Killing frontend") processes.kill_process_on_port(frontend_port) - print_info("Killing backend") + console.log("Killing backend") processes.kill_process_on_port(backend_port) - print_info("Done") + console.log("Done") @cli.command() diff --git a/cfdraw/core/__init__.py b/cfdraw/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cfdraw/core/flow/__init__.py b/cfdraw/core/flow/__init__.py new file mode 100644 index 00000000..d33dfbd9 --- /dev/null +++ b/cfdraw/core/flow/__init__.py @@ -0,0 +1,5 @@ +from .core import * +from .nodes import * +from .server import * +from .utils import * +from .docs import * diff --git a/cfdraw/core/flow/core.py b/cfdraw/core/flow/core.py new file mode 100644 index 00000000..c7b35939 --- /dev/null +++ b/cfdraw/core/flow/core.py @@ -0,0 +1,952 @@ +import json +import time +import asyncio + +from abc import abstractmethod +from abc import ABCMeta +from typing import Any +from typing import Set +from typing import Dict +from typing import List +from typing import Type +from typing import Union +from typing import TypeVar +from typing import Callable +from typing import Optional +from pydantic import Field +from pydantic import BaseModel +from dataclasses import field +from dataclasses import asdict +from dataclasses import dataclass + +from ..toolkit import console +from ..toolkit.web import get_err_msg +from ..toolkit.misc import offload +from ..toolkit.misc import random_hash +from ..toolkit.misc import register_core +from ..toolkit.misc import shallow_copy_dict +from ..toolkit.misc import JsonPack +from ..toolkit.misc import ISerializableDataClass +from ..toolkit.types import TPath +from ..toolkit.data_structures import Item +from ..toolkit.data_structures import Bundle + + +TNode = TypeVar("TNode", bound="Node") +TTNode = TypeVar("TTNode", bound=Type["Node"]) +nodes: Dict[str, Type["Node"]] = {} +_shared_pool: Dict[str, Any] = {} +warmed_up_records: Dict[str, bool] = {} + +UNDEFINED_PLACEHOLDER = "$undefined$" +EXCEPTION_MESSAGE_KEY = "$exception$" +ALL_LATENCIES_KEY = "$all_latencies$" + +LOOP_NODE = "common.loop" +GATHER_NODE = "common.gather" +WORKFLOW_NODE = "common.workflow" +WORKFLOW_ENDPOINT_NAME = "workflow" + + +def to_hierarchies(hierarchy: Union[str, List[str]]) -> List[str]: + if isinstance(hierarchy, list): + return hierarchy + return hierarchy.split(".") + + +def extract_from(data: Any, hierarchy: Union[str, List[str]]) -> Any: + hierarchies = to_hierarchies(hierarchy) + for h in hierarchies: + if isinstance(data, list): + try: + ih = int(h) + except: + msg = f"current value is list, but '{h}' is not int" + raise ValueError(msg) + data = data[ih] + elif isinstance(data, dict): + data = data[h] + else: + raise ValueError( + f"hierarchy '{h}' is required but current value type " + f"is '{type(data)}' ({data})" + ) + return data + + +def inject_leaf_data(d: Any, hierarchies: List[str], v: Any, *, verbose: bool) -> None: + h = hierarchies.pop(0) + is_leaf = len(hierarchies) == 0 + if isinstance(d, list): + try: + ih = int(h) + except: + raise ValueError(f"current value is list, but '{h}' is not int") + if len(d) <= ih: + if verbose: + replace_msg = "target value" if is_leaf else "an empty `dict`" + console.warn( + "current data is a list but its length is not enough, " + f"corresponding index ({h}) will be set to {replace_msg}, " + "and other elements will be set to `undefined`" + ) + d.extend([UNDEFINED_PLACEHOLDER] * (ih - len(d) + 1)) + if is_leaf: + d[ih] = v + else: + if d[ih] == UNDEFINED_PLACEHOLDER: + console.warn("filling `undefined` value with an empty `dict`") + d[ih] = {} + inject_leaf_data(d[ih], hierarchies, v, verbose=verbose) + elif isinstance(d, dict): + if is_leaf: + d[h] = v + else: + if h not in d: + if verbose: + console.warn( + "current data is a dict but it does not have the " + f" corresponding key ('{h}'), it will be set to " + "an empty `dict`" + ) + d[h] = {} + inject_leaf_data(d[h], hierarchies, v, verbose=verbose) + else: + raise ValueError( + f"hierarchy '{h}' is required but current value type " + f"is '{type(d)}' ({d})" + ) + + +async def warmup(t_node: Type["Node"], verbose: bool) -> None: + warmed_up_key = t_node.__identifier__ + if not warmed_up_records.get(warmed_up_key, False): + if verbose: + console.debug(f"warming up node '{warmed_up_key}'") + await t_node.warmup() + warmed_up_records[warmed_up_key] = True + + +@dataclass +class Injection: + """ + A dataclass that represents an injection to the current node. + + Attributes + ---------- + src_key : str + The key of the dependent node. + src_hierarchy : str | list[str] | None + The 'src_hierarchy' of the dependent node's results that the current node depends on. + - `src_hierarchy` can be very complex: + - use `int` as `list` index, and `str` as `dict` key. + - use list / `.` to represent nested structure. + - for example, you can use `["a", "0", "b"]` or `a.0.b` to indicate `results["a"][0]["b"]`. + - If `None`, all results of the dependent node will be used. + dst_hierarchy : str | list[str] + The 'dst_hierarchy' of the current node's `data`. + - `dst_hierarchy` can be very complex: + - use `int` as `list` index, and `str` as `dict` key. + - use list / `.` to represent nested structure. + - for example, if you want to inject to `data["a"][0]["b"]`, you can use either + `["a", "0", "b"]` or `a.0.b` as the `dst_hierarchy`. + + """ + + src_key: str + src_hierarchy: Optional[Union[str, List[str]]] + dst_hierarchy: Union[str, List[str]] + + def to_model(self) -> "InjectionModel": + return InjectionModel( + src_key=self.src_key, + src_hierarchy=self.src_hierarchy, + dst_hierarchy=self.dst_hierarchy, + ) + + +@dataclass +class LoopBackInjection: + """ + A dataclass that represents a loop back injection to the current node. + + > This is the same as `Injection`, except the `src_key` will always be the + key of the previous node in the loop. + """ + + src_hierarchy: Optional[Union[str, List[str]]] + dst_hierarchy: Union[str, List[str]] + + def to_model(self) -> "LoopBackInjectionModel": + return LoopBackInjectionModel( + src_hierarchy=self.src_hierarchy, + dst_hierarchy=self.dst_hierarchy, + ) + + +@dataclass +class Schema: + """ + A class that represents a Schema of a node. + + Implement `get_schema` method and return a `Schema` instance for your nodes + can help us auto-generate UIs, APIs and documents. + + Attributes + ---------- + input_model : Optional[Type[BaseModel]] + The input data model of the node. + > If your inputs are not JSON serializable, you can use `input_names` instead. + output_model : Optional[Type[BaseModel]] + The output data model of the node. + > If your outputs are not JSON serializable, you can use either `api_output_model` + or `output_names` instead. + api_output_model : Optional[Type[BaseModel]] + The API response data model of the node. + > This is helpful when your outputs are not JSON serializable, and you implement + the `get_api_response` method to convert the outputs to API responses. + > In this case, `api_output_model` should be the data model of the results returned + by `get_api_response`. + input_names : Optional[List[str]] + The names of the inputs of the node. + > This is helpful if you want to make things simple. + > Please make sure that the input `data` of the node has exactly the same keys as `input_names`. + output_names : Optional[List[str]] + The names of the outputs of the node. + > This is helpful if you want to make things simple. + > Please make sure that the output `results` of the node has exactly the same keys as `output_names`. + description : Optional[str] + A description of the node. + > This will be displayed in the auto-generated UIs / documents. + + """ + + input_model: Optional[Type[BaseModel]] = None + output_model: Optional[Type[BaseModel]] = None + api_output_model: Optional[Type[BaseModel]] = None + input_names: Optional[List[str]] = None + output_names: Optional[List[str]] = None + description: Optional[str] = None + + +class Hook: + @classmethod + async def initialize(cls, shared_pool: Dict[str, Any]) -> None: + pass + + @classmethod + async def cleanup(cls, shared_pool: Dict[str, Any]) -> None: + pass + + +@dataclass +class Node(ISerializableDataClass["Node"], metaclass=ABCMeta): + """ + A Node class that represents a node in a workflow. + + This class is abstract and should be subclassed. + + Attributes + ---------- + key : str, optional + The key of the node, should be unique with respect to the workflow. + data : Any, optional + The data associated with the node. + injections : List[Injection], optional + A list of injections of the node. + offload : bool, optional + A flag indicating whether the node should be offloaded. + lock_key : str, optional + The lock key of the node. + executing : bool, optional + A runtime attribute indicating whether the node is currently executing. + + Methods + ------- + async execute() -> Any + Abstract method that should return the results. + + @classmethod + get_schema() -> Optional[Schema] + Optional method that returns the schema of the node. + Implement this method can help us auto-generate UIs, APIs and documents. + @classmethod + async warmup() -> None + Optional method that will be called: + - only once. + - before the server starts, if under API mode. + Implement this method to do heavy initializations (e.g. loading AI models). + async initialize(flow: Flow) -> None + Optional method that will be called everytime before the execution. + async get_api_response(results: Dict[str, Any]) -> Any + Optional method that returns the API response of the node from its 'raw' results. + Implement this method to handle complex API responses (e.g. `PIL.Image`). + async cleanup() -> None + Optional method that will be called everytime after the execution. + + """ + + key: Optional[str] = None + data: Dict[str, Any] = field(default_factory=dict) + injections: List[Injection] = field(default_factory=list) + offload: bool = False + lock_key: Optional[str] = None + # runtime attribute, should not be touched and will not be serialized + executing: bool = False + + # optional + + @classmethod + def get_schema(cls) -> Optional[Schema]: + return None + + @classmethod + def get_hooks(cls) -> List[Type[Hook]]: + return [] + + @classmethod + async def warmup(cls) -> None: + """ + This is used to warmup the node, and will be called: + - only once. + - before the server starts, if under API mode. + + > The main difference between `warmup` and `initialize` is that `warmup` will be + called only once, while `initialize` will be called everytime the node is executed. + > So you can do some heavy initializations here (e.g. loading AI models). + """ + + async def initialize(self, flow: "Flow") -> None: + """Will be called everytime before the execution.""" + + for hook in self.get_hooks(): + await hook.initialize(_shared_pool) + + async def get_api_response(self, results: Dict[str, Any]) -> Any: + return results + + async def cleanup(self) -> None: + """Will be called everytime after the execution.""" + + for hook in self.get_hooks(): + await hook.cleanup(_shared_pool) + + # api + + def depend_on(self, src_key: str) -> None: + """ + This can be used if this Node does not directly depend on `src_key` Node, + but you want this Node to wait for `src_key` Node to finish before starting. + """ + + tag = f"$depend_{random_hash()[:4]}" + self.injections.append(Injection(src_key, None, tag)) + + def to_model(self) -> "NodeModel": + if self.key is None: + raise ValueError("node key cannot be None") + return NodeModel( + key=self.key, + type=self.__identifier__, + data=shallow_copy_dict(self.data), + injections=[injection.to_model() for injection in self.injections], + offload=self.offload, + lock_key=self.lock_key, + ) + + # abstract + + @abstractmethod + async def execute(self) -> Any: + pass + + # internal + + @classmethod + def register(cls, name: str, **kwargs: Any) -> Callable[[TTNode], TTNode]: # type: ignore + def before(cls_: TTNode) -> None: + if name == WORKFLOW_ENDPOINT_NAME: + raise RuntimeError( + "`workflow` is a reserved name, please use another name " + f"when registering node '{cls_.__name__}'" + ) + cls_.__identifier__ = name + if custom_before is not None: + custom_before(cls_) + + custom_before = kwargs.pop("before_register", None) + kwargs.setdefault("allow_duplicate", False) + kwargs["before_register"] = before + return register_core(name, cls.d, **kwargs) # type: ignore + + @property + def shared_pool(self) -> Dict[str, Any]: + return _shared_pool + + def asdict(self) -> Dict[str, Any]: + return dict( + key=self.key, + data=shallow_copy_dict(self.data), + injections=[asdict(injection) for injection in self.injections], + offload=self.offload, + lock_key=self.lock_key, + ) + + def to_item(self: TNode) -> Item[TNode]: + if self.key is None: + raise ValueError("node key cannot be None") + return Item(self.key, self) + + def to_pack(self) -> JsonPack: + return JsonPack(type=self.__identifier__, info=self.to_info()) + + def from_info(self, info: Dict[str, Any]) -> "Node": + super().from_info(info) + if self.key is None: + raise ValueError("node key cannot be None") + if "." in self.key: + raise ValueError("node key cannot contain '.'") + return self + + def check_inputs(self) -> None: + if not isinstance(self.data, dict): + raise ValueError( + f"input `data` ({self.data}) of node " + f"'{self.key}' ({self.__class__.__name__}) should be a `dict`" + ) + schema = self.get_schema() + if schema is None: + return + if schema.input_model is not None: + try: + narrowed = schema.input_model(**self.data) + self.data = narrowed.model_dump() + except Exception as err: + msg = f"input data ({self.data}) does not match the schema model ({schema.input_model})" + raise ValueError(msg) from err + elif schema.input_names is not None: + data_inputs = set(self.data.keys()) + schema_inputs = set(schema.input_names) + if data_inputs != schema_inputs: + msg = f"input data ({self.data}) does not match the schema names ({schema.input_names})" + raise ValueError(msg) + + def check_injections(self) -> None: + history: Dict[str, Injection] = {} + for injection in self.injections: + dst_hierarchy_key = str(injection.dst_hierarchy) + existing = history.get(dst_hierarchy_key) + if existing is not None: + raise ValueError( + f"`dst_hierarchy` of current injection ({injection}) is duplicated " + f"with previous injection ({existing})" + ) + history[dst_hierarchy_key] = injection + + def fetch_injections(self, results: Dict[str, Any], verbose: bool = True) -> None: + for injection in self.injections: + src_key = injection.src_key + src_out = results.get(src_key) + if src_out is None: + raise ValueError(f"cannot find cache for '{src_key}'") + if injection.src_hierarchy is not None: + src_out = extract_from(src_out, injection.src_hierarchy) + dst_hierarchies = to_hierarchies(injection.dst_hierarchy) + inject_leaf_data(self.data, dst_hierarchies, src_out, verbose=verbose) + + def check_undefined(self) -> None: + def check(data: Any) -> None: + if isinstance(data, list): + for item in data: + check(item) + elif isinstance(data, dict): + for v in data.values(): + check(v) + elif data == UNDEFINED_PLACEHOLDER: + raise ValueError(f"undefined value found in '{self.data}'") + + check(self.data) + + def check_results(self, results: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(results, dict): + raise ValueError( + f"output results ({results}) of " + f"node '{self.key}' ({self.__class__.__name__}) should be a `dict`" + ) + schema = self.get_schema() + if schema is None: + return results + if schema.output_model is not None: + try: + narrowed = schema.output_model(**results) + return narrowed.model_dump() + except Exception as err: + msg = f"output data ({results}) does not match the schema model ({schema.output_model})" + raise ValueError(msg) from err + if schema.output_names is not None: + node_outputs = set(results.keys()) + schema_outputs = set(schema.output_names) + if node_outputs != schema_outputs: + msg = f"output data ({results}) does not match the schema names ({schema.output_names})" + raise ValueError(msg) + return results + + def check_api_results(self, results: Dict[str, Any]) -> Dict[str, Any]: + schema = self.get_schema() + if schema is None: + return results + if schema.api_output_model is not None: + try: + narrowed = schema.api_output_model(**results) + return narrowed.model_dump() + except Exception as err: + msg = f"API response ({results}) does not match the schema model ({schema.api_output_model})" + raise ValueError(msg) from err + return results + + +class Flow(Bundle[Node]): + """ + A Flow class that represents a workflow. + + Attributes + ---------- + edges : Dict[str, List[Edge]] + The dependencies of the workflow. + - The key is the destination node key. + - The value is a list of edges that indicates the dependencies + of the destination node. + latest_latencies : Dict[str, Dict[str, float]] + The latest latencies of the workflow. + + Methods + ------- + push(node: Node) -> Flow: + Pushes a node into the workflow. + loop(n: int, node: Node, loop_back_injections: List[LoopBackInjection], ...) -> str: + Loops the given `node` for `n` times, this is useful when you want to perform: + > iterative tasks on the same node, in which case `loop_back_injections` should be set. + > the same task for `n` times, in which case `loop_back_injections` should be `None`. + In this case, the `node` should have some randomness inside, and what you are doing + is kind of like 'ensemble' or 'mixture of experts'. + gather(*targets: str) -> str: + Gathers targets into a single node, and returns the key of the node. + to_json() -> Dict[str, Any]: + Converts the workflow to a JSON object. + from_json(cls, data: Dict[str, Any]) -> Flow: + Creates a workflow from a JSON object. + dump(path: TPath) -> None: + Dumps the workflow to a (JSON) file. + load(cls, path: TPath) -> Flow: + Loads a workflow from a (JSON) file. + get_reachable(target: str) -> Set[str]: + Gets the reachable nodes from a target. + run(...) -> None: + Runs a single node in the workflow. + execute(...) -> Dict[str, Any]: + Executes the entire workflow. + + """ + + def __init__(self, *, no_mapping: bool = False) -> None: + super().__init__(no_mapping=no_mapping) + self.latest_latencies: Dict[str, Dict[str, float]] = {} + + def __str__(self) -> str: + body = ",\n ".join(str(item.data) for item in self) + return f"""Flow([ + {body} +])""" + + __repr__ = __str__ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Flow): + return False + return self.to_json() == other.to_json() + + @property + def shared_pool(self) -> Dict[str, Any]: + return _shared_pool + + def push(self, node: Node) -> "Flow": # type: ignore + if node.key is None: + raise ValueError("node key cannot be None") + return super().push(node.to_item()) + + def loop( + self, + node: Node, + loop_values: Optional[Dict[str, List[Any]]] = None, + loop_back_injections: Optional[List[LoopBackInjection]] = None, + *, + loop_injections: Optional[List[Injection]] = None, + extract_hierarchy: Optional[str] = None, + verbose: bool = False, + ) -> str: + loop_key = f"$loop_{node.key}_{random_hash()[:4]}" + modified_injections: List[Injection] = [] + for injection in node.injections: + modified_dst_hierarchy: Union[str, List[str]] + if isinstance(injection.dst_hierarchy, str): + modified_dst_hierarchy = f"base_data.{injection.dst_hierarchy}" + else: + modified_dst_hierarchy = ["base_data"] + injection.dst_hierarchy + modified_injections.append( + Injection( + injection.src_key, + injection.src_hierarchy, + modified_dst_hierarchy, + ) + ) + if loop_injections is not None: + modified_injections.extend(loop_injections) + self.push( + Node.make( + LOOP_NODE, + dict( + key=loop_key, + data=dict( + base_node=node.__identifier__, + base_data=shallow_copy_dict(node.data), + loop_values=shallow_copy_dict(loop_values or {}), + loop_back_injections=( + None + if loop_back_injections is None + else list(map(asdict, loop_back_injections)) + ), + extract_hierarchy=extract_hierarchy, + verbose=verbose, + ), + injections=modified_injections, + offload=node.offload, + lock_key=node.lock_key, + ), + ) + ) + return loop_key + + def gather(self, *targets: str) -> str: + gather_key = f"$gather_{random_hash()[:4]}" + injections = [Injection(k, None, k) for k in targets] + self.push( + Node.make( + GATHER_NODE, + dict(key=gather_key, injections=injections), + ) + ) + return gather_key + + def to_json(self) -> List[Dict[str, Any]]: + return [item.data.to_pack().asdict() for item in self] + + @classmethod + def from_json(cls, data: List[Dict[str, Any]]) -> "Flow": + workflow = cls() + for pack in data: + workflow.push(Node.from_pack(pack)) + return workflow + + def to_model( + self, + *, + target: str, + intermediate: Optional[List[str]] = None, + return_if_exception: bool = False, + verbose: bool = False, + ) -> "WorkflowModel": + return WorkflowModel( + target=target, + intermediate=intermediate, + nodes=[item.data.to_model() for item in self], + return_if_exception=return_if_exception, + verbose=verbose, + ) + + def copy(self) -> "Flow": + return Flow.from_json(self.to_json()) + + def dump(self, path: TPath) -> None: + with open(path, "w") as f: + json.dump(self.to_json(), f, indent=2) + + @classmethod + def load(cls, path: TPath) -> "Flow": + with open(path, "r") as f: + return cls.from_json(json.load(f)) + + def get_reachable(self, target: str) -> Set[str]: + def dfs(key: str, is_target: bool) -> None: + if not is_target and key == target: + raise ValueError(f"cyclic dependency detected when dfs from '{target}'") + if key in reachable: + return + reachable.add(key) + item = self.get(key) + if item is None: + raise ValueError( + f"cannot find node '{key}', which is declared as a dependency, " + f"in the workflow ({self})" + ) + node = item.data + for injection in node.injections: + dfs(injection.src_key, False) + + reachable: Set[str] = set() + dfs(target, True) + return reachable + + async def run( + self, + item: Item[Node], + api_results: Dict[str, Any], + all_results: Dict[str, Any], + return_api_response: bool, + verbose: bool, + all_latencies: Dict[str, Dict[str, float]], + ) -> None: + if item.key in all_results: + return + start_t = time.time() + while not all(i.src_key in all_results for i in item.data.injections): + await asyncio.sleep(0) + if item.data.lock_key is not None: + while not all( + not other.data.executing or other.data.lock_key != item.data.lock_key + for other in self + ): + await asyncio.sleep(0) + item.data.executing = True + t0 = time.time() + node: Node = item.data.copy() + node.fetch_injections(all_results, verbose) + node.check_undefined() + node.check_inputs() + t1 = time.time() + if verbose: + console.debug(f"executing node '{item.key}'") + if not node.offload: + results = await node.execute() + else: + results = await offload(node.execute()) + results = node.check_results(results) + all_results[item.key] = results + if return_api_response: + results = await node.get_api_response(results) + results = node.check_api_results(results) + api_results[item.key] = results + t2 = time.time() + item.data.executing = False + all_latencies[item.key] = dict( + pending=t0 - start_t, + inject=t1 - t0, + execute=t2 - t1, + latency=t2 - t0, + ) + if verbose: + console.debug(f"finished executing node '{item.key}'") + + async def execute( + self, + target: str, + intermediate: Optional[List[str]] = None, + *, + return_api_response: bool = False, + return_if_exception: bool = False, + verbose: bool = False, + ) -> Dict[str, Any]: + """ + Executes the workflow ending at the `target` node. + + Parameters + ---------- + target : str + The key of the target node which the execution will end at. + intermediate : List[str], optional + A list of intermediate nodes that will be returned. Default is `None`. + - Only useful when `return_api_response` is `True`. + - If `None`, no intermediate nodes will be returned. + return_if_exception : bool, optional + If `True`, the function will return even if an exception occurs. Default is `False`. + return_api_response : bool, optional + If `True`, the function will: + - Only return the results of the `target` node & the `intermediate` nodes. + - Call `get_api_response` on the results to get the final API response. + verbose : bool, optional + If `True`, the function will print detailed logs. Default is `False`. + + Returns + ------- + dict + A dictionary containing the results of the execution. + - If `return_api_response` is `True`, only outputs of the `target` node can be accessed + (via `results[target]`). + - Otherwise, outputs of all nodes can be accessed (via `results[key]`, where `key` is + the key of the node). + - If an exception occurs during the execution, the dictionary will contain + a key 'EXCEPTION_MESSAGE_KEY' with the error message as the value. + + """ + + api_results: Dict[str, Any] = {} + all_results: Dict[str, Any] = {} + extra_results: Dict[str, Any] = {} + all_latencies: Dict[str, Dict[str, float]] = {} + if intermediate is None: + intermediate = [] + reachable_nodes: List[Node] = [] + try: + workflow = self.copy() + if target not in workflow: + raise ValueError(f"cannot find target '{target}' in the workflow") + reachable = workflow.get_reachable(target) + reachable_nodes = [item.data for item in workflow if item.key in reachable] + for node in reachable_nodes: + node.check_injections() + await warmup(node.__class__, verbose) + for node in reachable_nodes: + if verbose: + console.debug(f"initializing node '{node.key}'") + await node.initialize(workflow) + await asyncio.gather( + *( + workflow.run( + item, + api_results, + all_results, + return_api_response + and (item.key == target or item.key in intermediate), + verbose, + all_latencies, + ) + for item in workflow + if item.key in reachable + ) + ) + extra_results[EXCEPTION_MESSAGE_KEY] = None + except Exception as err: + if not return_if_exception: + raise + err_msg = get_err_msg(err) + extra_results[EXCEPTION_MESSAGE_KEY] = err_msg + if verbose: + console.error(err_msg) + finally: + for node in reachable_nodes: + if verbose: + console.debug(f"cleaning up node '{node.key}'") + try: + await node.cleanup() + except Exception as err: + msg = f"error occurred when cleaning up node '{node.key}': {get_err_msg(err)}" + console.error(msg) + self.latest_latencies = all_latencies + extra_results[ALL_LATENCIES_KEY] = all_latencies + final_results = api_results if return_api_response else all_results + final_results.update(extra_results) + return final_results + + +Node.d = nodes # type: ignore + + +class SrcKey(BaseModel): + src_key: str = Field(..., description="The key of the dependent node.") + + +class LoopBackInjectionModel(BaseModel): + """Data model of `LoopBackInjection`""" + + src_hierarchy: Optional[Union[str, List[str]]] = Field( + ..., + description="""The 'src_hierarchy' of the dependent node's results that the current node depends on. +- `src_hierarchy` can be very complex: + - use `int` as `list` index, and `str` as `dict` key. + - use list / `.` to represent nested structure. + - for example, you can use `["a", "0", "b"]` or `a.0.b` to indicate `results["a"][0]["b"]`. +- If `None`, all results of the dependent node will be used.""", + ) + dst_hierarchy: Union[str, List[str]] = Field( + ..., + description="""The 'dst_hierarchy' of the current node's `data`. +- `dst_hierarchy` can be very complex: + - use `int` as `list` index, and `str` as `dict` key. + - use list / `.` to represent nested structure. + - for example, if you want to inject to `data["a"][0]["b"]`, you can use either `["a", "0", "b"]` or `a.0.b` as the `dst_hierarchy`.""", + ) + + +class InjectionModel(LoopBackInjectionModel, SrcKey): + pass + + +class NodeModel(BaseModel): + key: str = Field( + ..., + description="The key of the node, should be unique with respect to the workflow.", + ) + type: str = Field( + ..., + description="The type of the node, should be the one when registered.", + ) + data: Dict[str, Any] = Field( + default_factory=dict, + description="The data associated with the node.", + ) + injections: List[InjectionModel] = Field( + default_factory=list, + description="A list of injections of the node.", + ) + offload: bool = Field( + False, + description="A flag indicating whether the node should be offloaded.", + ) + lock_key: Optional[str] = Field(None, description="The lock key of the node.") + + +class WorkflowModel(BaseModel): + target: str = Field(..., description="The target output node of the workflow.") + intermediate: Optional[List[str]] = Field( + None, + description="The intermediate nodes that you want to get outputs from.", + ) + nodes: List[NodeModel] = Field(..., description="A list of nodes in the workflow.") + return_if_exception: bool = Field( + False, + description="Whether to return partial results if exception occurs.", + ) + verbose: bool = Field(False, description="Whether to print debug logs.") + + def get_workflow(self) -> Flow: + workflow_json = [] + for node in self.model_dump()["nodes"]: + node_json = dict(type=node.pop("type"), info=node) + workflow_json.append(node_json) + return Flow.from_json(workflow_json) + + async def run(self, *, return_api_response: bool = False) -> Dict[str, Any]: + return await self.get_workflow().execute( + self.target, + self.intermediate, + return_api_response=return_api_response, + return_if_exception=self.return_if_exception, + verbose=self.verbose, + ) + + +__all__ = [ + "UNDEFINED_PLACEHOLDER", + "EXCEPTION_MESSAGE_KEY", + "ALL_LATENCIES_KEY", + "Injection", + "LoopBackInjection", + "Schema", + "Node", + "Flow", + "BaseModel", + "SrcKey", + "LoopBackInjectionModel", + "InjectionModel", + "NodeModel", + "WorkflowModel", +] diff --git a/cfdraw/core/flow/docs.py b/cfdraw/core/flow/docs.py new file mode 100644 index 00000000..5c694fec --- /dev/null +++ b/cfdraw/core/flow/docs.py @@ -0,0 +1,385 @@ +import json +import inspect + +from typing import List +from typing import Type +from typing import Tuple +from typing import Optional +from pathlib import Path +from dataclasses import dataclass + +from .. import flow as cflow +from ..toolkit import console + + +RAG_SEPARATOR = "__RAG__" +UNDEFINED_PLACEHOLDER = "*Undefined*" + + +@dataclass +class Document: + name: str + source_codes: str + description: str = UNDEFINED_PLACEHOLDER + input_docs: str = UNDEFINED_PLACEHOLDER + output_docs: str = UNDEFINED_PLACEHOLDER + api_ouput_docs: Optional[str] = None + rag: bool = False + + @property + def markdown(self) -> str: + if not self.rag: + eof = "" + title = f"## {self.name}" + else: + eof = f"\n{RAG_SEPARATOR}\n" + title = f"## Supported Node - {self.name}" + return f"""{title} + +### Description + +{self.description} + +### Inputs + +{self.input_docs} + +### Functional Outputs + +{self.output_docs} + +### API Outputs + +{self.api_ouput_docs or "*Same as the functional outputs.*"} + +### Source Codes + +```python +{self.source_codes}``` +{eof} +""" + + +def strip_source(source: str, identifier: str) -> str: + source = source.strip() + id_index = source.index(identifier) + source = source[:id_index] + return source.strip() + + +def fetch_doc_sources(t_base: type) -> List[str]: + sources = [] + for sub in cflow.__dict__.values(): + if not inspect.isclass(sub): + continue + if issubclass(sub, t_base) and sub is not t_base: + sources.append( + f"""### `{sub.__name__}` + +```python +{inspect.getsource(sub).replace("`", "'")}``` +""" + ) + return sources + + +def genearte_document(t_node: Type[cflow.Node], rag: bool) -> Optional[Document]: + schema = t_node.get_schema() + source = inspect.getsource(t_node) + document = Document(name=t_node.__name__, source_codes=source, rag=rag) + if schema is None: + return document + if schema.input_model is not None: + document.input_docs = f"""```python +{inspect.getsource(schema.input_model).replace("`", "'")}```""" + elif schema.input_names is not None: + input_strings = [f"- {name}\n" for name in schema.input_names] + document.input_docs = f"""'{document.name}' has following inputs: +{''.join(input_strings)[:-1]}""" + if schema.output_model is not None: + document.output_docs = f"""```python +{inspect.getsource(schema.output_model).replace("`", "'")}```""" + elif schema.output_names is not None: + output_strings = [f"- {name}\n" for name in schema.output_names] + document.output_docs = f"""'{document.name}' has following outputs: +{''.join(output_strings)[:-1]}""" + if schema.api_output_model is not None: + document.api_ouput_docs = f"""```python +{inspect.getsource(schema.api_output_model).replace("`", "'")}```""" + if schema.description is not None: + document.description = schema.description + return document + + +def generate_documents(output: str, *, rag: bool = False) -> None: + def get_example_title_and_eof(path: Path, title_prefix: str) -> Tuple[str, str]: + relative = str(path.relative_to(root)) + if not rag: + eof = "\n" + title = f"### `{relative}`" + else: + eof = f"\n{RAG_SEPARATOR}\n" + title = f"### {title_prefix}`{relative}`" + return title, eof + + def get_code_example(code: Path) -> str: + title, eof = get_example_title_and_eof(code, "Coding Example - ") + return f"{title}\n\n```python\n{code.read_text()}```{eof}" + + def get_json_example(workflow: Path) -> str: + title, eof = get_example_title_and_eof(workflow, "Workflow JSON Example - ") + with open(workflow, "r") as f: + w_json = json.load(f) + w_description = w_json.pop("$description", "*Description is not provided.*") + return f"""{title} + +{w_description} + +```json +{json.dumps(w_json, indent=2, ensure_ascii=False)} +```{eof}""" + + if not output.endswith(".md"): + raise ValueError(f"`dst` should be a markdown file, '{output}' found") + console.rule("Generating Documents") + t_nodes = cflow.use_all_t_nodes() + gen_doc = lambda t_node: genearte_document(t_node, rag) + documents: List[Document] = list(filter(bool, map(gen_doc, t_nodes))) # type: ignore + root = Path(__file__).parent.parent + examples_dir = root / "examples" + workflows_dir = examples_dir / "workflows" + code_snippets = examples_dir.glob("*.py") + workflow_jsons = workflows_dir.rglob("*.json") + code_example_docs = "\n".join(map(get_code_example, code_snippets))[:-1] + workflow_example_docs = "\n".join(map(get_json_example, workflow_jsons))[:-1] + workflow_model_source = inspect.getsource(cflow.WorkflowModel).replace("`", "'") + workflow_model_source = strip_source(workflow_model_source, "def") + workflow_execute_source = inspect.getsource(cflow.Flow.execute).replace("`", "'") + workflow_execute_source = strip_source(workflow_execute_source, "api_results: ") + workflow_execute_split = workflow_execute_source.split("\n") + workflow_execute_split[1:] = [line[4:] for line in workflow_execute_split[1:]] + workflow_execute_source = "\n".join(workflow_execute_split) + enum_docs = "\n".join(fetch_doc_sources(cflow.DocEnum))[:-1] + data_model_docs = "\n".join(fetch_doc_sources(cflow.DocModel))[:-1] + sep = f"\n{RAG_SEPARATOR}\n" if rag else "" + generated = f"""# `carefree-workflow` Documentation + +Here are some design principles of `carefree-workflow`: +- `workflow` is a `DAG` (directed acyclic graph) of `nodes`. +- `workflow` is actually constructed by a set of `nodes` with `injections` defined. + - `injections` indicate the dependencies between `nodes`. +- Every `node`, as well as the `workflow` itself, can be used in both `functional` and `API` ways. +- Every `node` should take `dict` as inputs and return `dict` as outputs. + +And below will be the detailed documents of: +- Installation. +- The general introduction of `node` and `workflow`. +- All the nodes supported in `carefree-workflow`. + +We'll also include some examples at the end to help you understand how to use `carefree-workflow` in your projects. + +# Installation + +`carefree-workflow` requires Python 3.8 or higher. + +```bash +pip install carefree-workflow +``` + +or + +```bash +git clone https://github.com/carefree0910/carefree-workflow.git +cd carefree-workflow +pip install -e . +``` +{sep} +# Node + +Every `node` in `carefree-workflow` should inherit from `cflow.Node`: + +```python +{strip_source(inspect.getsource(cflow.Node).replace("`", "'"), "# optional")} +``` + +It looks complicated, but `node` can actually be simply understood as a `function`, except: +- It can be used in an `API` way **automatically**, as long as it implements the `get_schema` method. +- It can be used in a `workflow`, which means: + - Its input(s) can be the output(s) from other `node`. + - Its output(s) can be the input(s) of other `node`. + +The second feature is achieved by `injections`, which is represented by: +- `Injection`, if used in a `functional` way. +- `InjectionModel`, if used in an `API` way. + +`InjectionModel` will be introduced later ([API usage](#api-usage)), and here is the definition of `Injection`: + +```python +{inspect.getsource(cflow.Injection).replace("`", "'")}``` + +> Example of how to use `Injection` will also be introduced later ([Functional usage](#functional-usage)). + +## Example + +Here's an example of how to define a custom `node`: + +```python +@Node.register("hello") +class HelloNode(Node): + async def execute(self): + name = self.data["name"] + return {'''{"name": name, "greeting": f"Hello, {name}!"}'''} +``` + +In the above example, we defined a `node` named `hello`, which takes a `name` as input and returns the `name` itself and a `greeting` as outputs. + +To make it 'collectable' by the automated system, we can implement the `get_schema` method: + +```python +class HelloInput(BaseModel): + name: str + + +class HelloOutput(BaseModel): + name: str + greeting: str + + +@Node.register("hello") +class HelloNode(Node): + @classmethod + def get_schema(cls): + return Schema(HelloInput, HelloOutput) + + async def execute(self): + name = self.data["name"] + return {'''{"name": name, "greeting": f"Hello, {name}!"}'''} +``` + +This will help us automatically generate the API endpoint as well as the documentation. +{sep} +# Workflow + +```python +{strip_source(inspect.getsource(cflow.Flow).replace("`", "'"), "def __init__")} +``` + +The key method used by users will be the `execute` method, which is defined as: + +```python +{workflow_execute_source} +``` + +## Functional usage + +A typical procedure of using `workflow` in a `functional` way is as follows: +- Define your custom `nodes` by inheriting from `cflow.Node` (if needed). +- Define your `workflow` by using `cflow.Flow` in a chainable way (`cflow.Flow().push(...).push(...)`). + - Use `flow.gather(...)` if you have multiple targets. +- Call `await workflow.execute(...)` to execute the `workflow` with the given inputs. + +Here is a simple example: + +```python +import asyncio + +from cflow import * + +@Node.register("hello") +class HelloNode(Node): + async def execute(self): + name = self.data["name"] + return {'''{"name": name, "greeting": f"Hello, {name}!"}'''} + +async def main(): + flow = ( + Flow() + .push(HelloNode("A", dict(name="foo"))) + .push(HelloNode("B", dict(name="bar"))) + .push( + EchoNode( + "Echo", + dict(messages=[None, None, "Hello, World!"]), + injections=[ + Injection("A", "name", "messages.0"), + Injection("B", "greeting", "messages.1"), + ], + ) + ) + ) + await flow.execute("Echo") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +Running the above codes will yield something like: + +```text +[17:30:27] foo + Hello, bar! + Hello, World! +``` + +> More examples can be found at the end of this document ([Coding Examples](#coding-examples)). + +## API usage + +Here are some important input data models when you want to use `workflow` in an `API` way. + +> Examples can be found at the end of this document ([Workflow JSON Examples](#workflow-json-examples)). + +### `InjectionModel` + +```python +{inspect.getsource(cflow.SrcKey).replace("`", "'")}``` + +```python +{inspect.getsource(cflow.LoopBackInjectionModel).replace("`", "'")}``` + +```python +{inspect.getsource(cflow.InjectionModel).replace("`", "'")}``` + +### `NodeModel` + +```python +{inspect.getsource(cflow.NodeModel).replace("`", "'")}``` + +### `WorkflowModel` + +```python +{workflow_model_source} +``` +{sep} +# Schema + +## Common Enums + +{enum_docs} + +## Common Data Models + +{data_model_docs} +{sep} +# Supported Nodes + +{''.join([document.markdown for document in documents])[:-1]} + +# Examples + +## Coding Examples + +{code_example_docs} + +## Workflow JSON Examples + +{workflow_example_docs} +""" + with open(output, "w") as f: + f.write(generated) + console.log(f"generated documents saved to '{output}'!") + + +__all__ = [ + "generate_documents", +] diff --git a/cfdraw/core/flow/nodes/__init__.py b/cfdraw/core/flow/nodes/__init__.py new file mode 100644 index 00000000..767bae04 --- /dev/null +++ b/cfdraw/core/flow/nodes/__init__.py @@ -0,0 +1,2 @@ +from .schema import * +from .common import * diff --git a/cfdraw/core/flow/nodes/common.py b/cfdraw/core/flow/nodes/common.py new file mode 100644 index 00000000..494ed3f0 --- /dev/null +++ b/cfdraw/core/flow/nodes/common.py @@ -0,0 +1,383 @@ +# Common Nodes + +import json +import shutil +import asyncio + +from PIL import Image +from typing import Any +from typing import Dict +from typing import List +from typing import Union +from typing import Optional +from pathlib import Path +from pydantic import Field +from pydantic import BaseModel +from dataclasses import dataclass + +from .schema import TImage +from .schema import ImageModel +from .schema import IImageNode +from .schema import EmptyOutput +from .schema import IWithImageNode +from ..core import LOOP_NODE +from ..core import GATHER_NODE +from ..core import WORKFLOW_NODE +from ..core import extract_from +from ..core import inject_leaf_data +from ..core import Node +from ..core import Flow +from ..core import Schema +from ..core import Injection +from ..core import WorkflowModel +from ..core import LoopBackInjectionModel +from ...toolkit import console +from ...toolkit.misc import shallow_copy_dict + + +# functional nodes + + +class LoopInput(BaseModel): + base_node: str = Field(..., description="The node to be looped.") + base_data: Dict[str, Any] = Field(default_factory=dict, description="Base data.") + loop_values: Dict[str, List[Any]] = Field( + ..., + description="""The values to be looped. +> - The keys should be the 'target hierarchy' of the `data` +> - The values should be a list of values to be looped & injectedinto the 'target hierarchy'. +> - All values should have the same length. + +For example, if you want to loop `data["a"]` with values `[1, 2]`, and loop `data["b"][0]["c"]` with values `[3, 4]`, you can use: +```python +{ + "a": [1, 2], + "b.0.c": [3, 4], +} +``` +""", + ) + loop_back_injections: Optional[List[LoopBackInjectionModel]] = Field( + None, + description="The loop back injections.\n" + "> - If this is set, the results from the previous step in the loop will be " + "injected into the current node's `data`.\n" + "> - If `None`, no injection will be performed, and all nodes will be " + "executed in parallel.", + ) + extract_hierarchy: Optional[str] = Field( + None, + description="The hierarchy of the results to be extracted.\n" + "> - If `None`, all results will be preserved.", + ) + verbose: bool = Field(False, description="Whether to print debug logs.") + + +class LoopOutput(BaseModel): + results: List[Any] = Field(..., description="The results of the loop.") + + +@Node.register(LOOP_NODE) +class LoopNode(Node): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + LoopInput, + LoopOutput, + description="A node that represents a loop of another node.", + ) + + async def execute(self) -> Dict[str, List[Dict[str, Any]]]: + t_node = Node.get(self.data["base_node"]) + if t_node is None: + raise ValueError(f"node `{self.data['base_node']}` is not defined") + base_data = self.data["base_data"] + loop_values = self.data["loop_values"] + loop_back_injections = self.data["loop_back_injections"] + loop_keys = list(loop_values) + lengths = [len(loop_values[k]) for k in loop_keys] + if len(set(lengths)) != 1: + raise ValueError( + "all loop values should have the same length, " + f"but lengths are {lengths}" + ) + n = lengths[0] + flow = Flow() + verbose = self.data["verbose"] + for i in range(n): + i_data = shallow_copy_dict(base_data) + for k in loop_keys: + v = loop_values[k][i] + inject_leaf_data(i_data, k.split("."), v, verbose=verbose) + if loop_back_injections is None or i == 0: + i_injections = [] + else: + i_injections = list(map(shallow_copy_dict, loop_back_injections)) + i_injections = [Injection(str(i - 1), **d) for d in i_injections] + flow.push(t_node(str(i), i_data, i_injections)) + target = flow.gather(*map(str, range(n))) + results = await flow.execute(target, verbose=self.data["verbose"]) + extracted = [results[str(i)] for i in range(n)] + extract_hierarchy = self.data["extract_hierarchy"] + if extract_hierarchy is not None: + extracted = [extract_from(rs, extract_hierarchy) for rs in extracted] + return {"results": extracted} + + +@Node.register(GATHER_NODE) +class GatherNode(Node): + flow: Optional[Flow] = None + + @classmethod + def get_schema(cls) -> Schema: + return Schema( + description="A node that is used to gather other nodes' results.\n" + "> - This is useful when you have multiple targets to collect results from.\n" + "> - If you are programming in Python, you can use `flow.gather` to make things easier.", + ) + + async def initialize(self, flow: Flow) -> None: + await super().initialize(flow) + self.flow = flow + + async def get_api_response(self, results: Dict[str, Any]) -> Dict[str, Any]: + if self.flow is None: + console.warn( + "`flow` is not provided for `GatherNode`, raw results will be returned " + "and `get_api_response` might not work as expected" + ) + return results + keys = list(results) + node_items = [self.flow.get(k) for k in keys] + if any(item is None for item in node_items): + raise ValueError( + "internal error: some nodes are not found when getting api response: " + f"{[k for k, n in zip(keys, node_items) if n is None]}" + ) + nodes = [item.data for item in node_items] # type: ignore + tasks = [node.get_api_response(results[k]) for k, node in zip(keys, nodes)] + converted = await asyncio.gather(*tasks) + return {k: v for k, v in zip(keys, converted)} + + async def execute(self) -> Dict[str, Any]: + return self.data + + def from_info(self, info: Dict[str, Any]) -> "GatherNode": + super().from_info(info) + for injection in self.injections: + if injection.src_hierarchy is not None: + raise ValueError( + "`GatherNode` should always use `src_hierarchy=None` " + f"for injections, but `{injection}` is found" + ) + if injection.src_key != injection.dst_hierarchy: + raise ValueError( + "`GatherNode` should always use `src_key=dst_hierarchy` " + f"for injections, but `{injection}` is found" + ) + return self + + def copy(self) -> "GatherNode": + copied = super().copy() + copied.flow = self.flow + return copied + + +@Node.register(WORKFLOW_NODE) +class WorkflowNode(Node): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + input_model=WorkflowModel, + description="A node that represents a workflow", + ) + + async def execute(self) -> Dict[str, Any]: + return await WorkflowModel(**self.data).run(return_api_response=False) + + +# common nodes + + +class ParametersModel(BaseModel): + params: Dict[str, Any] = Field(default_factory=dict, description="The parameters.") + + +@Node.register("common.parameters") +class ParametersNode(Node): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + ParametersModel, + ParametersModel, + description="Setup parameters.\n" + "> - This is often used in a pre-defined workflow JSON to decide " + "which parameters to be exposed to the user.\n" + "> - See [examples](https://github.com/carefree0910/carefree-workflow/tree/main/examples/workflows) for reference.", + ) + + async def execute(self) -> Dict[str, Any]: + return self.data + + +class EchoModel(BaseModel): + messages: Union[str, List[str]] + + +@Node.register("common.echo") +class EchoNode(Node): + @classmethod + def get_schema(cls) -> Schema: + return Schema(EchoModel, EchoModel, description="Echo the given message(s).") + + async def execute(self) -> Dict[str, Union[str, List[str]]]: + messages = self.data["messages"] + if isinstance(messages, str): + messages = [messages] + for message in messages: + console.log(message) + return self.data + + +def pad_parent(path: str, parent: Optional[str]) -> Path: + if parent is None: + return Path(path) + return Path(parent) / path + + +class CopyInput(BaseModel): + src: str = Field(..., description="The source path.") + dst: str = Field(..., description="The destination path.") + parent: Optional[str] = Field(None, description="The parent directory of `dst`.") + + +class CopyOutput(BaseModel): + dst: str = Field(..., description="The destination path with parent directory.") + + +@Node.register("common.copy") +class CopyNode(Node): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + CopyInput, + CopyOutput, + description="Copy a file from `src` to `dst`.", + ) + + async def execute(self) -> Dict[str, str]: + src = Path(self.data["src"]) + dst = pad_parent(self.data["dst"], self.data["parent"]) + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(src, dst) + return {"dst": str(dst)} + + +class SaveJsonInput(BaseModel): + data: Any = Field(..., description="The data to be saved.") + path: str = Field(..., description="The path to save the data.") + parent: Optional[str] = Field(None, description="The parent directory for saving.") + + +@Node.register("common.save_json") +class SaveJsonNode(Node): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + SaveJsonInput, + output_model=CopyOutput, + description="Save the given data to a JSON file.", + ) + + async def execute(self) -> dict: + path = pad_parent(self.data["path"], self.data["parent"]) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + json.dump(self.data["data"], f) + return {"dst": str(path)} + + +@dataclass +@Node.register("common.download_image") +class DownloadImageNode(IImageNode): + offload: bool = True + + @classmethod + def get_schema(cls) -> Schema: + schema = super().get_schema() + schema.description = "Download an image from the given url." + return schema + + async def execute(self) -> Dict[str, Image.Image]: + image = await self.get_image_from("url") + return {"image": image} + + +class SaveImageInput(ImageModel): + path: str = Field("debug.png", description="The path to save the image.") + parent: Optional[str] = Field(None, description="The parent directory for saving.") + + +@Node.register("debug.save_image") +class SaveImageNode(IWithImageNode): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + SaveImageInput, + output_model=EmptyOutput, + description="Save an image from the given url to disk, mainly for debugging.", + ) + + async def execute(self) -> dict: + image = await self.get_image_from("url") + image.save(pad_parent(self.data["path"], self.data["parent"])) + return {} + + +class SaveImagesInput(BaseModel): + urls: List[TImage] = Field(..., description="The urls of the images.") + prefix: str = Field("debug", description="The prefix to save the images.") + parent: Optional[str] = Field(None, description="The parent directory for saving.") + + +@Node.register("debug.save_images") +class SaveImagesNode(IWithImageNode): + @classmethod + def get_schema(cls) -> Schema: + return Schema( + SaveImagesInput, + output_model=EmptyOutput, + description="Save images from the given urls to disk, mainly for debugging.", + ) + + async def execute(self) -> dict: + tasks = [self.fetch_image(str(v), v) for v in self.data["urls"]] + images = await asyncio.gather(*tasks) + prefix = self.data["prefix"] + for i, image in enumerate(images): + image.save(pad_parent(f"{prefix}_{i}.png", self.data["parent"])) + return {} + + +# common node utils + + +def to_endpoint(name: str) -> str: + split = name.split(".") + return f"/{'/'.join(split)}" + + +__all__ = [ + "LoopBackInjectionModel", + "LoopNode", + "GatherNode", + "WorkflowNode", + "ParametersNode", + "EchoNode", + "CopyNode", + "SaveJsonNode", + "DownloadImageNode", + "SaveImageNode", + "SaveImagesNode", + "to_endpoint", +] diff --git a/cfdraw/core/flow/nodes/schema.py b/cfdraw/core/flow/nodes/schema.py new file mode 100644 index 00000000..3c9a9d90 --- /dev/null +++ b/cfdraw/core/flow/nodes/schema.py @@ -0,0 +1,236 @@ +import json +import websockets + +from PIL import Image +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Type +from typing import Union +from typing import Optional +from typing import Protocol +from aiohttp import ClientSession +from pydantic import Field +from pydantic import BaseModel +from dataclasses import dataclass +from pydantic_core import core_schema + +from ..core import Hook +from ..core import Node +from ..core import Schema +from ...toolkit.cv import to_base64 +from ...toolkit.web import download_raw_with_retry +from ...toolkit.web import download_image_with_retry + + +HTTP_SESSION_KEY = "$http_session$" + + +# enums / data models + + +class DocEnum(Enum): + """A class that tells use to include it in the documentation""" + + +class DocModel(BaseModel): + """A class that tells use to include it in the documentation""" + + +class TextModel(DocModel): + text: str = Field(..., description="The text.") + + +class ImageField(Image.Image): + @classmethod + def __get_pydantic_core_schema__(cls, *args: Any) -> core_schema.CoreSchema: + return core_schema.with_info_plain_validator_function(cls.validate) + + @classmethod + def validate(cls, v: Any, info: core_schema.ValidationInfo) -> Image.Image: + if isinstance(v, Image.Image): + return v + raise ValueError("Value must be a PIL Image") + + +TImage = Union[str, ImageField] + + +class ImageModel(DocModel): + url: TImage = Field(..., description="The url / PIL.Image instance of the image.") + + +class ImageAPIOuput(DocModel): + image: str = Field(..., description="The base64 encoded image.") + + +class EmptyOutput(BaseModel): + pass + + +# hooks / node interfaces + + +class HttpSessionHook(Hook): + @classmethod + async def initialize(cls, shared_pool: Dict[str, Any]) -> None: + if HTTP_SESSION_KEY not in shared_pool: + shared_pool[HTTP_SESSION_KEY] = ClientSession() + + @classmethod + async def cleanup(cls, shared_pool: Dict[str, Any]) -> None: + http_session = shared_pool.pop(HTTP_SESSION_KEY, None) + if http_session is not None: + if not isinstance(http_session, ClientSession): + raise TypeError(f"invalid http session type: {type(http_session)}") + await http_session.close() + + +class IWithHttpSessionNode(Node): + """ + node interface which requires `ClientSession` in the `shared_pool`. + + Notes + ----- + - This interface provides `http_session` to get the `ClientSession` from the `shared_pool`. + - This interface provides `download_raw` and `download_image` to download data from the internet. + + """ + + @classmethod + def get_hooks(cls) -> List[Type[Hook]]: + return [HttpSessionHook] + + @property + def http_session(self) -> ClientSession: + session = self.shared_pool.get(HTTP_SESSION_KEY) + if session is None: + raise ValueError( + f"`{HTTP_SESSION_KEY}` should be provided in the `shared_pool` " + f"for `{self.__class__.__name__}`" + ) + if not isinstance(session, ClientSession): + raise TypeError(f"invalid http session type: {type(session)}") + return session + + async def download_raw(self, url: str) -> bytes: + return await download_raw_with_retry(self.http_session, url) + + async def download_image(self, url: str) -> Image.Image: + return await download_image_with_retry(self.http_session, url) + + +class IHandleMessage(Protocol): + async def __call__(self, raw_message: websockets.Data) -> bool: + """ + Handle the message and return whether to terminate the websocket. + + Parameters + ---------- + raw_message: websockets.Data + The raw message. + + Returns + ------- + bool + Whether to terminate the websocket (return `True` to terminate). + """ + + +class IWithWebsocketNode(Node): + async def connect( + self, + url: str, + *, + handler: IHandleMessage, + send_data: Optional[Dict[str, Any]] = None, + headers: Optional[websockets.HeadersLike] = None, + **kwargs: Any, + ) -> None: + kwargs["extra_headers"] = headers + async with websockets.connect(url, **kwargs) as websocket: + if send_data is not None: + await websocket.send(json.dumps(send_data)) + async for raw_message in websocket: + if await handler(raw_message): + break + + +class IWithImageNode(IWithHttpSessionNode): + """ + node interface which (may) have image(s) as input. This is helpful for crafting image processing nodes. + + Notes + ----- + - This interface provides `get_image_from` to get the image from the given field. + - If the field is a url, it will be downloaded. + - The field can also be `PIL.Image` directly, since it might be injected by + other nodes. + + """ + + async def fetch_image(self, tag: str, image: TImage) -> Image.Image: + if isinstance(image, str): + image = await self.download_image(image) + elif not isinstance(image, Image.Image): + raise ValueError(f"`{tag}` should be a `PIL.Image` or a url") + return image + + async def get_image_from(self, field: str) -> Image.Image: + image = self.data[field] + return await self.fetch_image(field, image) + + +class IImageNode(IWithImageNode): + """ + Image node interface. This is helpful for crafting image processing nodes. + + Notes + ----- + - This interface assumes the output to be like `{"image": PIL.Image}`. + + """ + + @classmethod + def get_schema(cls) -> Schema: + return Schema( + ImageModel, + api_output_model=ImageAPIOuput, + output_names=["image"], + ) + + async def get_api_response(self, results: Dict[str, Image.Image]) -> Dict[str, str]: + return {"image": to_base64(results["image"])} + + +@dataclass +class ICUDANode(Node): + """ + CUDA node interface. This is helpful when creating nodes for modern AI models. + + Notes + ----- + - CUDA executions should be 'offloaded' to avoid blocking other async executions. + - CUDA executions should be 'locked' to avoid CUDA issues. + """ + + offload: bool = True + lock_key: str = "$cuda$" + + +__all__ = [ + "DocEnum", + "DocModel", + "TextModel", + "TImage", + "ImageModel", + "ImageAPIOuput", + "EmptyOutput", + "HttpSessionHook", + "IWithHttpSessionNode", + "IWithWebsocketNode", + "IWithImageNode", + "IImageNode", + "ICUDANode", +] diff --git a/cfdraw/core/flow/server.py b/cfdraw/core/flow/server.py new file mode 100644 index 00000000..2e92a432 --- /dev/null +++ b/cfdraw/core/flow/server.py @@ -0,0 +1,158 @@ +import re +import asyncio + +from typing import Any +from typing import Dict +from typing import List +from typing import Type +from typing import Optional +from fastapi import FastAPI +from pydantic import create_model +from pydantic import Field +from pydantic import BaseModel + +from .core import WORKFLOW_ENDPOINT_NAME +from .core import warmup +from .core import Node +from .core import Flow +from .core import WorkflowModel +from .core import InjectionModel +from .nodes.common import to_endpoint +from ..parameters import OPT +from ..toolkit.web import raise_err +from ..toolkit.web import get_responses +from ..toolkit.misc import random_hash + + +def parse_endpoint(t_node: Type[Node]) -> str: + return to_endpoint(t_node.__identifier__) + + +def parse_input_model(t_node: Type[Node]) -> Optional[Type[BaseModel]]: + schema = t_node.get_schema() + if schema is None: + return None + if schema.input_model is not None: + return schema.input_model + if schema.input_names is not None: + return create_model( # type: ignore + f"{t_node.__name__}Input", + **{name: (Any, ...) for name in schema.input_names}, + ) + return None + + +def parse_output_model(t_node: Type[Node]) -> Optional[Type[BaseModel]]: + schema = t_node.get_schema() + if schema is None: + return None + if schema.api_output_model is not None: + return schema.api_output_model + if schema.output_model is not None: + return schema.output_model + if schema.output_names is not None: + return create_model( # type: ignore + f"{t_node.__name__}Output", + **{name: (Any, ...) for name in schema.output_names}, + ) + return None + + +def parse_description(t_node: Type[Node]) -> Optional[str]: + schema = t_node.get_schema() + if schema is None: + return None + return schema.description + + +def use_all_t_nodes() -> List[Type[Node]]: + return list(t_node for t_node in Node.d.values() if issubclass(t_node, Node)) # type: ignore + + +def register_api(app: FastAPI, t_node: Type[Node], focus: str) -> None: + endpoint = parse_endpoint(t_node) + if not re.search(focus, endpoint): + return + input_model = parse_input_model(t_node) + output_model = parse_output_model(t_node) + description = parse_description(t_node) + if input_model is None or output_model is None: + return None + names = t_node.__identifier__.split(".") + names[0] = f"[{names[0]}]" + name = "_".join(names) + asyncio.run(warmup(t_node, True)) + + @app.post( + endpoint, + name=name, + responses=get_responses(output_model), + description=description, + ) + async def _(data: input_model) -> output_model: # type: ignore + try: + key = random_hash() + flow = Flow().push(t_node(key, data.model_dump())) # type: ignore + results = await flow.execute(key, return_api_response=True) + return output_model(**results[key]) + except Exception as err: + raise_err(err) + + +def register_nodes_api(app: FastAPI) -> None: + focus = OPT.flow_opt["focus"] + for t_node in use_all_t_nodes(): + register_api(app, t_node, focus) + + +def register_workflow_api(app: FastAPI) -> None: + @app.post(f"/{WORKFLOW_ENDPOINT_NAME}") + async def workflow(data: WorkflowModel) -> Dict[str, Any]: + try: + return await data.run(return_api_response=True) + except Exception as err: + raise_err(err) + return {} + + +class ServerStatus(BaseModel): + num_nodes: int = Field( + ..., + description="The number of registered nodes in the environment.\n" + "> - Notice that this may be different from the number of nodes " + "which are exposed as API, because some nodes may not have " + "`get_schema` method implemented.\n" + "> - However, all nodes can be used in the `workflow` API, no matter " + "whether they have `get_schema` method implemented or not.", + ) + + +def register_server_api(app: FastAPI) -> None: + @app.get("/server_status", responses=get_responses(ServerStatus)) + async def server_status() -> ServerStatus: + return ServerStatus(num_nodes=len(use_all_t_nodes())) + + +class API: + def __init__(self) -> None: + self.app = FastAPI() + + def initialize(self) -> None: + register_server_api(self.app) + register_nodes_api(self.app) + register_workflow_api(self.app) + + +api = API() + + +__all__ = [ + "parse_endpoint", + "parse_input_model", + "parse_output_model", + "use_all_t_nodes", + "register_api", + "register_nodes_api", + "register_workflow_api", + "API", +] diff --git a/cfdraw/core/flow/utils.py b/cfdraw/core/flow/utils.py new file mode 100644 index 00000000..773c84c6 --- /dev/null +++ b/cfdraw/core/flow/utils.py @@ -0,0 +1,209 @@ +import io + +import networkx as nx +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +from PIL import Image +from typing import Set +from typing import Dict +from typing import List +from typing import Tuple +from typing import Optional +from typing import NamedTuple +from dataclasses import asdict + +from .core import Node +from .core import Flow +from .core import NodeModel +from .core import WorkflowModel +from .core import InjectionModel +from ..toolkit.misc import truncate_string_to_length +from ..toolkit.data_structures import Item + + +class ToposortResult(NamedTuple): + in_edges: Dict[str, Set[str]] + hierarchy: List[List[Item[Node]]] + edge_labels: Dict[Tuple[str, str], str] + reachable: Set[str] + + +def toposort(workflow: Flow) -> ToposortResult: + in_edges: Dict[str, Set[str]] = {item.key: set() for item in workflow} + out_degrees = {item.key: 0 for item in workflow} + edge_labels: Dict[Tuple[str, str], str] = {} + for item in workflow: + visited = set() + dst_key = item.key + for injection in item.data.injections: + in_edges[injection.src_key].add(dst_key) + if injection.src_key not in visited: + visited.add(injection.src_key) + out_degrees[dst_key] += 1 + label_key = (dst_key, injection.src_key) + edge_label = str(injection.dst_hierarchy) + existing_label = edge_labels.get(label_key) + if existing_label is None: + edge_labels[label_key] = edge_label + else: + edge_labels[label_key] = f"{existing_label}, {edge_label}" + for k, v in edge_labels.items(): + edge_labels[k] = truncate_string_to_length(v, 25) + + ready = [k for k, v in out_degrees.items() if v == 0] + result = [] + while ready: + layer = ready.copy() + result.append(layer) + ready.clear() + for dep in layer: + for node in in_edges[dep]: + out_degrees[node] -= 1 + if out_degrees[node] == 0: + ready.append(node) + + if len(workflow) != sum(map(len, result)): + raise ValueError("cyclic dependency detected") + + hierarchy = [list(map(workflow.get, layer)) for layer in result] + reachable = {item.key for item in workflow} + return ToposortResult(in_edges, hierarchy, edge_labels, reachable) # type: ignore + + +def get_dependency_path(workflow: Flow, target: str) -> ToposortResult: + reachable = workflow.get_reachable(target) + in_edges, raw_hierarchy, edge_labels, _ = toposort(workflow) + hierarchy = [] + for raw_layer in raw_hierarchy: + layer = [] + for item in raw_layer: + if item.key in reachable: + layer.append(item) + if layer: + hierarchy.append(layer) + return ToposortResult(in_edges, hierarchy, edge_labels, reachable) + + +def render_workflow( + workflow: Flow, + *, + target: Optional[str] = None, + figsize: Optional[Tuple[int, int]] = None, + fig_w_ratio: int = 4, + fig_h_ratio: int = 3, + dpi: int = 200, + node_size: int = 2000, + node_shape: str = "s", + node_color: str = "lightblue", + layout: str = "multipartite_layout", +) -> Image.Image: + # setup graph + G = nx.DiGraph() + if target is None: + target = workflow.last.key # type: ignore + in_edges, hierarchy, edge_labels, _ = get_dependency_path(workflow, target) + # setup plt + if figsize is None and layout == "multipartite_layout": + fig_w = max(fig_w_ratio * len(hierarchy), 8) + fig_h = fig_h_ratio * max(map(len, hierarchy)) + figsize = (fig_w, fig_h) + plt.figure(figsize=figsize, dpi=dpi) + box = plt.gca().get_position() + plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height]) + # map key to indices + key2idx: Dict[str, int] = {} + for layer in hierarchy: + for node in layer: + key2idx[node.key] = len(key2idx) + # add nodes + for i, layer in enumerate(hierarchy): + for node in layer: + G.add_node(key2idx[node.key], subset=f"layer_{i}") + # add edges + for dep, links in in_edges.items(): + for link in links: + if dep not in key2idx or link not in key2idx: + continue + label = edge_labels[(link, dep)] + G.add_edge(key2idx[dep], key2idx[link], label=label) + # calculate positions + layout_fn = getattr(nx, layout, None) + if layout_fn is None: + raise ValueError(f"unknown layout: {layout}") + pos = layout_fn(G) + # draw the nodes + nodes_styles = dict( + node_size=node_size, + node_shape=node_shape, + node_color=node_color, + ) + nx.draw_networkx_nodes(G, pos, **nodes_styles) + node_labels_styles = dict( + font_size=18, + ) + nx.draw_networkx_labels(G, pos, **node_labels_styles) + # draw the edges + nx_edge_labels = nx.get_edge_attributes(G, "label") + nx.draw_networkx_edges( + G, + pos, + arrows=True, + arrowstyle="-|>", + arrowsize=16, + node_size=nodes_styles["node_size"], + node_shape=nodes_styles["node_shape"], + ) + nx.draw_networkx_edge_labels(G, pos, edge_labels=nx_edge_labels) + # draw captions + patches = [ + mpatches.Patch(color=node_color, label=f"{idx}: {key}") + for key, idx in key2idx.items() + ] + plt.legend(handles=patches, bbox_to_anchor=(1, 0.5), loc="center left") + # render + plt.axis("off") + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + return Image.open(buf) + + +def to_data_model( + flow: Flow, + *, + target: str, + intermediate: Optional[List[str]] = None, + return_if_exception: bool = False, + verbose: bool = False, +) -> WorkflowModel: + nodes: List[NodeModel] = [] + for node_item in flow: + node = node_item.data + if node.key is None: + raise ValueError(f"node key cannot be None ({node})") + nodes.append( + NodeModel( + key=node.key, + type=node.__identifier__, + data=node.data, + injections=[InjectionModel(**asdict(d)) for d in node.injections], + offload=node.offload, + lock_key=node.lock_key, + ) + ) + return WorkflowModel( + target=target, + intermediate=intermediate, + nodes=nodes, + return_if_exception=return_if_exception, + verbose=verbose, + ) + + +__all__ = [ + "toposort", + "get_dependency_path", + "render_workflow", + "to_data_model", +] diff --git a/cfdraw/core/parameters.py b/cfdraw/core/parameters.py new file mode 100644 index 00000000..2aba44fd --- /dev/null +++ b/cfdraw/core/parameters.py @@ -0,0 +1,55 @@ +from typing import Any +from typing import Dict +from pathlib import Path + +from .toolkit.misc import update_dict +from .toolkit.misc import OPTBase + + +class OPTClass(OPTBase): + flow_opt: Dict[str, Any] + learn_opt: Dict[str, Any] + + @property + def env_key(self) -> str: + return "CFCORE_ENV" + + @property + def defaults(self) -> Dict[str, Any]: + user_dir = Path.home() + return dict( + flow_opt=dict(focus="", verbose=True), + learn_opt=dict( + cache_dir=user_dir / ".cache" / "carefree-core" / "learn", + data_cache_dir=user_dir / ".cache" / "carefree-core" / "learn" / "data", + external_dir=user_dir + / ".cache" + / "carefree-core" + / "learn" + / "external", + meta_settings={}, + ), + ) + + def update_from_env(self) -> None: + super().update_from_env() + defaults = self.defaults + flow_opt = self._opt["flow_opt"] + learn_opt = self._opt["learn_opt"] + if isinstance(flow_opt, dict): + self._opt["flow_opt"] = update_dict(flow_opt, defaults["flow_opt"]) + if isinstance(learn_opt, dict): + updated = update_dict(learn_opt, defaults["learn_opt"]) + if "cache_dir" in updated: + updated["cache_dir"] = Path(updated["cache_dir"]) + if "external_dir" in updated: + updated["external_dir"] = Path(updated["external_dir"]) + self._opt["learn_opt"] = updated + + +OPT = OPTClass() + + +__all__ = [ + "OPT", +] diff --git a/cfdraw/core/toolkit/__init__.py b/cfdraw/core/toolkit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cfdraw/core/toolkit/array.py b/cfdraw/core/toolkit/array.py new file mode 100644 index 00000000..a4e0d39c --- /dev/null +++ b/cfdraw/core/toolkit/array.py @@ -0,0 +1,835 @@ +import math + +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union +from typing import Callable +from typing import Optional +from typing import NamedTuple +from typing import TYPE_CHECKING +from collections import Counter + +from .misc import to_path +from .misc import random_hash +from .misc import get_file_size +from .misc import timeit +from .types import TPath +from .types import TArray +from .types import arr_type +from .types import tensor_dict_type + +if TYPE_CHECKING: + import torch + import numpy as np + + +def is_int(arr: arr_type) -> bool: + import torch + import numpy as np + + if isinstance(arr, (np.ndarray, np.number)): + return np.issubdtype(arr.dtype, np.integer) + return not torch.is_floating_point(arr) and not torch.is_complex(arr) + + +def is_float(arr: arr_type) -> bool: + import torch + import numpy as np + + if isinstance(arr, (np.ndarray, np.number)): + return np.issubdtype(arr.dtype, np.floating) + return torch.is_floating_point(arr) + + +def is_string(arr: arr_type) -> bool: + import numpy as np + + if isinstance(arr, (np.ndarray, np.character)): + return np.issubdtype(arr.dtype, str) + return False + + +def is_real_numeric(arr: arr_type) -> bool: + return is_float(arr) or is_int(arr) + + +def sigmoid(arr: TArray) -> TArray: + import torch + import numpy as np + + if isinstance(arr, np.ndarray): + return 1.0 / (1.0 + np.exp(-arr)) + return torch.sigmoid(arr) + + +def softmax(arr: TArray) -> TArray: + import numpy as np + import torch.nn.functional as F + + if isinstance(arr, np.ndarray): + logits = arr - np.max(arr, axis=1, keepdims=True) + exp = np.exp(logits) + return exp / exp.sum(1, keepdims=True) + return F.softmax(arr, dim=1) + + +def l2_normalize(arr: TArray) -> TArray: + import numpy as np + + if isinstance(arr, np.ndarray): + return arr / np.linalg.norm(arr, axis=-1, keepdims=True) + return arr / arr.norm(dim=-1, keepdim=True) # type: ignore + + +def normalize( + arr: TArray, + *, + global_norm: bool = True, + return_stats: bool = False, + eps: float = 1.0e-8, +) -> Union[TArray, Tuple[TArray, Dict[str, Any]]]: + import torch + import numpy as np + + if global_norm: + arr_mean, arr_std = arr.mean().item(), arr.std().item() + arr_std = max(eps, arr_std) + out = (arr - arr_mean) / arr_std + if not return_stats: + return out + return out, dict(mean=arr_mean, std=arr_std) + if isinstance(arr, np.ndarray): + arr_mean, arr_std = arr.mean(axis=0), arr.std(axis=0) + std = np.maximum(eps, arr_std) + else: + arr_mean, arr_std = arr.mean(dim=0), arr.std(dim=0) # type: ignore + std = torch.clip(arr_std, min=eps) + out = (arr - arr_mean) / std + if not return_stats: + return out + return out, dict(mean=arr_mean.tolist(), std=std.tolist()) + + +def normalize_from(arr: TArray, stats: Dict[str, Any]) -> TArray: + mean, std = stats["mean"], stats["std"] + return (arr - mean) / std + + +def recover_normalize_from(arr: TArray, stats: Dict[str, Any]) -> TArray: + mean, std = stats["mean"], stats["std"] + return arr * std + mean + + +def min_max_normalize( + arr: TArray, + *, + global_norm: bool = True, + return_stats: bool = False, + eps: float = 1.0e-8, +) -> Union[TArray, Tuple[TArray, Dict[str, Any]]]: + import torch + import numpy as np + + if global_norm: + arr_min, arr_max = arr.min().item(), arr.max().item() + diff = max(eps, arr_max - arr_min) + out = (arr - arr_min) / diff + if not return_stats: + return out + return out, dict(min=arr_min, diff=diff) + if isinstance(arr, np.ndarray): + arr_min, arr_max = arr.min(axis=0), arr.max(axis=0) + diff = np.maximum(eps, arr_max - arr_min) + else: + arr_min, arr_max = arr.min(dim=0).values, arr.max(dim=0).values # type: ignore + diff = torch.clip(arr_max - arr_min, min=eps) + out = (arr - arr_min) / diff + if not return_stats: + return out + return out, dict(min=arr_min.tolist(), diff=diff.tolist()) + + +def min_max_normalize_from(arr: TArray, stats: Dict[str, Any]) -> TArray: + arr_min, diff = stats["min"], stats["diff"] + return (arr - arr_min) / diff + + +def recover_min_max_normalize_from(arr: TArray, stats: Dict[str, Any]) -> TArray: + arr_min, diff = stats["min"], stats["diff"] + return arr * diff + arr_min + + +def quantile_normalize( + arr: TArray, + *, + q: float = 0.01, + global_norm: bool = True, + return_stats: bool = False, + eps: float = 1.0e-8, +) -> Union[TArray, Tuple[TArray, Dict[str, Any]]]: + import torch + import numpy as np + + # quantiles + if isinstance(arr, np.ndarray): + kw = {"axis": 0} + quantile_fn = np.quantile + else: + kw = {"dim": 0} + quantile_fn = torch.quantile + if global_norm: + arr_min = quantile_fn(arr, q) + arr_max = quantile_fn(arr, 1.0 - q) + else: + arr_min = quantile_fn(arr, q, **kw) # type: ignore + arr_max = quantile_fn(arr, 1.0 - q, **kw) # type: ignore + # diff + if global_norm: + diff = max(eps, arr_max - arr_min) + else: + if isinstance(arr, np.ndarray): + diff = np.maximum(eps, arr_max - arr_min) + else: + diff = torch.clip(arr_max - arr_min, min=eps) + arr = arr.clip(arr_min, arr_max) + out = (arr - arr_min) / diff + if not return_stats: + return out + if not global_norm: + arr_min = arr_min.item() + diff = diff.item() + else: + arr_min = arr_min.tolist() + diff = diff.tolist() + return out, dict(min=arr_min, diff=diff) + + +def quantile_normalize_from(arr: TArray, stats: Dict[str, Any]) -> TArray: + arr_min, diff = stats["min"], stats["diff"] + return (arr - arr_min) / diff + + +def recover_quantile_normalize_from(arr: TArray, stats: Dict[str, Any]) -> TArray: + arr_min, diff = stats["min"], stats["diff"] + return arr * diff + arr_min + + +def clip_normalize(arr: TArray) -> TArray: + import torch + import numpy as np + + fn = np if isinstance(arr, np.ndarray) else torch + if arr.dtype == fn.uint8: + return arr + return fn.clip(arr, 0.0, 1.0) + + +# will return at least 2d +def squeeze(arr: TArray) -> TArray: + n = arr.shape[0] + arr = arr.squeeze() # type: ignore + if n == 1: + arr = arr[None, ...] # type: ignore + return arr + + +def to_standard(arr: "np.ndarray") -> "np.ndarray": + import numpy as np + + if is_int(arr): + arr = arr.astype(np.int64) + elif is_float(arr): + arr = arr.astype(np.float32) + return arr + + +def to_torch(arr: "np.ndarray") -> "torch.Tensor": + import torch + + return torch.from_numpy(to_standard(arr)) + + +def to_numpy(tensor: "torch.Tensor") -> "np.ndarray": + return tensor.detach().cpu().numpy() + + +def to_device( + batch: tensor_dict_type, + device: Optional["torch.device"], + **kwargs: Any, +) -> tensor_dict_type: + import torch + + def to(v: Any) -> Any: + if isinstance(v, torch.Tensor): + return v.to(device, **kwargs) + if isinstance(v, dict): + return {vk: to(vv) for vk, vv in v.items()} + if isinstance(v, list): + return [to(vv) for vv in v] + return v + + if device is None: + return batch + return {k: to(v) for k, v in batch.items()} + + +def iou(logits: TArray, labels: TArray) -> TArray: + import numpy as np + + is_numpy = isinstance(logits, np.ndarray) + num_classes = logits.shape[1] + if num_classes == 1: + heat_map = sigmoid(logits) + elif num_classes == 2: + heat_map = softmax(logits)[:, [1]] # type: ignore + else: + raise ValueError("`IOU` only supports binary situations") + intersect = heat_map * labels + union = heat_map + labels - intersect + kwargs = {"axis" if is_numpy else "dim": tuple(range(1, len(intersect.shape)))} + return intersect.sum(**kwargs) / union.sum(**kwargs) + + +def corr( + predictions: TArray, + target: TArray, + weights: Optional[TArray] = None, + *, + get_diagonal: bool = False, + eps: float = 1.0e-8, +) -> TArray: + import torch + import numpy as np + + is_numpy = isinstance(predictions, np.ndarray) + keepdim_kw: Dict[str, Any] = {"keepdims" if is_numpy else "keepdim": True} + norm_fn = np.linalg.norm if is_numpy else torch.norm + matmul_fn = np.matmul if is_numpy else torch.matmul + sqrt_fn = np.sqrt if is_numpy else torch.sqrt + transpose_fn = np.transpose if is_numpy else torch.t + + w_sum = 0.0 if weights is None else weights.sum().item() + if weights is None: + mean = predictions.mean(0, **keepdim_kw) + else: + mean = (predictions * weights).sum(0, **keepdim_kw) / w_sum + vp = predictions - mean + if weights is None: + kw = keepdim_kw.copy() + kw["axis" if is_numpy else "dim"] = 0 + vp_norm = norm_fn(vp, 2, **kw) + else: + vp_norm = sqrt_fn((weights * (vp**2)).sum(0, **keepdim_kw)) + if predictions is target: + vp_norm_t = transpose_fn(vp_norm) + if weights is None: + mat = matmul_fn(transpose_fn(vp), vp) / (vp_norm * vp_norm_t) + else: + mat = matmul_fn(transpose_fn(weights * vp), vp) / (vp_norm * vp_norm_t) + else: + if weights is None: + target_mean = target.mean(0, **keepdim_kw) + else: + target_mean = (target * weights).sum(0, **keepdim_kw) / w_sum + vt = transpose_fn(target - target_mean) + if weights is None: + kw = keepdim_kw.copy() + kw["axis" if is_numpy else "dim"] = 1 + vt_norm = norm_fn(vt, 2, **kw) + else: + vt_norm = sqrt_fn((transpose_fn(weights) * (vt**2)).sum(1, **keepdim_kw)) + if weights is None: + mat = matmul_fn(vt, vp) / (vp_norm * vt_norm + eps) + else: + mat = matmul_fn(vt, weights * vp) / (vp_norm * vt_norm + eps) + if not get_diagonal: + return mat + if mat.shape[0] != mat.shape[1]: + raise ValueError( + "`get_diagonal` is set to True but the correlation matrix " + "is not a squared matrix, which is an invalid condition" + ) + return np.diag(mat) if is_numpy else mat.diag() + + +def get_one_hot(feature: Union[list, "np.ndarray"], dim: int) -> "np.ndarray": + """ + Get one-hot representation. + + Parameters + ---------- + feature : array-like, source data of one-hot representation. + dim : int, dimension of the one-hot representation. + + Returns + ------- + one_hot : np.ndarray, one-hot representation of `feature` + + """ + + import numpy as np + + one_hot = np.zeros([len(feature), dim], np.int64) + one_hot[range(len(one_hot)), np.asarray(feature, np.int64).ravel()] = 1 + return one_hot + + +def get_indices_from_another( + base: "np.ndarray", + segment: "np.ndarray", + *, + already_sorted: bool = False, +) -> "np.ndarray": + """ + Get `segment` elements' indices in `base`. + + Warnings + ---------- + All elements in segment should appear in base to ensure validity. + + Parameters + ---------- + base : np.ndarray, base array. + segment : np.ndarray, segment array. + already_sorted : bool, whether `base` is already sorted. + + Returns + ------- + indices : np.ndarray, positions where elements in `segment` appear in `base` + + Examples + ------- + >>> import numpy as np + >>> base, segment = np.arange(100), np.random.permutation(100)[:10] + >>> assert np.allclose(get_indices_from_another(base, segment), segment) + + """ + + import numpy as np + + if already_sorted: + return np.searchsorted(base, segment) + base_sorted_args = np.argsort(base) + positions = np.searchsorted(base[base_sorted_args], segment) + return base_sorted_args[positions] + + +class UniqueIndices(NamedTuple): + """ + unique : np.ndarray, unique values of the given array (`arr`). + unique_cnt : np.ndarray, counts of each unique value. + sorting_indices : np.ndarray, indices which can (stably) sort the given + array by its value. + split_arr : np.ndarray, array which can split the `sorting_indices` + to make sure that. Each portion of the split + indices belong & only belong to one of the + unique values. + """ + + unique: "np.ndarray" + unique_cnt: "np.ndarray" + sorting_indices: "np.ndarray" + split_arr: "np.ndarray" + + @property + def split_indices(self) -> List["np.ndarray"]: + import numpy as np + + return np.split(self.sorting_indices, self.split_arr) + + +def get_unique_indices(arr: "np.ndarray") -> UniqueIndices: + """ + Get indices for unique values of an array. + + Parameters + ---------- + arr : np.ndarray, target array which we wish to find indices of each unique value. + + Returns + ------- + UniqueIndices + + Examples + ------- + >>> import numpy as np + >>> arr = np.array([1, 2, 3, 2, 4, 1, 0, 1], np.int64) + >>> # UniqueIndices( + >>> # unique = array([0, 1, 2, 3, 4], dtype=int64), + >>> # unique_cnt = array([1, 3, 2, 1, 1], dtype=int64), + >>> # sorting_indices = array([6, 0, 5, 7, 1, 3, 2, 4], dtype=int64), + >>> # split_arr = array([1, 4, 6, 7], dtype=int64)) + >>> # split_indices = [array([6], dtype=int64), array([0, 5, 7], dtype=int64), + >>> # array([1, 3], dtype=int64), array([2], dtype=int64), + >>> # array([4], dtype=int64)] + >>> print(get_unique_indices(arr)) + + """ + + import numpy as np + + unique, unique_inv, unique_cnt = np.unique( + arr, + return_inverse=True, + return_counts=True, + ) + sorting_indices, split_arr = ( + np.argsort(unique_inv, kind="mergesort"), + np.cumsum(unique_cnt)[:-1], + ) + return UniqueIndices(unique, unique_cnt, sorting_indices, split_arr) + + +def get_counter_from_arr(arr: "np.ndarray") -> Counter: + """ + Get `Counter` of an array. + + Parameters + ---------- + arr : np.ndarray, target array which we wish to get `Counter` from. + + Returns + ------- + Counter + + Examples + ------- + >>> import numpy as np + >>> arr = np.array([1, 2, 3, 2, 4, 1, 0, 1], np.int64) + >>> # Counter({1: 3, 2: 2, 0: 1, 3: 1, 4: 1}) + >>> print(get_counter_from_arr(arr)) + + """ + + import numpy as np + + return Counter(dict(zip(*np.unique(arr, return_counts=True)))) + + +def allclose(*arrays: "np.ndarray", **kwargs: Any) -> bool: + """ + Perform `np.allclose` to `arrays` one by one. + + Parameters + ---------- + arrays : np.ndarray, target arrays. + **kwargs : keyword arguments which will be passed into `np.allclose`. + + Returns + ------- + allclose : bool + + """ + + import numpy as np + + for i, arr in enumerate(arrays[:-1]): + if not np.allclose(arr, arrays[i + 1], **kwargs): + return False + return True + + +class StrideArray: + def __init__( + self, + arr: "np.ndarray", + *, + copy: bool = False, + writable: Optional[bool] = None, + ): + self.arr = arr + self.shape = arr.shape + self.num_dim = len(self.shape) + self.strides = arr.strides + self.copy = copy + if writable is None: + writable = copy + self.writable = writable + + def __str__(self) -> str: + return self.arr.__str__() + + def __repr__(self) -> str: + return self.arr.__repr__() + + def _construct( + self, + shapes: Tuple[int, ...], + strides: Tuple[int, ...], + ) -> "np.ndarray": + from numpy.lib.stride_tricks import as_strided + + arr = self.arr.copy() if self.copy else self.arr + return as_strided( + arr, + shape=shapes, + strides=strides, + writeable=self.writable, + ) + + @staticmethod + def _get_output_dim(in_dim: int, window: int, stride: int) -> int: + return (in_dim - window) // stride + 1 + + def roll(self, window: int, *, axis: int, stride: int = 1) -> "np.ndarray": + while axis < 0: + axis += self.num_dim + target_dim = self.shape[axis] + rolled_dim = self._get_output_dim(target_dim, window, stride) + if rolled_dim <= 0: + msg = f"window ({window}) is too large for target dimension ({target_dim})" + raise ValueError(msg) + # shapes + rolled_shapes = tuple(self.shape[:axis]) + (rolled_dim, window) + if axis < self.num_dim - 1: + rolled_shapes = rolled_shapes + self.shape[axis + 1 :] + # strides + previous_strides = tuple(self.strides[:axis]) + target_stride = (self.strides[axis] * stride,) + latter_strides = tuple(self.strides[axis:]) + rolled_strides = previous_strides + target_stride + latter_strides + # construct + return self._construct(rolled_shapes, rolled_strides) + + def patch( + self, + patch_w: int, + patch_h: Optional[int] = None, + *, + h_stride: int = 1, + w_stride: int = 1, + h_axis: int = -2, + ) -> "np.ndarray": + if self.num_dim < 2: + raise ValueError("`patch` requires input with at least 2d") + while h_axis < 0: + h_axis += self.num_dim + w_axis = h_axis + 1 + if patch_h is None: + patch_h = patch_w + h_shape, w_shape = self.shape[h_axis], self.shape[w_axis] + if h_shape < patch_h: + msg = f"patch_h ({patch_h}) is too large for target dimension ({h_shape})" + raise ValueError(msg) + if w_shape < patch_w: + msg = f"patch_w ({patch_w}) is too large for target dimension ({w_shape})" + raise ValueError(msg) + # shapes + patched_h_dim = self._get_output_dim(h_shape, patch_h, h_stride) + patched_w_dim = self._get_output_dim(w_shape, patch_w, w_stride) + patched_dim: Tuple[int, ...] + patched_dim = (patched_h_dim, patched_w_dim) + patched_dim = patched_dim + (patch_h, patch_w) + patched_shapes = tuple(self.shape[:h_axis]) + patched_dim + if w_axis < self.num_dim - 1: + patched_shapes = patched_shapes + self.shape[w_axis + 1 :] + # strides + arr_h_stride, arr_w_stride = self.strides[h_axis], self.strides[w_axis] + previous_strides = tuple(self.strides[:h_axis]) + target_stride: Tuple[int, ...] + target_stride = (arr_h_stride * h_stride, arr_w_stride * w_stride) + target_stride = target_stride + (arr_h_stride, arr_w_stride) + latter_strides = tuple(self.strides[w_axis + 1 :]) + patched_strides = previous_strides + target_stride + latter_strides + # construct + return self._construct(patched_shapes, patched_strides) + + def repeat(self, k: int, axis: int = -1) -> "np.ndarray": + while axis < 0: + axis += self.num_dim + target_dim = self.shape[axis] + if target_dim != 1: + raise ValueError("`repeat` can only be applied on axis with dim == 1") + # shapes + repeated_shapes = tuple(self.shape[:axis]) + (k,) + if axis < self.num_dim - 1: + repeated_shapes = repeated_shapes + self.shape[axis + 1 :] + # strides + previous_strides = tuple(self.strides[:axis]) + target_stride = (0,) + latter_strides = tuple(self.strides[axis + 1 :]) + repeated_strides = previous_strides + target_stride + latter_strides + # construct + return self._construct(repeated_shapes, repeated_strides) + + +class SharedArray: + value: "np.ndarray" + + def __init__( + self, + name: str, + dtype: Union[type, "np.dtype"], + shape: Union[List[int], Tuple[int, ...]], + *, + create: bool = True, + data: Optional["np.ndarray"] = None, + ): + import numpy as np + from multiprocessing.shared_memory import SharedMemory + + self.name = name + self.dtype = dtype + self.shape = shape + if create: + d_size = np.dtype(dtype).itemsize * np.prod(shape).item() + self._shm = SharedMemory(name, create=True, size=int(round(d_size))) + else: + if data is not None: + raise ValueError("`data` should not be provided when `create` is False") + self._shm = SharedMemory(name) + self.value = np.ndarray(shape=shape, dtype=dtype, buffer=self._shm.buf) + if data is not None: + self.value[:] = data[:] + + def close(self) -> None: + self._shm.close() + + def destroy(self) -> None: + self._shm.close() + self._shm.unlink() + + @classmethod + def from_data(cls, data: "np.ndarray") -> "SharedArray": + return cls(random_hash()[:16], data.dtype, data.shape, data=data) + + +def to_labels(logits: "np.ndarray", threshold: Optional[float] = None) -> "np.ndarray": + # binary classification + if logits.shape[-1] == 2: + logits = logits[..., [1]] - logits[..., [0]] + if logits.shape[-1] == 1: + if threshold is None: + threshold = 0.5 + logit_threshold = math.log(threshold / (1.0 - threshold)) + return (logits > logit_threshold).astype(int) + return logits.argmax(1)[..., None] + + +def get_full_logits(logits: "np.ndarray") -> "np.ndarray": + import numpy as np + + # binary classification + if logits.shape[-1] == 1: + logits = np.concatenate([-logits, logits], axis=-1) + return logits + + +def make_grid(arr: arr_type, n_row: Optional[int] = None) -> "torch.Tensor": + import torchvision + import numpy as np + + if isinstance(arr, np.ndarray): + arr = to_torch(arr) + if n_row is None: + n_row = math.ceil(math.sqrt(len(arr))) + return torchvision.utils.make_grid(arr, n_row) + + +class NpSafeSerializer: + size_file = "size.txt" + array_file = "array.npy" + + @classmethod + def save( + cls, + folder: TPath, + data: Union["np.ndarray", Callable[[], "np.ndarray"]], + *, + verbose: bool = True, + ) -> None: + import numpy as np + from filelock import FileLock + + folder = to_path(folder) + with FileLock(folder / "NpSafeSerializer.lock", timeout=30000): + if cls.try_load(folder, no_load=True) is None: + folder.mkdir(parents=True, exist_ok=True) + array_path = folder / cls.array_file + with timeit(f"save '{folder}'", enabled=verbose): + if not isinstance(data, np.ndarray): + data = data() + np.save(array_path, data) + with (folder / cls.size_file).open("w") as f: + f.write(str(get_file_size(array_path))) + + @classmethod + def load(cls, folder: TPath, *, mmap_mode: Optional[str] = None) -> "np.ndarray": + import numpy as np + + return np.load(to_path(folder) / cls.array_file, mmap_mode=mmap_mode) # type: ignore + + @classmethod + def try_load( + cls, + folder: TPath, + *, + mmap_mode: Optional[str] = None, + no_load: bool = False, + **kwargs: Any, + ) -> Optional["np.ndarray"]: + import numpy as np + + folder = to_path(folder) + array_path = folder / cls.array_file + if not array_path.exists(): + return None + size_path = folder / cls.size_file + if not size_path.is_file(): + return None + with (folder / cls.size_file).open("r") as f: + try: + size = int(f.read().strip()) + except ValueError: + return None + if size != get_file_size(array_path): + return None + if no_load: + return np.zeros(0) + return np.load(array_path, mmap_mode=mmap_mode, **kwargs) # type: ignore + + @classmethod + def load_with( + cls, + folder: TPath, + init_fn: Callable[[], "np.ndarray"], + *, + mmap_mode: Optional[str] = None, + no_load: bool = False, + verbose: bool = True, + **kwargs: Any, + ) -> "np.ndarray": + """ + This method uses `FileLock` to ensure that only one rank will save the array. + > It will also ensure that all ranks will load the array after it's saved. The load + procedure will even be executed on the rank that saved the array in order to make + use of the `mmap_mode` feature at the first time. + """ + + load_func = lambda: cls.try_load( + folder, + mmap_mode=mmap_mode, + no_load=no_load, + **kwargs, + ) + array = load_func() + if array is None: + folder = to_path(folder) + folder.mkdir(parents=True, exist_ok=True) + cls.save(folder, init_fn, verbose=verbose) + array = load_func() + if array is None: + raise RuntimeError(f"failed to load array from '{folder}'") + return array + + @classmethod + def cleanup(cls, folder: TPath) -> None: + from filelock import FileLock + + folder = to_path(folder) + with FileLock(folder / "NpSafeSerializer.lock", timeout=30000): + (folder / cls.array_file).unlink(missing_ok=True) + (folder / cls.size_file).unlink(missing_ok=True) diff --git a/cfdraw/core/toolkit/console.py b/cfdraw/core/toolkit/console.py new file mode 100644 index 00000000..b53c048e --- /dev/null +++ b/cfdraw/core/toolkit/console.py @@ -0,0 +1,68 @@ +from typing import Any +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from functools import lru_cache + +if TYPE_CHECKING: + from rich.status import Status + from rich.console import Console + + +@lru_cache +def get_console() -> "Console": + from rich.console import Console + + return Console() + + +def log(msg: str, *args: Any, _stack_offset: int = 2, **kwargs: Any) -> None: + get_console().log(msg, *args, _stack_offset=_stack_offset, **kwargs) + + +def debug(msg: str, *args: Any, prefix: str = "", **kwargs: Any) -> None: + kwargs.setdefault("_stack_offset", 3) + log(f"[grey42]{prefix}{msg}[/grey42]", *args, **kwargs) + + +def warn(msg: str, *args: Any, prefix: str = "Warning: ", **kwargs: Any) -> None: + kwargs.setdefault("_stack_offset", 3) + log(f"[yellow]{prefix}{msg}[/yellow]", *args, **kwargs) + + +def deprecated(msg: str, *args: Any, **kwargs: Any) -> None: + kwargs.setdefault("_stack_offset", 4) + warn(msg, *args, prefix="DeprecationWarning: ", **kwargs) + + +def error(msg: str, *args: Any, prefix: str = "Error: ", **kwargs: Any) -> None: + kwargs.setdefault("_stack_offset", 3) + log(f"[red]{prefix}{msg}[/red]", *args, **kwargs) + + +def print(msg: str, *args: Any, **kwargs: Any) -> None: + get_console().print(msg, *args, **kwargs) + + +def rule(title: str, **kwargs: Any) -> None: + get_console().rule(title, **kwargs) + + +def ask( + question: str, + choices: Optional[List[str]] = None, + *, + default: Optional[str] = None, + **kwargs: Any, +) -> str: + from rich.prompt import Prompt + + kwargs = kwargs.copy() + kwargs["choices"] = choices + if default is not None: + kwargs["default"] = default + return Prompt.ask(question, **kwargs) + + +def status(msg: str, **kwargs: Any) -> "Status": + return get_console().status(msg, **kwargs) diff --git a/cfdraw/core/toolkit/constants.py b/cfdraw/core/toolkit/constants.py new file mode 100644 index 00000000..3ade28e5 --- /dev/null +++ b/cfdraw/core/toolkit/constants.py @@ -0,0 +1,2 @@ +TIME_FORMAT = "%Y-%m-%d_%H-%M-%S-%f" +WEB_ERR_CODE = 406 diff --git a/cfdraw/core/toolkit/cv.py b/cfdraw/core/toolkit/cv.py new file mode 100644 index 00000000..4832def1 --- /dev/null +++ b/cfdraw/core/toolkit/cv.py @@ -0,0 +1,284 @@ +import math +import base64 + +from io import BytesIO +from typing import Tuple +from typing import Union +from typing import Optional +from typing import NamedTuple +from typing import TYPE_CHECKING +from dataclasses import dataclass + +from .array import to_torch +from .types import TArray +from .types import arr_type +from .geometry import is_close +from .geometry import Matrix2D +from .geometry import Matrix2DProperties + +if TYPE_CHECKING: + from PIL import Image + from numpy import ndarray + from PIL.Image import Image as TImage + + +class ReadImageResponse(NamedTuple): + image: "ndarray" + alpha: Optional["ndarray"] + original: "TImage" + anchored: "TImage" + to_masked: Optional["TImage"] + original_size: Tuple[int, int] + anchored_size: Tuple[int, int] + + +def to_rgb(image: "TImage", color: Tuple[int, int, int] = (255, 255, 255)) -> "TImage": + from PIL import Image + + if image.mode == "CMYK": + return image.convert("RGB") + split = image.split() + if len(split) < 4: + return image.convert("RGB") + background = Image.new("RGB", image.size, color) + background.paste(image, mask=split[3]) + return background + + +def to_uint8(normalized_img: TArray) -> TArray: + import torch + import numpy as np + + if isinstance(normalized_img, np.ndarray): + return (np.clip(normalized_img * 255.0, 0.0, 255.0)).astype(np.uint8) # type: ignore + return torch.clamp(normalized_img * 255.0, 0.0, 255.0).to(torch.uint8) + + +def to_alpha_channel(image: "TImage") -> "TImage": + if image.mode == "RGBA": + return image.split()[3] + return image.convert("L") + + +def np_to_bytes(img_arr: "ndarray") -> bytes: + import numpy as np + from PIL import Image + + if img_arr.dtype != np.uint8: + img_arr = to_uint8(img_arr) + bytes_io = BytesIO() + Image.fromarray(img_arr).save(bytes_io, format="PNG") + return bytes_io.getvalue() + + +def restrict_wh(w: int, h: int, max_wh: int) -> Tuple[int, int]: + max_original_wh = max(w, h) + if max_original_wh <= max_wh: + return w, h + wh_ratio = w / h + if wh_ratio >= 1: + return max_wh, round(max_wh / wh_ratio) + return round(max_wh * wh_ratio), max_wh + + +def get_suitable_size(n: int, anchor: int) -> int: + if n <= anchor: + return anchor + mod = n % anchor + return n - mod + int(mod > 0.5 * anchor) * anchor + + +def read_image( + image: Union[str, "TImage"], + max_wh: Optional[int], + *, + anchor: Optional[int], + to_gray: bool = False, + to_mask: bool = False, + resample: "Image.Resampling" = "auto", + normalize: bool = True, + to_torch_fmt: bool = True, +) -> ReadImageResponse: + import numpy as np + from PIL import Image + + if isinstance(image, str): + image = Image.open(image) + alpha = None + original = image + if image.mode == "RGBA": + alpha = image.split()[3] + if not to_mask and not to_gray: + image = to_rgb(image) + else: + if to_mask and to_gray: + raise ValueError("`to_mask` & `to_gray` should not be True simultaneously") + if to_mask and image.mode == "RGBA": + image = alpha + else: + image = image.convert("L") + original_w, original_h = image.size + to_masked = image if to_mask else None + if max_wh is None: + w, h = original_w, original_h + else: + w, h = restrict_wh(original_w, original_h, max_wh) + if anchor is not None: + w, h = map(get_suitable_size, (w, h), (anchor, anchor)) + if w != original_w or h != original_h: + if resample == "auto": + resample = Image.Resampling.LANCZOS + image = image.resize((w, h), resample=resample) + anchored = image + anchored_size = w, h + image = np.array(image) + if normalize: + image = image.astype(np.float32) / 255.0 + if alpha is not None: + alpha = np.array(alpha)[None, None] + if normalize: + alpha = alpha.astype(np.float32) / 255.0 + if to_torch_fmt: + if to_mask or to_gray: + image = image[None, None] + else: + image = image[None].transpose(0, 3, 1, 2) + return ReadImageResponse( + image, + alpha, + original, + anchored, + to_masked, + (original_w, original_h), + anchored_size, + ) + + +def save_images(arr: arr_type, path: str, n_row: Optional[int] = None) -> None: + import torchvision + import numpy as np + + if isinstance(arr, np.ndarray): + arr = to_torch(arr) + if n_row is None: + n_row = math.ceil(math.sqrt(len(arr))) + torchvision.utils.save_image(arr, path, normalize=True, nrow=n_row) + + +def to_base64(image: "TImage") -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/png;base64,{img_str}" + + +def from_base64(base64_string: str) -> "TImage": + from PIL import Image + + base64_string = base64_string.split("base64,")[1] + return Image.open(BytesIO(base64.b64decode(base64_string))) + + +@dataclass +class ImageBox: + l: int + t: int + r: int + b: int + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageBox): + return False + return all(map(is_close, self.tuple, other.tuple)) + + @property + def w(self) -> int: + return self.r - self.l + + @property + def h(self) -> int: + return self.b - self.t + + @property + def wh_ratio(self) -> float: + return self.w / self.h + + @property + def tuple(self) -> Tuple[int, int, int, int]: + return self.l, self.t, self.r, self.b + + @property + def matrix(self) -> Matrix2D: + return Matrix2D.from_properties( + Matrix2DProperties(x=self.l, y=self.t, w=self.w, h=self.h) + ) + + def copy(self) -> "ImageBox": + return ImageBox(*self.tuple) + + def crop(self, image: TArray) -> TArray: + return image[self.t : self.b + 1, self.l : self.r + 1] # type: ignore + + def pad( + self, + padding: int, + *, + w: Optional[int] = None, + h: Optional[int] = None, + ) -> "ImageBox": + l, t, r, b = self.tuple + l = max(0, l - padding) + r += padding + if w is not None: + r = min(r, w) + t = max(0, t - padding) + b += padding + if h is not None: + b = min(b, h) + return ImageBox(l, t, r, b) + + def to_square( + self, + *, + w: Optional[int] = None, + h: Optional[int] = None, + expand: bool = True, + ) -> "ImageBox": + l, t, r, b = self.tuple + bw, bh = r - l, b - t + diff = abs(bw - bh) + if diff == 0: + return self.copy() + if expand: + if bw > bh: + t = max(0, t - diff // 2) + b = t + bw + if h is not None: + b = min(b, h) + else: + l = max(0, l - diff // 2) + r = l + bh + if w is not None: + r = min(r, w) + else: + if bw > bh: + l += diff // 2 + r = l + bh + if w is not None: + r = min(r, w) + else: + t += diff // 2 + b = t + bw + if h is not None: + b = min(b, h) + return ImageBox(l, t, r, b) + + @classmethod + def from_mask(cls, uint8_mask: "ndarray", threshold: int = 0) -> "ImageBox": + import numpy as np + + ys, xs = np.where(uint8_mask > threshold) + ys, xs = np.where(uint8_mask) + if len(ys) == 0: + return cls(0, 0, 0, 0) + return cls(xs.min().item(), ys.min().item(), xs.max().item(), ys.max().item()) diff --git a/cfdraw/core/toolkit/data_structures.py b/cfdraw/core/toolkit/data_structures.py new file mode 100644 index 00000000..ee47fe7a --- /dev/null +++ b/cfdraw/core/toolkit/data_structures.py @@ -0,0 +1,295 @@ +import gc + +from typing import Any +from typing import Dict +from typing import List +from typing import Type +from typing import Tuple +from typing import Generic +from typing import TypeVar +from typing import Callable +from typing import Iterator +from typing import Optional +from datetime import datetime + +from . import console +from .misc import sort_dict_by_value +from .constants import TIME_FORMAT + + +TTypes = TypeVar("TTypes") +TBundle = TypeVar("TBundle", bound="Bundle") +TItemData = TypeVar("TItemData") +TPoolItem = TypeVar("TPoolItem", bound="IPoolItem") +PItemInit = Callable[[], TPoolItem] + + +class Item(Generic[TItemData]): + def __init__(self, key: str, data: TItemData) -> None: + self.key = key + self.data = data + + +class Bundle(Generic[TItemData]): + def __init__(self, *, no_mapping: bool = False) -> None: + """ + * use mapping is fast at the cost of doubled memory. + * for the `queue` use case, mapping is not needed because all operations + focus on the first item. + + Details + ------- + * no_mapping = False + * get : O(1) + * push : O(1) + * remove : O(1) (if not found) / O(n) + * no_mapping = True + * get : O(n) + * push : O(1) + * remove : O(n) + * `queue` (both cases, so use no_mapping = True to save memory) + * get : O(1) + * push : O(1) + * remove : O(1) + """ + + self._items: List[Item[TItemData]] = [] + self._mapping: Optional[Dict[str, Item[TItemData]]] = None if no_mapping else {} + + def __len__(self) -> int: + return len(self._items) + + def __iter__(self) -> Iterator[Item[TItemData]]: + return iter(self._items) + + def __contains__(self, key: str) -> bool: + return self.get(key) is not None + + @property + def first(self) -> Optional[Item[TItemData]]: + if self.is_empty: + return None + return self._items[0] + + @property + def last(self) -> Optional[Item[TItemData]]: + if self.is_empty: + return None + return self._items[-1] + + @property + def is_empty(self) -> bool: + return not self._items + + def get(self, key: str) -> Optional[Item[TItemData]]: + if self._mapping is not None: + return self._mapping.get(key) + for item in self._items: + if key == item.key: + return item + return None + + def get_index(self, index: int) -> Item[TItemData]: + return self._items[index] + + def push(self: TBundle, item: Item[TItemData]) -> TBundle: + if self.get(item.key) is not None: + raise ValueError(f"item '{item.key}' already exists") + self._items.append(item) + if self._mapping is not None: + self._mapping[item.key] = item + return self + + def remove(self, key: str) -> Optional[Item[TItemData]]: + if self._mapping is None: + for i, item in enumerate(self._items): + if key == item.key: + self._items.pop(i) + return item + return None + item = self._mapping.pop(key, None) # type: ignore + if item is not None: + for i, _item in enumerate(self._items): + if key == _item.key: + self._items.pop(i) + break + return item + + +class Types(Generic[TTypes]): + def __init__(self) -> None: + self._types: Dict[str, Type[TTypes]] = {} + + def __iter__(self) -> Iterator[str]: + return iter(self._types) + + def __setitem__(self, key: str, value: Type[TTypes]) -> None: + self._types[key] = value + + def make(self, key: str, *args: Any, **kwargs: Any) -> Optional[TTypes]: + t = self._types.get(key) + return None if t is None else t(*args, **kwargs) + + def items(self) -> Iterator[Tuple[str, Type[TTypes]]]: + return self._types.items() # type: ignore + + def values(self) -> Iterator[Type[TTypes]]: + return self._types.values() # type: ignore + + +class IPoolItem: + """ + Life cycle of a pool item: + + (without context) init -> collect + (with context) init -> (everytime) load -> (everytime) unload -> collect + + """ + + def load(self, **kwargs: Any) -> None: + """Will be called everytime the pool loads the item with context""" + + def unload(self) -> None: + """Will be called everytime the pool finishes using the item with context""" + + def collect(self) -> None: + """Will be called when the pool removes the item""" + + +class PoolItemContext: + def __init__(self, item: Any, **kwargs: Any) -> None: + self.item = item + self.kwargs = kwargs + + def __enter__(self) -> Any: + load_fn = getattr(self.item, "load", None) + if load_fn is not None: + load_fn(**self.kwargs) + return self.item + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + unload_fn = getattr(self.item, "unload", None) + if unload_fn is not None: + unload_fn() + + +class PoolItemManager(Generic[TPoolItem]): + _item: Optional[TPoolItem] + + def __init__( + self, + init_fn: PItemInit, + *, + init: bool = False, + force_keep: bool = False, + ): + self.init_fn = init_fn + self.use_time = datetime.now() + self.force_keep = force_keep + self._item = init_fn() if init or force_keep else None + + @property + def ready(self) -> bool: + return self._item is not None + + def get(self) -> TPoolItem: + self.use_time = datetime.now() + if self._item is None: + self._item = self.init_fn() + return self._item + + def use(self, **kwargs: Any) -> PoolItemContext: + self.use_time = datetime.now() + if self._item is None: + self._item = self.init_fn() + return PoolItemContext(self._item, **kwargs) + + def collect(self) -> None: + collect_fn = getattr(self._item, "collect", None) + if collect_fn is not None: + collect_fn() + del self._item + self._item = None + gc.collect() + + +class Pool(Generic[TPoolItem]): + t_manager = PoolItemManager + + pool: Dict[str, PoolItemManager[TPoolItem]] + + # set `limit` to negative values to indicate 'no limit' + def __init__(self, limit: int = -1, *, allow_duplicate: bool = False): + self.pool = {} + self.limit = limit + self.allow_duplicate = allow_duplicate + if limit == 0: + raise ValueError( + "limit should either be negative " + "(which indicates 'no limit') or be positive" + ) + + def __contains__(self, key: str) -> bool: + return key in self.pool + + @property + def activated(self) -> Dict[str, PoolItemManager[TPoolItem]]: + return {k: m for k, m in self.pool.items() if m.ready and not m.force_keep} + + def register(self, key: str, init_fn: PItemInit, **kwargs: Any) -> None: + """ + Register a new item to the pool. + + This method will create a new item manager and store it in the pool. + > `kwargs` will be passed to the item manager's constructor. + """ + + if key in self.pool: + if self.allow_duplicate: + return + raise ValueError(f"key '{key}' already exists") + init = self.limit < 0 or len(self.activated) < self.limit + manager: PoolItemManager = self.t_manager(init_fn, init=init, **kwargs) + self.pool[key] = manager + + def get(self, key: str) -> TPoolItem: + """ + Get a registered item from the pool without context. + + - If `limit` is reached, this method will try to remove the 'earliest' item. + """ + + return self._fetch(key).get() + + def use(self, key: str, **kwargs: Any) -> PoolItemContext: + """ + Use a registered item from the pool with context. + + - If `limit` is reached, this method will try to remove the 'earliest' item. + > `kwargs` will be passed to the item's `load` method, if it exists. + """ + + return self._fetch(key).use(**kwargs) + + def _fetch(self, key: str) -> PoolItemManager: + """ + Fetch the item manager from the pool. + + - If `limit` is reached, this method will try to remove the 'earliest' item. + """ + + target = self.pool.get(key) + if target is None: + raise ValueError(f"key '{key}' does not exist") + if not target.ready: + # need to remove earliest item before using the target + use_times = {k: m.use_time for k, m in self.activated.items()} + earliest_key = list(sort_dict_by_value(use_times).keys())[0] + earliest = self.pool[earliest_key] + earliest.collect() + get_time_str = lambda m: datetime.strftime(m.use_time, TIME_FORMAT) + console.log( + f"'{earliest_key}' (last updated: {get_time_str(earliest)}) is collected " + f"to make room for '{key}' (last updated: {get_time_str(target)})" + ) + return target diff --git a/cfdraw/core/toolkit/geometry.py b/cfdraw/core/toolkit/geometry.py new file mode 100644 index 00000000..111bc2fb --- /dev/null +++ b/cfdraw/core/toolkit/geometry.py @@ -0,0 +1,633 @@ +import math + +from enum import Enum +from typing import List +from typing import Tuple +from typing import Union +from typing import TypeVar +from typing import Optional +from typing import TYPE_CHECKING +from pydantic import BaseModel +from dataclasses import dataclass + +if TYPE_CHECKING: + import numpy as np + + +class PivotType(str, Enum): + LT = "lt" + TOP = "top" + RT = "rt" + LEFT = "left" + CENTER = "center" + RIGHT = "right" + LB = "lb" + BOTTOM = "bottom" + RB = "rb" + + +# start from left-top, clockwise, plus the center point +outer_pivots: List[PivotType] = [ + PivotType.LT, + PivotType.TOP, + PivotType.RT, + PivotType.RIGHT, + PivotType.RB, + PivotType.BOTTOM, + PivotType.LB, + PivotType.LEFT, +] +all_pivots: List[PivotType] = outer_pivots + [PivotType.CENTER] +# start from left-top, clockwise, four corner points +corner_pivots: List[PivotType] = [ + PivotType.LT, + PivotType.RT, + PivotType.RB, + PivotType.LB, +] +edge_pivots: List[PivotType] = [ + PivotType.TOP, + PivotType.RIGHT, + PivotType.BOTTOM, + PivotType.LEFT, +] +mid_pivots: List[PivotType] = edge_pivots + [PivotType.CENTER] + + +def is_close(a: float, b: float, *, atol: float = 1.0e-6, rtol: float = 1.0e-4) -> bool: + diff = abs(a - b) + a = max(a, 1.0e-8) + b = max(b, 1.0e-8) + if diff >= atol or abs(diff / a) >= rtol or abs(diff / b) >= rtol: + return False + return True + + +@dataclass +class Point: + x: float + y: float + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Point): + return False + return all(map(is_close, self.tuple, other.tuple)) + + def __add__(self, other: "Point") -> "Point": + return Point(self.x + other.x, self.y + other.y) + + def __sub__(self, other: "Point") -> "Point": + return Point(self.x - other.x, self.y - other.y) + + def __rmatmul__(self, other: "Matrix2D") -> "Point": + x, y = self.x, self.y + a, b, c, d, e, f = other.tuple + return Point(x=a * x + c * y + e, y=b * x + d * y + f) + + @property + def tuple(self) -> Tuple[float, float]: + return self.x, self.y + + @property + def theta(self) -> float: + return math.atan2(self.y, self.x) + + def rotate(self, theta: float) -> "Point": + l = math.sqrt(self.x**2 + self.y**2) + theta += self.theta + return Point(l * math.cos(theta), l * math.sin(theta)) + + def inside(self, box: "Matrix2D") -> bool: + x, y = (box.inverse @ self).tuple + return 0 <= x <= 1 and 0 <= y <= 1 + + @classmethod + def origin(cls) -> "Point": + return cls(x=0.0, y=0.0) + + +@dataclass +class Line: + start: Point + end: Point + + def intersect(self, other: "Line", extendable: bool = False) -> Optional[Point]: + x1, y1 = self.start.tuple + x2, y2 = self.end.tuple + x3, y3 = other.start.tuple + x4, y4 = other.end.tuple + x13 = x1 - x3 + x21 = x2 - x1 + x43 = x4 - x3 + y13 = y1 - y3 + y21 = y2 - y1 + y43 = y4 - y3 + denom = y43 * x21 - x43 * y21 + if is_close(denom, 0): + return None + uA = (x43 * y13 - y43 * x13) / denom + uB = (x21 * y13 - y21 * x13) / denom + if extendable or (0 <= uA <= 1 and 0 <= uB <= 1): + return Point(x1 + uA * (x2 - x1), y1 + uA * (y2 - y1)) + return None + + def distance_to(self, target_line: "Line") -> float: + x1, y1 = self.start.tuple + x2, y2 = self.end.tuple + x4, y4 = target_line.end.tuple + dy = y1 - y2 or 10e-10 + k = (x1 - x2) / dy + d = (k * (y2 - y4) + x4 - x2) / math.sqrt(1 + k**2) + return d + + +class Matrix2DProperties(BaseModel): + x: float + y: float + w: float + h: float + theta: float = 0.0 + skew_x: float = 0.0 + skew_y: float = 0.0 + + +TMatMul = TypeVar("TMatMul", bound=Union[Point, "Matrix2D"]) + + +class ExpandType(str, Enum): + IOU = "iou" + FIX_W = "fix_w" + FIX_H = "fix_h" + + +@dataclass +class Box: + left: float + top: float + right: float + bottom: float + + +class Matrix2D(BaseModel): + a: float + b: float + c: float + d: float + e: float + f: float + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Matrix2D): + return False + return all(map(is_close, self.tuple, other.tuple)) + + def __matmul__(self, other: TMatMul) -> TMatMul: + if isinstance(other, Point): + return other.__rmatmul__(self) # type: ignore + if isinstance(other, Matrix2D): + a1, b1, c1, d1, e1, f1 = self.tuple + a2, b2, c2, d2, e2, f2 = other.tuple + return Matrix2D( # type: ignore + a=a1 * a2 + c1 * b2, + b=b1 * a2 + d1 * b2, + c=a1 * c2 + c1 * d2, + d=b1 * c2 + d1 * d2, + e=a1 * e2 + c1 * f2 + e1, + f=b1 * e2 + d1 * f2 + f1, + ) + msg = f"unsupported operand type(s) for @: 'Matrix2D' and '{type(other)}'" + raise TypeError(msg) + + @property + def tuple(self) -> Tuple[float, float, float, float, float, float]: + return self.a, self.b, self.c, self.d, self.e, self.f + + @property + def x(self) -> float: + return self.e + + @property + def y(self) -> float: + return self.f + + @property + def position(self) -> Point: + return Point(self.e, self.f) + + @property + def w(self) -> float: + return math.sqrt(self.a**2 + self.b**2) + + @property + def h(self) -> float: + return (self.a * self.d - self.b * self.c) / max(self.w, 1.0e-12) + + @property + def wh(self) -> Tuple[float, float]: + w = self.w + h = (self.a * self.d - self.b * self.c) / max(w, 1.0e-12) + return w, h + + @property + def abs_wh(self) -> Tuple[float, float]: + w, h = self.wh + return w, abs(h) + + @property + def wh_ratio(self) -> float: + w, h = self.wh + h_sign = math.copysign(1.0, h) + return w / max(abs(h), 1.0e-12) * h_sign + + @property + def abs_wh_ratio(self) -> float: + w, h = self.abs_wh + return w / max(h, 1.0e-12) + + @property + def area(self) -> float: + w, h = self.abs_wh + return w * h + + @property + def theta(self) -> float: + return -math.atan2(self.b, self.a) + + @property + def shear(self) -> float: + a, b, c, d = self.a, self.b, self.c, self.d + return math.atan2(a * c + b * d, a**2 + b**2) + + @property + def translation(self) -> Point: + return Point(self.e, self.f) + + @property + def determinant(self) -> float: + return self.a * self.d - self.b * self.c + + @property + def matrix(self) -> "np.ndarray": + import numpy as np + + return np.array([[self.a, self.c, self.e], [self.b, self.d, self.f]]) + + @property + def inverse(self) -> "Matrix2D": + a, b, c, d, e, f = self.tuple + ad = a * d + bc = b * c + return Matrix2D( + a=d / (ad - bc), + b=b / (bc - ad), + c=c / (bc - ad), + d=a / (ad - bc), + e=(d * e - c * f) / (bc - ad), + f=(b * e - a * f) / (ad - bc), + ) + + @property + def lt(self) -> Point: + return Point(self.e, self.f) + + @property + def top(self) -> Point: + return Point(0.5 * self.a + self.e, 0.5 * self.b + self.f) + + @property + def rt(self) -> Point: + return Point(self.a + self.e, self.b + self.f) + + @property + def right(self) -> Point: + return self @ Point(1.0, 0.5) + + @property + def rb(self) -> Point: + return self @ Point(1.0, 1.0) + + @property + def bottom(self) -> Point: + return self @ Point(0.5, 1.0) + + @property + def lb(self) -> Point: + return Point(self.c + self.e, self.d + self.f) + + @property + def left(self) -> Point: + return Point(0.5 * self.c + self.e, 0.5 * self.d + self.f) + + @property + def center(self) -> Point: + return self @ Point(0.5, 0.5) + + def pivot(self, pivot: PivotType) -> Point: + return getattr(self, pivot.value) + + # lt -> rt -> rb -> lb + @property + def corner_points(self) -> List[Point]: + return [self.pivot(pivot) for pivot in corner_pivots] + + # top -> right -> bottom -> left -> center + @property + def mid_points(self) -> List[Point]: + return [self.pivot(pivot) for pivot in mid_pivots] + + # lt -> top -> rt -> right -> rb -> bottom -> lb -> left -> center + @property + def all_points(self) -> List[Point]: + return [self.pivot(pivot) for pivot in all_pivots] + + # top -> right -> bottom -> left + @property + def edges(self) -> List[Line]: + corners = self.corner_points + return [Line(corner, corners[(i + 1) % 4]) for i, corner in enumerate(corners)] + + @property + def outer_most(self) -> Box: + import numpy as np + + corner_points = self.corner_points + xs = np.array([point.x for point in corner_points]) + ys = np.array([point.y for point in corner_points]) + left, right = xs.min().item(), xs.max().item() + top, bottom = ys.min().item(), ys.max().item() + return Box(left, top, right, bottom) + + @property + def bounding(self) -> "Matrix2D": + box = self.outer_most + return Matrix2D.from_properties( + Matrix2DProperties( + x=box.left, + y=box.top, + w=box.right - box.left, + h=box.bottom - box.top, + ) + ) + + @property + def css_property(self) -> str: + return f"matrix({self.a},{self.b},{self.c},{self.d},{self.e},{self.f})" + + @property + def no_move(self) -> "Matrix2D": + return Matrix2D(a=self.a, b=self.b, c=self.c, d=self.d, e=0.0, f=0.0) + + @property + def no_skew(self) -> "Matrix2D": + return self @ Matrix2D.skew_matrix(-self.shear, 0.0, Point.origin()) + + @property + def no_scale(self) -> "Matrix2D": + a, b, c, d, e, f = self.tuple + w, h = self.wh + return Matrix2D(a=a / w, b=b / w, c=c / h, d=d / h, e=e, f=f) + + @property + def no_scale_but_flip(self) -> "Matrix2D": + a, b, c, d, e, f = self.tuple + w, h = self.abs_wh + return Matrix2D(a=a / w, b=b / w, c=c / h, d=d / h, e=e, f=f) + + @property + def no_rotation(self) -> "Matrix2D": + return self.rotate(-self.theta, self.translation) + + @property + def no_move_scale_but_flip(self) -> "Matrix2D": + return self.no_scale_but_flip.no_move + + def scale( + self, + scale: float, + scale_y: Optional[float] = None, + center: Optional[Point] = None, + ) -> "Matrix2D": + if scale_y is None: + scale_y = scale + if center is None: + return Matrix2D( + a=self.a * scale, + b=self.b * scale, + c=self.c * scale_y, + d=self.d * scale_y, + e=self.e, + f=self.f, + ) + return Matrix2D.scale_matrix(scale, scale_y, center=center) @ self + + def scale_to( + self, + scale: float, + scale_y: Optional[float] = None, + center: Optional[Point] = None, + ) -> "Matrix2D": + if scale_y is None: + scale_y = scale + w, h = self.wh + return self.scale(scale / w, scale_y / h, center=center) + + def flip( + self, + flip_x: bool, + flip_y: bool, + center: Optional[Point] = None, + ) -> "Matrix2D": + return Matrix2D.flip_matrix(flip_x, flip_y, center) @ self + + def rotate(self, theta: float, center: Optional[Point] = None) -> "Matrix2D": + if math.isclose(theta, 0.0): + return self.model_copy() + return Matrix2D.rotation_matrix(theta, center) @ self + + def rotate_to(self, theta: float, center: Optional[Point] = None) -> "Matrix2D": + return self.rotate(theta - self.theta, center) + + def move(self, point: Point) -> "Matrix2D": + a, b, c, d, e, f = self.tuple + return Matrix2D(a=a, b=b, c=c, d=d, e=point.x + e, f=point.y + f) + + def move_to(self, point: Point) -> "Matrix2D": + a, b, c, d, _, _ = self.tuple + return Matrix2D(a=a, b=b, c=c, d=d, e=point.x, f=point.y) + + def set_w(self, w: float) -> "Matrix2D": + properties = self.decompose() + properties.w = w + return Matrix2D.from_properties(properties) + + def set_h(self, h: float) -> "Matrix2D": + properties = self.decompose() + properties.h = h + return Matrix2D.from_properties(properties) + + def set_wh(self, w: float, h: float) -> "Matrix2D": + properties = self.decompose() + properties.w = w + properties.h = h + return Matrix2D.from_properties(properties) + + def set_wh_ratio( + self, + wh_ratio: float, + *, + type: ExpandType = ExpandType.IOU, + pivot: PivotType = PivotType.CENTER, + ) -> "Matrix2D": + o_pivot = self.pivot(pivot) + w, h = self.wh + abs_h = abs(h) + h_sign = math.copysign(1.0, h) + if type == "fix_w": + new_w = w + new_h = h_sign * w / wh_ratio + elif type == "fix_h": + new_w = abs_h * wh_ratio + new_h = h + else: + area = w * abs_h + new_w = math.sqrt(area * wh_ratio) + new_h = h_sign * area / new_w + bbox = self.set_wh(new_w, new_h) + delta = o_pivot - bbox.pivot(pivot) + return bbox.move(delta) + + def decompose(self) -> Matrix2DProperties: + w, h = self.wh + a, b, c, d, e, f = self.tuple + return Matrix2DProperties( + x=e, + y=f, + w=w, + h=h, + theta=self.theta, + skew_x=math.atan2(a * c + b * d, w**2), + ) + + @classmethod + def skew_matrix( + cls, + skew_x: float, + skew_y: float, + center: Optional[Point] = None, + ) -> "Matrix2D": + center = center or Point.origin() + tx = math.tan(skew_x) + ty = math.tan(skew_y) + return cls(a=1, b=ty, c=tx, d=1, e=-tx * center.y, f=-ty * center.x) + + @classmethod + def scale_matrix( + cls, + w: float, + h: float, + center: Optional[Point] = None, + ) -> "Matrix2D": + center = center or Point.origin() + return cls(a=w, b=0, c=0, d=h, e=center.x * (1 - w), f=center.y * (1 - h)) + + @classmethod + def rotation_matrix( + cls, + theta: float, + center: Optional[Point] = None, + ) -> "Matrix2D": + center = center or Point.origin() + sin = math.sin(theta) + cos = math.cos(theta) + return cls( + a=cos, + b=-sin, + c=sin, + d=cos, + e=(1.0 - cos) * center.x - center.y * sin, + f=(1.0 - cos) * center.y + center.x * sin, + ) + + @classmethod + def move_matrix(cls, x: float, y: float) -> "Matrix2D": + return cls(a=1, b=0, c=0, d=1, e=x, f=y) + + @classmethod + def flip_matrix( + self, + flip_x: bool, + flip_y: bool, + center: Optional[Point] = None, + ) -> "Matrix2D": + fx = -1 if flip_x else 1 + fy = -1 if flip_y else 1 + return Matrix2D.scale_matrix(fx, fy, center) + + @classmethod + def identical(cls) -> "Matrix2D": + return cls(a=1, b=0, c=0, d=1, e=0, f=0) + + @classmethod + def from_properties(cls, properties: Matrix2DProperties) -> "Matrix2D": + return ( + cls.move_matrix(properties.x, properties.y) + @ cls.rotation_matrix(properties.theta) + @ cls.scale_matrix(properties.w, properties.h) + @ cls.skew_matrix(properties.skew_x, properties.skew_y) + ) + + @classmethod + def from_css_property(cls, css_property: str) -> "Matrix2D": + css_property = css_property.replace("matrix(", "").replace(")", "") + a, b, c, d, e, f = [float(x.strip()) for x in css_property.split(",")] + return cls(a=a, b=b, c=c, d=d, e=e, f=f) + + @classmethod + def get_bounding_of(cls, bboxes: List["Matrix2D"]) -> "Matrix2D": + if not bboxes: + return cls.identical() + boxes = [bbox.outer_most for bbox in bboxes] + lx = min(box.left for box in boxes) + rx = max(box.right for box in boxes) + ty = min(box.top for box in boxes) + by = max(box.bottom for box in boxes) + return Matrix2D.from_properties( + Matrix2DProperties(x=lx, y=ty, w=rx - lx, h=by - ty) + ) + + +class HitTest: + @staticmethod + def line_line(a: Line, b: Line) -> bool: + return a.intersect(b) is not None + + @staticmethod + def line_box(a: Line, b: Matrix2D) -> bool: + edges = b.edges + for edge in edges: + if HitTest.line_line(a, edge): + return True + return False + + @staticmethod + def box_box(a: Matrix2D, b: Matrix2D) -> bool: + b_edges = b.edges + for b_edge in b_edges: + if HitTest.line_box(b_edge, a): + return True + if a.position.inside(b): + return True + if b.position.inside(a): + return True + return False + + +__all__ = [ + "PivotType", + "ExpandType", + "Point", + "Line", + "Box", + "Matrix2D", + "HitTest", +] diff --git a/cfdraw/core/toolkit/misc.py b/cfdraw/core/toolkit/misc.py new file mode 100644 index 00000000..19ef5202 --- /dev/null +++ b/cfdraw/core/toolkit/misc.py @@ -0,0 +1,1497 @@ +import os +import sys +import json +import math +import time +import random +import shutil +import decimal +import inspect +import hashlib +import operator +import unicodedata + +from abc import abstractmethod +from abc import ABC +from abc import ABCMeta +from typing import Any +from typing import Set +from typing import Dict +from typing import List +from typing import Type +from typing import Tuple +from typing import Union +from typing import Generic +from typing import TypeVar +from typing import Callable +from typing import Iterable +from typing import Optional +from typing import Protocol +from typing import Coroutine +from typing import NamedTuple +from typing import TYPE_CHECKING +from typing import ContextManager +from pathlib import Path +from argparse import Namespace +from datetime import datetime +from datetime import timedelta +from functools import reduce +from collections import OrderedDict +from dataclasses import asdict +from dataclasses import fields +from dataclasses import dataclass +from dataclasses import is_dataclass +from dataclasses import Field + +from . import console +from .types import TPath +from .types import TConfig +from .types import arr_type +from .types import np_dict_type +from .constants import TIME_FORMAT + +if TYPE_CHECKING: + from accelerate import InitProcessGroupKwargs + + +# torch distributed utils + + +@dataclass +class DDPInfo: + """ + A dataclass for storing Distributed Data Parallel (DDP) information. + + Attributes + ---------- + rank : int + The rank of the current process in the DDP group. + world_size : int + The total number of processes in the DDP group. + local_rank : int + The rank of the current process within its machine. + + """ + + rank: int + world_size: int + local_rank: int + + +def get_ddp_info() -> Optional[DDPInfo]: + """ + Get DDP information from the environment variables. + + Returns + ------- + Optional[DDPInfo] + The DDP information if the relevant environment variables are set, otherwise None. + + Examples + -------- + >>> get_ddp_info() + + """ + + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + return DDPInfo(rank, world_size, local_rank) + return None + + +def is_ddp() -> bool: + return get_ddp_info() is not None + + +def is_rank_0() -> bool: + """ + Check if the rank is 0. + + Returns + ------- + bool + True if the rank is 0 or DDP is not being used, False otherwise. + + Examples + -------- + >>> is_rank_0() + True + + """ + + ddp_info = get_ddp_info() + return ddp_info is None or ddp_info.rank == 0 + + +def is_local_rank_0() -> bool: + """ + Check if the local rank is 0. + + Returns + ------- + bool + True if the local rank is 0 or DDP is not being used, False otherwise. + + Examples + -------- + >>> is_local_rank_0() + True + + """ + + ddp_info = get_ddp_info() + return ddp_info is None or ddp_info.local_rank == 0 + + +def get_world_size() -> int: + """ + Get the world size. + + Returns + ------- + int + The world size if DDP is being used, otherwise 1. + + Examples + -------- + >>> get_world_size() + 1 + + """ + + ddp_info = get_ddp_info() + return 1 if ddp_info is None else ddp_info.world_size + + +def is_dist_initialized() -> bool: + import torch.distributed as dist + + return dist.is_initialized() + + +def is_fsdp() -> bool: + from accelerate import PartialState + from accelerate import DistributedType + + if not is_dist_initialized(): + return False + return PartialState().distributed_type == DistributedType.FSDP + + +def init_process_group(*, cpu: bool, handler: "InitProcessGroupKwargs") -> None: + from accelerate import PartialState + + PartialState(cpu, **handler.to_kwargs()) + + +def wait_for_everyone() -> None: + from accelerate.utils import wait_for_everyone as accelerate_wait + + if is_dist_initialized(): + accelerate_wait() + + +# util functions + + +T = TypeVar("T") +FAny = TypeVar("FAny", bound=Callable[..., Any]) +FNone = TypeVar("FNone", bound=Callable[..., None]) +T_co = TypeVar("T_co", covariant=True) +TDict = TypeVar("TDict", bound="dict") +TRetryResponse = TypeVar("TRetryResponse") +TFutureResponse = TypeVar("TFutureResponse") + + +class Fn(Protocol[T_co]): + def __call__(self, *args: Any, **kwargs: Any) -> T_co: + pass + + +def walk( + root: str, + hierarchy_callback: Callable[[List[str], str], None], + filter_extensions: Optional[Set[str]] = None, +) -> None: + from tqdm import tqdm + + walked = list(os.walk(root)) + for folder, _, files in tqdm(walked, desc="folders", position=0, mininterval=1): + for file in tqdm(files, desc="files", position=1, leave=False, mininterval=1): + if filter_extensions is not None: + if not any(file.endswith(ext) for ext in filter_extensions): + continue + hierarchy = folder.split(os.path.sep) + [file] + hierarchy_callback(hierarchy, os.path.join(folder, file)) + + +def parse_config(config: TConfig) -> Dict[str, Any]: + if config is None: + return {} + if isinstance(config, (str, Path)): + with open(config, "r") as f: + return json.load(f) + return shallow_copy_dict(config) + + +def check_requires(fn: Any, name: str, strict: bool = True) -> bool: + if isinstance(fn, type): + fn = fn.__init__ # type: ignore + signature = inspect.signature(fn) + for k, param in signature.parameters.items(): + if not strict and param.kind is inspect.Parameter.VAR_KEYWORD: + return True + if k == name: + if param.kind is inspect.Parameter.VAR_POSITIONAL: + return False + return True + return False + + +def get_requirements(fn: Any) -> List[str]: + remove_first = False + if isinstance(fn, type): + fn = fn.__init__ # type: ignore + remove_first = True # remove `self` + requirements = [] + signature = inspect.signature(fn) + for k, param in signature.parameters.items(): + if param.kind is inspect.Parameter.VAR_KEYWORD: + continue + if param.kind is inspect.Parameter.VAR_POSITIONAL: + continue + if param.default is not inspect.Parameter.empty: + continue + requirements.append(k) + if remove_first: + requirements = requirements[1:] + return requirements + + +def filter_kw( + fn: Callable, + kwargs: Dict[str, Any], + *, + strict: bool = False, +) -> Dict[str, Any]: + kw = {} + for k, v in kwargs.items(): + if check_requires(fn, k, strict): + kw[k] = v + return kw + + +def safe_execute(fn: Fn[T], kw: Dict[str, Any], *, strict: bool = False) -> T: + return fn(**filter_kw(fn, kw, strict=strict)) + + +def safe_instantiate(cls: Type[T], kw: Dict[str, Any], *, strict: bool = False) -> T: + return cls(**filter_kw(cls, kw, strict=strict)) + + +def get_num_positional_args(fn: Callable) -> Union[int, float]: + signature = inspect.signature(fn) + counter = 0 + for param in signature.parameters.values(): + if param.kind is inspect.Parameter.VAR_POSITIONAL: + return math.inf + if param.kind is inspect.Parameter.POSITIONAL_ONLY: + counter += 1 + elif param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + counter += 1 + return counter + + +def prepare_workspace_from( + workspace: str, + *, + timeout: timedelta = timedelta(30), + make: bool = True, +) -> str: + current_time = datetime.now() + if os.path.isdir(workspace): + for stuff in os.listdir(workspace): + if not os.path.isdir(os.path.join(workspace, stuff)): + continue + try: + stuff_time = datetime.strptime(stuff, TIME_FORMAT) + stuff_delta = current_time - stuff_time + if stuff_delta > timeout: + console.warn(f"{stuff} will be removed (already {stuff_delta} ago)") + shutil.rmtree(os.path.join(workspace, stuff)) + except: + pass + workspace = os.path.join(workspace, current_time.strftime(TIME_FORMAT)) + if make: + os.makedirs(workspace) + return workspace + + +def get_sub_workspaces(root: TPath) -> List[Path]: + root = to_path(root) + if not root.is_dir(): + return [] + all_workspaces = [] + for stuff in root.iterdir(): + if not stuff.is_dir(): + continue + try: + datetime.strptime(stuff.name, TIME_FORMAT) + all_workspaces.append(stuff) + except: + pass + return all_workspaces + + +def get_sorted_workspaces(root: TPath) -> List[Path]: + sub_workspaces = get_sub_workspaces(root) + return sorted(sub_workspaces, key=lambda x: datetime.strptime(x.name, TIME_FORMAT)) + + +def get_latest_workspace(root: TPath) -> Optional[Path]: + sorted_workspaces = get_sorted_workspaces(root) + if not sorted_workspaces: + return None + return sorted_workspaces[-1] + + +def sort_dict_by_value(d: Dict[Any, Any], *, reverse: bool = False) -> OrderedDict: + sorted_items = sorted([(v, k) for k, v in d.items()], reverse=reverse) + return OrderedDict({item[1]: item[0] for item in sorted_items}) + + +def parse_args(args: Any) -> Namespace: + return Namespace(**{k: None if not v else v for k, v in args.__dict__.items()}) + + +def get_arguments( + *, + num_back: int = 0, + pop_class_attributes: bool = True, +) -> Dict[str, Any]: + frame = inspect.currentframe() + if frame is None: + raise ValueError("`get_arguments` should be called in a frame") + frame = frame.f_back + for i in range(num_back): + if frame is None: + raise ValueError(f"`get_arguments` failed at {i}th frame backword") + frame = frame.f_back + if frame is None: + raise ValueError(f"`get_arguments` failed at {num_back}th frame backword") + arguments = inspect.getargvalues(frame)[-1] + if pop_class_attributes: + arguments.pop("self", None) + arguments.pop("__class__", None) + return arguments + + +def timestamp(*, simplify: bool = False, ensure_different: bool = False) -> str: + """ + Return current timestamp. + + Parameters + ---------- + simplify : bool. If True, format will be simplified to 'year-month-day'. + ensure_different : bool. If True, format will include millisecond. + + Returns + ------- + timestamp : str + + """ + + now = datetime.now() + if simplify: + return now.strftime(TIME_FORMAT[:8]) + if ensure_different: + time.sleep(1.0e-6) # ensure different by sleep 1 tick + return now.strftime(TIME_FORMAT) + return now.strftime(TIME_FORMAT[:-3]) + + +def prod(iterable: Iterable) -> float: + """Return cumulative production of an iterable.""" + + return float(reduce(operator.mul, iterable, 1)) + + +def hash_code(code: str) -> str: + """Return hash code for a string.""" + + return hashlib.md5(code.encode()).hexdigest() + + +def hash_dict(d: Dict[str, Any], *, static_keys: bool = False) -> str: + """ + Return a consistent hash code for an arbitrary dict. + * `static_keys` is used to control whether to include dict keys in the hash code. + Default is False, which means the hash code will be consistent even if the dict + has different keys but same values. + """ + + def _hash(_d: Dict[str, Any]) -> str: + sorted_keys = sorted(_d) + hashes = [] + for k in sorted_keys: + v = _d[k] + if not static_keys: + hashes.append(str(k)) + if isinstance(v, dict): + hashes.append(_hash(v)) + elif isinstance(v, set): + hashes.append(hash_code(str(sorted(v)))) + else: + hashes.append(hash_code(str(v))) + return hash_code("".join(hashes)) + + return _hash(d) + + +def hash_str_dict( + d: Dict[str, str], + *, + key_order: Optional[List[str]] = None, + static_keys: bool = False, +) -> str: + """A specific fast path for `hash_dict` when all values are strings.""" + + if key_order is None: + key_order = sorted(d) + if static_keys: + return hash_code("$?^^?$".join([d[k] for k in key_order])) + return hash_code("$?^^?$".join([f"{k}$?%%?${d[k]}" for k in key_order])) + + +def random_hash() -> str: + return hash_code(str(random.random())) + + +def prefix_dict(d: TDict, prefix: str) -> TDict: + """Prefix every key in dict `d` with `prefix`.""" + + return {f"{prefix}_{k}": v for k, v in d.items()} # type: ignore + + +def shallow_copy_dict(d: TDict) -> TDict: + def _copy(d_: T) -> T: + if isinstance(d_, list): + return [_copy(item) for item in d_] # type: ignore + if isinstance(d_, dict): + return {k: _copy(v) for k, v in d_.items()} # type: ignore + return d_ + + return _copy(d) + + +def update_dict(src_dict: dict, tgt_dict: dict) -> dict: + """ + Update tgt_dict with src_dict. + * Notice that changes will happen only on keys which src_dict holds. + + Parameters + ---------- + src_dict : dict + tgt_dict : dict + + Returns + ------- + tgt_dict : dict + + """ + + for k, v in src_dict.items(): + tgt_v = tgt_dict.get(k) + if tgt_v is None: + tgt_dict[k] = v + elif not isinstance(v, dict): + tgt_dict[k] = v + else: + update_dict(v, tgt_v) + return tgt_dict + + +def fix_float_to_length(num: float, length: int) -> str: + """Change a float number to string format with fixed length.""" + + ctx = decimal.Context() + ctx.prec = 2 * length + d = ctx.create_decimal(repr(num)) + str_num = format(d, "f").lower() + if str_num == "nan": + return f"{str_num:^{length}s}" + idx = str_num.find(".") + if idx == -1: + diff = length - len(str_num) + if diff <= 0: + return str_num + if diff == 1: + return f"{str_num}." + return f"{str_num}.{'0' * (diff - 1)}" + length = max(length, idx) + return str_num[:length].ljust(length, "0") + + +def truncate_string_to_length(string: str, length: int) -> str: + """Truncate a string to make sure its length not exceeding a given length.""" + + if len(string) <= length: + return string + half_length = int(0.5 * length) - 1 + head = string[:half_length] + tail = string[-half_length:] + return f"{head}{'.' * (length - 2 * half_length)}{tail}" + + +def grouped(iterable: Iterable, n: int, *, keep_tail: bool = False) -> List[tuple]: + """Group an iterable every `n` elements.""" + + if not keep_tail: + return list(zip(*[iter(iterable)] * n)) + with batch_manager(iterable, batch_size=n, max_batch_size=n) as manager: + return [tuple(batch) for batch in manager] + + +def grouped_into(iterable: Iterable, n: int) -> List[tuple]: + """Group an iterable into `n` groups.""" + + elements = list(iterable) + num_elements = len(elements) + num_elem_per_group = int(math.ceil(num_elements / n)) + results: List[tuple] = [] + split_idx = num_elements + n - n * num_elem_per_group + start = 0 + for _ in range(split_idx): + end = start + num_elem_per_group + results.append(tuple(elements[start:end])) + start = end + for _ in range(split_idx, n): + end = start + num_elem_per_group - 1 + results.append(tuple(elements[start:end])) + start = end + return results + + +def is_numeric(s: Any) -> bool: + """Check whether `s` is a number.""" + + try: + s = float(s) + return True + except (TypeError, ValueError): + try: + unicodedata.numeric(s) + return True + except (TypeError, ValueError): + return False + + +def register_core( + name: str, + global_dict: Dict[str, T], + *, + allow_duplicate: bool = False, + before_register: Optional[Callable] = None, + after_register: Optional[Callable] = None, +) -> Callable[[T], T]: + def _register(cls: T) -> T: + if before_register is not None: + before_register(cls) + registered = global_dict.get(name) + if registered is not None and not allow_duplicate: + console.warn( + f"'{name}' has already registered " + f"in the given global dict ({global_dict})" + ) + return cls + global_dict[name] = cls + if after_register is not None: + after_register(cls) + return cls + + return _register + + +def get_err_msg(err: Exception) -> str: + return " | ".join(map(repr, sys.exc_info()[:2] + (str(err),))) + + +async def retry( + fn: Callable[[], Coroutine[None, None, TRetryResponse]], + num_retry: Optional[int] = None, + *, + health_check: Optional[Callable[[TRetryResponse], bool]] = None, + error_verbose_fn: Optional[Callable[[TRetryResponse], None]] = None, +) -> TRetryResponse: + counter = 0 + if num_retry is None: + num_retry = 1 + while counter < num_retry: + try: + res = await fn() + if health_check is None or health_check(res): + if counter > 0: + console.log(f"succeeded after {counter} retries") + return res + if error_verbose_fn is not None: + error_verbose_fn(res) + else: + raise ValueError("response did not pass health check") + except Exception as e: + console.warn(f"{e}, retrying ({counter + 1})") + finally: + counter += 1 + raise ValueError(f"failed after {num_retry} retries") + + +async def offload(future: Coroutine[Any, Any, TFutureResponse]) -> TFutureResponse: + import asyncio + from concurrent.futures import ThreadPoolExecutor + + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + return await loop.run_in_executor( + executor, + lambda new_loop, f: new_loop.run_until_complete(f), + asyncio.new_event_loop(), + future, + ) + + +def compress(absolute_folder: TPath, remove_original: bool = True) -> None: + shutil.make_archive(str(absolute_folder), "zip", absolute_folder) + if remove_original: + shutil.rmtree(absolute_folder) + + +def to_path(path: TPath) -> Path: + if isinstance(path, Path): + return path + return Path(path) + + +class FileInfo(NamedTuple): + """ + Represents information about a (remote) file, often generated by + `check_available` / `get_file_info`. + + Attributes + ---------- + sha : str + The sha of the (remote) file. + st_size : int + The size of the (remote) file in bytes. + download_url : Optional[str] + The download url of the (remote) file, if available. + + """ + + sha: str + st_size: int + download_url: Optional[str] = None + + +def get_file_size(path: TPath) -> int: + """ + Get the size of a file. + + Parameters + ---------- + path : TPath + The path of the file. + + Returns + ------- + int + The size of the file in bytes. + + Examples + -------- + >>> get_file_size(Path("...")) + + """ + + return to_path(path).stat().st_size + + +def get_file_info(path: TPath) -> FileInfo: + """ + Get the information of a file. + + Parameters + ---------- + path : TPath + The path of the file. + + Returns + ------- + FileInfo + The FileInfo object containing information about the file. + + Examples + -------- + >>> get_file_info(Path("...")) + + """ + + path = to_path(path) + with path.open("rb") as f: + sha = hashlib.sha256(f.read()).hexdigest() + return FileInfo(sha, get_file_size(path)) + + +def check_sha_with(path: TPath, tgt_sha: str) -> bool: + """ + Check if the SHA256 hash of a file matches a target hash. + + Parameters + ---------- + path : TPath + The path of the file. + tgt_sha : str + The target SHA256 hash to compare with. + + Returns + ------- + bool + True if the file's hash matches the target hash, False otherwise. + + Examples + -------- + >>> check_sha_with(Path("..."), "...") + + """ + + return get_file_info(path).sha == tgt_sha + + +def to_set(inp: Any) -> Set: + if isinstance(inp, set): + return inp + if isinstance(inp, (list, tuple, dict)): + return set(inp) + return {inp} + + +def wait_for_everyone_at_end(fn: FAny) -> FAny: + def _wrapper(*args: Any, **kwargs: Any) -> Any: + result = fn(*args, **kwargs) + wait_for_everyone() + return result + + return _wrapper # type: ignore + + +def only_execute_on_rank0(fn: FNone) -> FNone: + @wait_for_everyone_at_end + def _wrapper(*args: Any, **kwargs: Any) -> None: + if is_rank_0(): + fn(*args, **kwargs) + + return _wrapper # type: ignore + + +def only_execute_on_local_rank0(fn: FNone) -> FNone: + @wait_for_everyone_at_end + def _wrapper(*args: Any, **kwargs: Any) -> None: + if is_local_rank_0(): + fn(*args, **kwargs) + + return _wrapper # type: ignore + + +def get_memory_size(obj: Any, seen: Optional[Set] = None) -> int: + import numpy as np + + try: + from pandas import Index + from pandas import DataFrame + except ImportError: + Index = DataFrame = None + + if seen is None: + seen = set() + obj_id = id(obj) + if obj_id in seen: + return 0 + + if isinstance(obj, np.ndarray): + return obj.nbytes + if Index is not None and isinstance(obj, Index): + return obj.memory_usage(deep=True) + if DataFrame is not None and isinstance(obj, DataFrame): + return obj.memory_usage(deep=True).sum() + + size = sys.getsizeof(obj) + seen.add(obj_id) + + if isinstance(obj, dict): + for k, v in obj.items(): + size += get_memory_size(k, seen) + size += get_memory_size(v, seen) + elif hasattr(obj, "__dict__"): + size += get_memory_size(obj.__dict__, seen) + elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): + for i_obj in obj: + size += get_memory_size(i_obj, seen) + + return size + + +def get_memory_mb(obj: Any) -> float: + return get_memory_size(obj) / 1024 / 1024 + + +# util modules + + +TRegister = TypeVar("TRegister", bound="WithRegister", covariant=True) +TTRegister = TypeVar("TTRegister", bound=Type["WithRegister"]) +T_s = TypeVar("T_s", bound="ISerializable", covariant=True) +T_sd = TypeVar("T_sd", bound="ISerializableDataClass", covariant=True) +TSerializable = TypeVar("TSerializable", bound="ISerializable", covariant=True) +T_sa = TypeVar("T_sa", bound="ISerializableArrays", covariant=True) +TSArrays = TypeVar("TSArrays", bound="ISerializableArrays", covariant=True) +TSDataClass = TypeVar("TSDataClass", bound="ISerializableDataClass", covariant=True) +TDataClass = TypeVar("TDataClass", bound="DataClassBase") + + +class DataClassBase: + """ + To use this base class, you should not only inherit from `DataClassBase`, + but also decorate your class with `@dataclass`. + """ + + @property + def fields(self) -> List[Field]: + return fields(self) # type: ignore + + @property + def field_names(self) -> List[str]: + return [f.name for f in self.fields] + + @property + def attributes(self) -> List[Any]: + return [getattr(self, name) for name in self.field_names] + + def asdict(self) -> Dict[str, Any]: + def _to_item(ins: Any) -> Any: + if isinstance(ins, DataClassBase): + return ins.asdict() + if isinstance(ins, dict): + return {k: _to_item(v) for k, v in ins.items()} + if isinstance(ins, list): + return [_to_item(item) for item in ins] + if is_dataclass(ins): + return asdict(ins) + return ins + + return {k: _to_item(v) for k, v in zip(self.field_names, self.attributes)} + + def as_modified_dict( + self, + *, + focuses: Optional[Union[str, List[str]]] = None, + excludes: Optional[Union[str, List[str]]] = None, + ) -> Dict[str, Any]: + cls = self.__class__ + requirements = set(get_requirements(cls)) + d = {k: getattr(self, k) for k in requirements} + defaults = cls(**d) + excludes_set = to_set(excludes) + if focuses is not None: + focus_set = to_set(focuses) + for k in self.field_names: + if k not in focus_set: + excludes_set.add(k) + modified_dict = { + k: getattr(self, k) + for k in self.field_names + if k not in excludes_set + and (k in requirements or getattr(self, k) != getattr(defaults, k)) + } + modified_dict = { + k: v.as_modified_dict() if isinstance(v, DataClassBase) else v + for k, v in modified_dict.items() + } + return modified_dict + + def copy(self: TDataClass) -> TDataClass: + return self.__class__.construct(self.asdict()) + + def update_with(self: TDataClass, other: TDataClass) -> TDataClass: + d = update_dict(other.asdict(), self.asdict()) + updated = self.__class__.construct(d) + for field_name in self.field_names: + setattr(self, field_name, getattr(updated, field_name)) + return self + + def to_hash( + self, + *, + focuses: Optional[Union[str, List[str]]] = None, + excludes: Optional[Union[str, List[str]]] = None, + ) -> str: + return hash_dict(self.as_modified_dict(focuses=focuses, excludes=excludes)) + + @classmethod + def construct(cls: Type[TDataClass], d: Dict[str, Any]) -> TDataClass: + def _construct(t: Type, d: Dict[str, Any]) -> Any: + instance = safe_instantiate(t, d) + if not is_dataclass(instance): + return instance + for field in fields(instance): + if is_dataclass(field.type): + setattr( + instance, + field.name, + _construct(field.type, getattr(instance, field.name)), + ) + continue + t_origin = getattr(field.type, "__origin__", None) + if t_origin is None: + continue + if t_origin is list and hasattr(field.type, "__args__"): + t_value = field.type.__args__[0] + if is_dataclass(t_value): + setattr( + instance, + field.name, + [ + _construct(t_value, item) + for item in getattr(instance, field.name) + ], + ) + continue + if t_origin is dict and hasattr(field.type, "__args__"): + t_value = field.type.__args__[1] + if is_dataclass(t_value): + setattr( + instance, + field.name, + { + k: _construct(t_value, v) + for k, v in getattr(instance, field.name).items() + }, + ) + continue + return instance + + return _construct(cls, d) + + +class WithRegister(Generic[TRegister]): + d: Dict[str, Type[TRegister]] + __identifier__: str + + @classmethod + def get(cls: Type[TRegister], name: str) -> Type[TRegister]: + return cls.d[name] + + @classmethod + def has(cls, name: str) -> bool: + return name in cls.d + + @classmethod + def make( + cls: Type[TRegister], + name: str, + config: Dict[str, Any], + *, + ensure_safe: bool = False, + ) -> TRegister: + base = cls.get(name) + if not ensure_safe: + return base(**config) # type: ignore + return safe_instantiate(base, config) + + @classmethod + def make_multiple( + cls: Type[TRegister], + names: Union[str, List[str]], + configs: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + *, + ensure_safe: bool = False, + ) -> List[TRegister]: + if configs is None: + configs = {} + if isinstance(names, str): + assert isinstance(configs, dict) + return cls.make(names, configs, ensure_safe=ensure_safe) # type: ignore + if not isinstance(configs, list): + configs = [configs.get(name, {}) for name in names] + return [ + cls.make(name, shallow_copy_dict(config), ensure_safe=ensure_safe) + for name, config in zip(names, configs) + ] + + @classmethod + def register( + cls, + name: str, + *, + allow_duplicate: bool = False, + ) -> Callable[[TTRegister], TTRegister]: + def before(cls_: TTRegister) -> None: + cls_.__identifier__ = name + + return register_core( # type: ignore + name, + cls.d, + allow_duplicate=allow_duplicate, + before_register=before, + ) + + @classmethod + def check_subclass(cls, name: str) -> bool: + return issubclass(cls.d[name], cls) + + +@dataclass +class JsonPack(DataClassBase): + type: str + info: Dict[str, Any] + + +class ISerializable( + Generic[TSerializable], + WithRegister[TSerializable], + metaclass=ABCMeta, +): + # abstract + + @abstractmethod + def to_info(self) -> Dict[str, Any]: + pass + + @abstractmethod + def from_info(self: T_s, info: Dict[str, Any]) -> T_s: + pass + + # optional callbacks + + def after_load(self) -> None: + pass + + # api + + def to_pack(self) -> JsonPack: + return JsonPack(self.__identifier__, self.to_info()) + + @classmethod + def from_pack(cls: Type[TSerializable], pack: Dict[str, Any]) -> TSerializable: + obj: TSerializable = cls.make(pack["type"], {}) + obj.from_info(pack["info"]) + obj.after_load() + return obj + + def to_json(self) -> str: + return json.dumps(self.to_pack().asdict()) + + @classmethod + def from_json(cls: Type[TSerializable], json_string: str) -> TSerializable: + return cls.from_pack(json.loads(json_string)) + + def copy(self: T_s) -> T_s: + copied = self.__class__() + copied.from_info(shallow_copy_dict(self.to_info())) + return copied + + +class ISerializableArrays( + Generic[TSArrays], + ISerializable[TSArrays], + metaclass=ABCMeta, +): + @abstractmethod + def to_npd(self) -> np_dict_type: + pass + + @abstractmethod + def from_npd(self: T_sa, npd: np_dict_type) -> T_sa: + pass + + def copy(self: T_sa) -> T_sa: + copied = super().copy() + copied.from_npd(shallow_copy_dict(self.to_npd())) + return copied + + +class ISerializableDataClass( # type: ignore + Generic[TSDataClass], + DataClassBase, + ISerializable[TSDataClass], +): + def to_info(self) -> Dict[str, Any]: + return self.asdict() + + def from_info(self: T_sd, info: Dict[str, Any]) -> T_sd: + new = self.__class__.construct(info) + self.update_with(new) + return self + + +class Serializer: + id_file: str = "id.txt" + info_file: str = "info.json" + npd_folder: str = "npd" + + @classmethod + def save_info( + cls, + folder: TPath, + *, + info: Optional[Dict[str, Any]] = None, + serializable: Optional[ISerializable] = None, + ) -> None: + folder = to_path(folder) + folder.mkdir(parents=True, exist_ok=True) + if info is None and serializable is None: + raise ValueError("either `info` or `serializable` should be provided") + if info is None: + info = serializable.to_info() # type: ignore + with (folder / cls.info_file).open("w") as f: + json.dump(info, f) + + @classmethod + def load_info(cls, folder: TPath) -> Dict[str, Any]: + return cls.try_load_info(folder, strict=True) # type: ignore + + @classmethod + def try_load_info( + cls, + folder: TPath, + *, + strict: bool = False, + ) -> Optional[Dict[str, Any]]: + folder = to_path(folder) + info_path = folder / cls.info_file + if not info_path.is_file(): + if not strict: + return None + raise ValueError(f"'{info_path}' does not exist") + with info_path.open("r") as f: + info = json.load(f) + return info + + @classmethod + def save_npd( + cls, + folder: TPath, + *, + npd: Optional[np_dict_type] = None, + serializable: Optional[ISerializableArrays] = None, + ) -> None: + import numpy as np + + folder = to_path(folder) + folder.mkdir(parents=True, exist_ok=True) + if npd is None and serializable is None: + raise ValueError("either `npd` or `serializable` should be provided") + if npd is None: + npd = serializable.to_npd() # type: ignore + npd_folder = folder / cls.npd_folder + npd_folder.mkdir(exist_ok=True) + for k, v in npd.items(): + np.save(npd_folder / f"{k}.npy", v) + + @classmethod + def load_npd(cls, folder: TPath) -> np_dict_type: + import numpy as np + + folder = to_path(folder) + folder.mkdir(parents=True, exist_ok=True) + npd_folder = folder / cls.npd_folder + if not npd_folder.is_dir(): + raise ValueError(f"'{npd_folder}' does not exist") + npd = {} + for file in npd_folder.iterdir(): + npd[file.stem] = np.load(npd_folder / file) + return npd + + @classmethod + def save( + cls, + folder: TPath, + serializable: ISerializable, + *, + save_npd: bool = True, + ) -> None: + folder = to_path(folder) + cls.save_info(folder, serializable=serializable) + if save_npd and isinstance(serializable, ISerializableArrays): + cls.save_npd(folder, serializable=serializable) + with (folder / cls.id_file).open("w") as f: + f.write(serializable.__identifier__) + + @classmethod + def load( + cls, + folder: TPath, + base: Type[TSerializable], + *, + swap_id: Optional[str] = None, + swap_info: Optional[Dict[str, Any]] = None, + load_npd: bool = True, + ) -> TSerializable: + serializable = cls.load_empty(folder, base, swap_id=swap_id) + serializable.from_info(swap_info or cls.load_info(folder)) + if load_npd and isinstance(serializable, ISerializableArrays): + serializable.from_npd(cls.load_npd(folder)) + serializable.after_load() + return serializable + + @classmethod + def load_empty( + cls, + folder: TPath, + base: Type[TSerializable], + *, + swap_id: Optional[str] = None, + ) -> TSerializable: + if swap_id is not None: + s_type = swap_id + else: + folder = to_path(folder) + id_path = folder / cls.id_file + if not id_path.is_file(): + raise ValueError(f"cannot find '{id_path}'") + with id_path.open("r") as f: + s_type = f.read().strip() + return base.make(s_type, {}) + + +class Incrementer: + """ + Util class which can calculate running mean & running std efficiently. + + Parameters + ---------- + window_size : {int, None}, window size of running statistics. + * If None, then all history records will be used for calculation. + + Examples + ---------- + >>> incrementer = Incrementer(window_size=5) + >>> for i in range(10): + >>> incrementer.update(i) + >>> if i >= 4: + >>> print(incrementer.mean) # will print 2.0, 3.0, ..., 6.0, 7.0 + + """ + + def __init__(self, window_size: Optional[int] = None): + if window_size is not None: + if not isinstance(window_size, int): + msg = f"window size should be integer, {type(window_size)} found" + raise ValueError(msg) + if window_size < 2: + msg = f"window size should be at least 2, {window_size} found" + raise ValueError(msg) + self.previous: List[float] = [] + self.num_record = 0.0 + self.window_size = window_size + self.running_sum = self.running_square_sum = 0.0 + + @property + def mean(self) -> float: + return self.running_sum / self.num_record + + @property + def std(self) -> float: + return math.sqrt( + max( + 0.0, + self.running_square_sum / self.num_record - self.mean**2, + ) + ) + + def update(self, new_value: float) -> None: + self.num_record += 1 + self.running_sum += new_value + self.running_square_sum += new_value**2 + if self.window_size is not None: + self.previous.append(new_value) + if self.num_record == self.window_size + 1: + self.num_record -= 1 + previous = self.previous.pop(0) + self.running_sum -= previous + self.running_square_sum -= previous**2 + + +class OPTBase(ABC): + def __init__(self) -> None: + self._opt = self.defaults + self.update_from_env() + + def __getattr__(self, __name: str) -> Any: + return self._opt[__name] + + # abstract + + @property + @abstractmethod + def env_key(self) -> str: + pass + + @property + @abstractmethod + def defaults(self) -> Dict[str, Any]: + pass + + # optional callbacks + + def update_from_env(self) -> None: + env_opt_json = os.environ.get(self.env_key) + if env_opt_json is not None: + update_dict(json.loads(env_opt_json), self._opt) + + # api + + def opt_context(self, increment: Dict[str, Any]) -> ContextManager: + class _: + def __init__(self) -> None: + self._increment = increment + self._backup = shallow_copy_dict(instance._opt) + + def __enter__(self) -> None: + update_dict(self._increment, instance._opt) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + instance._opt = self._backup + + instance = self + return _() + + def opt_env_context(self, increment: Dict[str, Any]) -> ContextManager: + class _: + def __init__(self) -> None: + self._increment = increment + self._backup = os.environ.get(instance.env_key) + + def __enter__(self) -> None: + os.environ[instance.env_key] = json.dumps(self._increment) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._backup is None: + del os.environ[instance.env_key] + else: + os.environ[instance.env_key] = self._backup + + instance = self + return _() + + +# contexts + + +class timeit: + """ + Timing context manager. + + Examples + -------- + >>> with timeit("something"): + >>> # do something here + >>> # will print "> [ info ] timing for something : x.xxxx" + + """ + + t: float + + def __init__(self, message: str, *, precision: int = 6, enabled: bool = True): + self.p = precision + self.message = message + self.enabled = enabled + + def __enter__(self) -> None: + if self.enabled: + self.t = time.time() + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self.enabled: + console.log( + f"timing for {self.message:^16s} : " + f"{time.time() - self.t:{self.p}.{self.p-2}f}", + _stack_offset=3, + ) + + +class batch_manager: + """ + Process data in batch. + + Parameters + ---------- + inputs : tuple(np.ndarray), auxiliary array inputs. + num_elem : {int, float}, indicates how many elements will be processed in a batch. + > `element` here means every single entry of the `inputs`. + batch_size : int, indicates the batch_size; if None, batch_size will be + calculated by `num_elem`. + + Examples + -------- + >>> with batch_manager(np.arange(5), np.arange(1, 6), batch_size=2) as manager: + >>> for arr, tensor in manager: + >>> print(arr, tensor) + >>> # Will print: + >>> # [0 1], [1 2] + >>> # [2 3], [3 4] + >>> # [4] , [5] + + """ + + start: int + end: int + + def __init__( + self, + *inputs: arr_type, + num_elem: Union[int, float] = 1e6, + batch_size: Optional[int] = None, + max_batch_size: int = 1024, + ): + if not inputs: + raise ValueError("inputs should be provided in general_batch_manager") + input_lengths = list(map(len, inputs)) + self.num_samples, self.inputs = input_lengths[0], inputs + assert_msg = "inputs should be of same length" + assert all(length == self.num_samples for length in input_lengths), assert_msg + if batch_size is not None: + self.batch_size = batch_size + else: + self.batch_size = int( + int(num_elem) / sum(map(lambda arr: prod(arr.shape[1:]), inputs)) + ) + self.batch_size = min(max_batch_size, min(self.num_samples, self.batch_size)) + self.num_epoch = int(self.num_samples / self.batch_size) + self.num_epoch += int(self.num_epoch * self.batch_size < self.num_samples) + + def __enter__(self) -> "batch_manager": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass + + def __iter__(self) -> "batch_manager": + self.start, self.end = 0, self.batch_size + return self + + def __next__(self) -> Union[Tuple[arr_type, ...], arr_type]: + if self.start >= self.num_samples: + raise StopIteration + batched_data = tuple( + map( + lambda arr: arr[self.start : self.end], + self.inputs, + ) + ) + self.start, self.end = self.end, self.end + self.batch_size + if len(batched_data) == 1: + return batched_data[0] + return batched_data + + def __len__(self) -> int: + return self.num_epoch diff --git a/cfdraw/core/toolkit/pipeline.py b/cfdraw/core/toolkit/pipeline.py new file mode 100644 index 00000000..882a0edc --- /dev/null +++ b/cfdraw/core/toolkit/pipeline.py @@ -0,0 +1,212 @@ +import shutil + +from abc import abstractmethod +from abc import ABCMeta +from typing import Any +from typing import Dict +from typing import List +from typing import Type +from typing import Union +from typing import Generic +from typing import Mapping +from typing import TypeVar +from typing import Optional +from typing import ContextManager +from pathlib import Path +from zipfile import ZipFile +from tempfile import mkdtemp + +from .misc import to_path +from .misc import shallow_copy_dict +from .misc import WithRegister +from .misc import ISerializable +from .misc import ISerializableDataClass +from .types import TPath + + +TB = TypeVar("TB", bound="IBlock") +TBlock = TypeVar("TBlock", bound="IBlock") +TConfig = TypeVar("TConfig", bound="ISerializableDataClass") +T_p = TypeVar("T_p", bound="IPipeline", covariant=True) +TPipeline = TypeVar("TPipeline", bound="IPipeline") + +pipelines: Dict[str, Type["IPipeline"]] = {} +pipeline_blocks: Dict[str, Type["IBlock"]] = {} + + +def get_folder(folder: TPath, *, force_new: bool = False) -> ContextManager: + class _: + tmp_folder: Optional[Path] + + def __init__(self) -> None: + self.tmp_folder = None + + def __enter__(self) -> Path: + folder = to_path(folder_input) + if folder.is_dir(): + if not force_new: + return folder + self.tmp_folder = Path(mkdtemp()) + shutil.copytree(folder, self.tmp_folder, dirs_exist_ok=True) + return self.tmp_folder + path = Path(f"{folder}.zip") + if not path.is_file(): + raise ValueError(f"neither '{folder}' nor '{path}' exists") + self.tmp_folder = Path(mkdtemp()) + with ZipFile(path, "r") as ref: + ref.extractall(self.tmp_folder) + return self.tmp_folder + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self.tmp_folder is not None: + shutil.rmtree(self.tmp_folder) + + folder_input = folder + return _() + + +def get_req_choices(req: "IBlock") -> List[str]: + return [r.strip() for r in req.__identifier__.split("|")] + + +def check_requirement(block: "IBlock", previous: Mapping[str, "IBlock"]) -> None: + for req in block.requirements: + choices = get_req_choices(req) # type: ignore + if all(c != "none" and c not in previous for c in choices): + raise ValueError( + f"'{block.__identifier__}' requires '{req}', " + "but none is provided in the previous blocks" + ) + + +class IBlock(Generic[TBlock], WithRegister["IBlock"], metaclass=ABCMeta): + d = pipeline_blocks + + """ + This property should be injected by the `IPipeline`. + > In runtime (i.e. executing the `run` method), this property will represent ALL `IBlock`s used in the `IPipeline`. + """ + previous: Dict[str, TBlock] + + @abstractmethod + def build(self, config: Any) -> None: + """This method can modify the `config` inplace, which will affect the following blocks""" + + @property + def requirements(self) -> List[Type[TBlock]]: + return [] + + def try_get_previous(self, block: Union[str, Type[TB]]) -> Optional[TB]: + if not isinstance(block, str): + block = block.__identifier__ + return self.previous.get(block) # type: ignore + + def get_previous(self, block: Union[str, Type[TB]]) -> TB: + b = self.try_get_previous(block) + if b is None: + raise ValueError(f"cannot find '{block}' in `previous`") + return b + + +class IPipeline( + Generic[TBlock, TConfig, TPipeline], + ISerializable[TPipeline], + metaclass=ABCMeta, +): + d = pipelines # type: ignore + + config: TConfig + blocks: List[TBlock] + + def __init__(self) -> None: + self.blocks = [] + + # abstract + + @classmethod + @abstractmethod + def init(cls: Type[TPipeline], config: TConfig) -> TPipeline: + pass + + @property + @abstractmethod + def config_base(self) -> Type[TConfig]: + pass + + @property + @abstractmethod + def block_base(self) -> Type[TBlock]: + pass + + # inheritance + + def to_info(self) -> Dict[str, Any]: + return dict( + blocks=[ + ( + b.to_pack().asdict() # type: ignore + if isinstance(b, ISerializable) + else b.__identifier__ + ) + for b in self.blocks + ], + config=self.config.to_pack().asdict(), + ) + + def from_info(self: T_p, info: Dict[str, Any]) -> T_p: + self.config = self.config_base.from_pack(info["config"]) + block_base = self.block_base + blocks: List[TBlock] = [] + for block in info["blocks"]: + blocks.append( + block_base.from_pack(block) # type: ignore + if issubclass(block_base, ISerializable) + else block_base.make(block, {}) + ) + self.build(*blocks) + return self + + # optional callbacks + + def before_block_build(self, block: TBlock) -> None: + pass + + def after_block_build(self, block: TBlock) -> None: + pass + + # api + + @property + def block_mappings(self) -> Dict[str, TBlock]: + return {b.__identifier__: b for b in self.blocks} + + def try_get_block(self, block: Union[str, Type[TB]]) -> Optional[TB]: + if not isinstance(block, str): + block = block.__identifier__ + return self.block_mappings.get(block) # type: ignore + + def get_block(self, block: Union[str, Type[TB]]) -> TB: + b = self.try_get_block(block) + if b is None: + raise ValueError(f"cannot find '{block}' in `previous`") + return b + + def build(self, *blocks: TBlock) -> None: + previous: Dict[str, TBlock] = self.block_mappings + for block in blocks: + check_requirement(block, previous) + block.previous = shallow_copy_dict(previous) + self.before_block_build(block) + block.build(self.config) + self.after_block_build(block) + previous[block.__identifier__] = block + self.blocks.append(block) + + +__all__ = [ + "IBlock", + "IPipeline", + "TPipeline", + "get_folder", + "get_req_choices", +] diff --git a/cfdraw/core/toolkit/types.py b/cfdraw/core/toolkit/types.py new file mode 100644 index 00000000..e914b82c --- /dev/null +++ b/cfdraw/core/toolkit/types.py @@ -0,0 +1,21 @@ +from typing import Any +from typing import Dict +from typing import Tuple +from typing import Union +from typing import TypeVar +from typing import Optional +from typing import TYPE_CHECKING +from pathlib import Path + +if TYPE_CHECKING: + import torch + import numpy as np + +arr_type = Union["np.ndarray", "torch.Tensor"] +TArray = TypeVar("TArray", bound=arr_type) +np_dict_type = Dict[str, Union["np.ndarray", Any]] +tensor_dict_type = Dict[str, Union["torch.Tensor", Any]] + +TPath = Union[str, Path] +TConfig = Optional[Union[TPath, Dict[str, Any]]] +TNumberPair = Optional[Union[int, Tuple[int, int]]] diff --git a/cfdraw/core/toolkit/web.py b/cfdraw/core/toolkit/web.py new file mode 100644 index 00000000..4e3637ba --- /dev/null +++ b/cfdraw/core/toolkit/web.py @@ -0,0 +1,184 @@ +import json +import time +import logging + +from io import BytesIO +from typing import Any +from typing import Dict +from typing import Type +from typing import TypeVar +from typing import Callable +from typing import Optional +from typing import Awaitable +from typing import TYPE_CHECKING +from pydantic import BaseModel +from pydantic import ConfigDict + +from .misc import get_err_msg +from .constants import WEB_ERR_CODE + +if TYPE_CHECKING: + from PIL import Image + from aiohttp import ClientSession + + +TResponse = TypeVar("TResponse") + + +class RuntimeError(BaseModel): + detail: str + + model_config = ConfigDict( + json_schema_extra={"example": {"detail": "RuntimeError occurred."}} + ) + + +def get_ip() -> str: + import socket + + return socket.gethostbyname(socket.gethostname()) + + +def get_responses( + success_model: Type[BaseModel], + *, + json_example: Optional[Dict[str, Any]] = None, +) -> Dict[int, Dict[str, Type]]: + success_response: Dict[str, Any] = {"model": success_model} + if json_example is not None: + content = success_response["content"] = {} + json_field = content["application/json"] = {} + json_field["example"] = json_example + return { + 200: success_response, + WEB_ERR_CODE: {"model": RuntimeError}, + } + + +def get_image_response_kwargs() -> Dict[str, Any]: + from fastapi import Response + + example = "\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x08\\x00\\x00\\x00\\x00:~\\x9bU\\x00\\x00\\x00\\nIDATx\\x9cc`\\x00\\x00\\x00\\x02\\x00\\x01H\\xaf\\xa4q\\x00\\x00\\x00\\x00IEND\\xaeB`\\x82" + responses = { + 200: {"content": {"image/png": {"example": example}}}, + WEB_ERR_CODE: {"model": RuntimeError}, + } + description = """ +Bytes of the output image. ++ When using `requests` in `Python`, you can get the `bytes` with `res.content`. ++ When using `fetch` in `JavaScript`, you can get the `Blob` with `await res.blob()`. +""" + return dict( + responses=responses, + response_class=Response(content=b""), + response_description=description, + ) + + +def raise_err(err: Exception) -> None: + from fastapi import HTTPException + + logging.exception(err) + raise HTTPException(status_code=WEB_ERR_CODE, detail=get_err_msg(err)) + + +async def get(url: str, session: "ClientSession", **kwargs: Any) -> bytes: + async with session.get(url, **kwargs) as response: + return await response.read() + + +async def post( + url: str, + json: Dict[str, Any], + session: "ClientSession", + **kwargs: Any, +) -> Dict[str, Any]: + async with session.post(url, json=json, **kwargs) as response: + return await response.json() + + +def log_endpoint(endpoint: str, data: BaseModel) -> None: + msg = f"{endpoint} endpoint entered with kwargs : {json.dumps(data.model_dump(), ensure_ascii=False)}" + logging.debug(msg) + + +def log_times(endpoint: str, times: Dict[str, float]) -> None: + times["__total__"] = sum(times.values()) + logging.debug(f"elapsed time of endpoint {endpoint} : {json.dumps(times)}") + + +async def download_raw(session: "ClientSession", url: str, **kw: Any) -> bytes: + try: + return await get(url, session, **kw) + except Exception: + import requests + + return requests.get(url, **kw).content + + +async def download_image( + session: "ClientSession", url: str, **kw: Any +) -> "Image.Image": + from PIL import Image + from PIL import ImageOps + + raw_data = None + try: + raw_data = await download_raw(session, url, **kw) + image = Image.open(BytesIO(raw_data)) + try: + image = ImageOps.exif_transpose(image) + finally: + return image + except Exception as err: + if raw_data is None: + msg = f"raw | None | err | {err}" + else: + try: + msg = raw_data.decode("utf-8") + except: + msg = f"raw | {raw_data[:20]!r} | err | {err}" + raise ValueError(msg) + + +async def retry_with( + download_fn: Callable[["ClientSession", str], Awaitable[TResponse]], + session: "ClientSession", + url: str, + retry: int = 3, + interval: int = 1, + **kw: Any, +) -> TResponse: + msg = "" + for i in range(retry): + try: + res = await download_fn(session, url, **kw) + if i > 0: + logging.warning(f"succeeded after {i} retries") + return res + except Exception as err: + msg = str(err) + time.sleep(interval) + raise ValueError(f"{msg}\n(After {retry} retries)") + + +async def download_raw_with_retry( + session: "ClientSession", + url: str, + *, + retry: int = 3, + interval: int = 1, + **kw: Any, +) -> bytes: + return await retry_with(download_raw, session, url, retry, interval, **kw) + + +async def download_image_with_retry( + session: "ClientSession", + url: str, + *, + retry: int = 3, + interval: int = 1, + **kw: Any, +) -> "Image.Image": + return await retry_with(download_image, session, url, retry, interval, **kw) diff --git a/cfdraw/parsers/noli.py b/cfdraw/parsers/noli.py index 6741e0d5..e1c93ee9 100644 --- a/cfdraw/parsers/noli.py +++ b/cfdraw/parsers/noli.py @@ -7,11 +7,12 @@ from typing import Generator from pydantic import Field from pydantic import BaseModel -from cftool.geometry import Line -from cftool.geometry import Point -from cftool.geometry import HitTest -from cftool.geometry import Matrix2D -from cftool.geometry import PivotType + +from cfdraw.core.toolkit.geometry import Line +from cfdraw.core.toolkit.geometry import Point +from cfdraw.core.toolkit.geometry import HitTest +from cfdraw.core.toolkit.geometry import Matrix2D +from cfdraw.core.toolkit.geometry import PivotType class Lang(str, Enum): diff --git a/cfdraw/plugins/base.py b/cfdraw/plugins/base.py index 6037c4f2..905805f8 100644 --- a/cfdraw/plugins/base.py +++ b/cfdraw/plugins/base.py @@ -5,8 +5,6 @@ from typing import Dict from typing import List from typing import Optional -from cftool.misc import shallow_copy_dict -from cftool.data_structures import Workflow from cfdraw import constants from cfdraw.utils import server @@ -16,6 +14,7 @@ from cfdraw.parsers.noli import SingleNodeType from cfdraw.app.endpoints.upload import ImageUploader from cfdraw.app.endpoints.upload import FetchImageModel +from cfdraw.core.toolkit.misc import shallow_copy_dict class ISocketPlugin(IPlugin, metaclass=ABCMeta): diff --git a/cfdraw/plugins/factory.py b/cfdraw/plugins/factory.py index eb84f7ea..3a86baeb 100644 --- a/cfdraw/plugins/factory.py +++ b/cfdraw/plugins/factory.py @@ -1,11 +1,11 @@ from typing import Type from typing import Callable from typing import NamedTuple -from cftool.data_structures import Types from cfdraw.schema.plugins import IPlugin from cfdraw.schema.plugins import IPluginSettings from cfdraw.schema.plugins import IPluginGroupInfo +from cfdraw.core.toolkit.data_structures import Types TPlugin = Type[IPlugin] diff --git a/cfdraw/schema/fields.py b/cfdraw/schema/fields.py index d07c519b..29680227 100644 --- a/cfdraw/schema/fields.py +++ b/cfdraw/schema/fields.py @@ -9,7 +9,6 @@ from typing import Union from typing import Optional from pathlib import Path -from pydantic import Extra from pydantic import Field from pydantic import BaseModel @@ -62,7 +61,7 @@ class IBaseField(BaseModel): ) class Config: - extra = Extra.forbid + extra = "forbid" smart_union = True diff --git a/cfdraw/schema/plugins.py b/cfdraw/schema/plugins.py index ed14a0bb..ba507427 100644 --- a/cfdraw/schema/plugins.py +++ b/cfdraw/schema/plugins.py @@ -1,23 +1,30 @@ +import io import json import time +import networkx as nx +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + from abc import abstractmethod from abc import ABC from PIL import Image from enum import Enum from typing import Any +from typing import Set from typing import Dict from typing import List from typing import Type +from typing import Tuple from typing import Union from typing import TypeVar from typing import Callable from typing import Optional from typing import Coroutine +from typing import NamedTuple from aiohttp import ClientSession from pydantic import Field from pydantic import BaseModel -from cftool.data_structures import Workflow from cfdraw import constants from cfdraw.schema.fields import IFieldDefinition @@ -30,12 +37,217 @@ from cfdraw.parsers.noli import NodeConstraintRules from cfdraw.parsers.chakra import IChakra from cfdraw.parsers.chakra import TextAlign +from cfdraw.core.toolkit.data_structures import Item +from cfdraw.core.toolkit.data_structures import Bundle TPluginModel = TypeVar("TPluginModel") ISend = Callable[["ISocketMessage"], Coroutine[Any, Any, bool]] +class InjectionPack(BaseModel): + index: Optional[int] + field: str + + +class WorkNode(BaseModel): + key: str = Field( + ..., + description="Key of the node, should be identical within the same workflow", + ) + endpoint: str = Field(..., description="Algorithm endpoint of the node") + injections: Dict[str, Union[InjectionPack, List[InjectionPack]]] = Field( + ..., + description=( + "Injection map, maps 'key' from other `WorkNode` (A) to 'index' of A's results & " + "'field' of the algorithm's field. In runtime, we'll collect " + "the (list of) results from the depedencies (other `WorkNode`) and " + "inject the specific result (based on 'index') to the algorithm's field.\n" + "> If external caches is provided, the 'key' could be the key of the external cache.\n" + "> Hierarchy injection is also supported, you just need to set 'field' to:\n" + ">> `a.b.c` to inject the result to data['a']['b']['c']\n" + ">> `a.0.b` to inject the first result to data['a'][0]['b']\n" + ), + ) + data: Dict[str, Any] = Field(..., description="Algorithm's data") + + def to_item(self) -> Item["WorkNode"]: + return Item(self.key, self) + + +class ToposortResult(NamedTuple): + in_edges: Dict[str, Set[str]] + hierarchy: List[List[Item[WorkNode]]] + edge_labels: Dict[Tuple[str, str], str] + + +class Workflow(Bundle[WorkNode]): + def copy(self) -> "Workflow": + return Workflow.from_json(self.to_json()) + + def push(self, node: WorkNode) -> "Workflow": # type: ignore + return super().push(node.to_item()) + + def toposort(self) -> ToposortResult: + in_edges: Dict[str, Set[str]] = {item.key: set() for item in self} + out_degrees = {item.key: 0 for item in self} + edge_labels: Dict[Tuple[str, str], str] = {} + for item in self: + for dep, packs in item.data.injections.items(): + in_edges[dep].add(item.key) + out_degrees[item.key] += 1 + if not isinstance(packs, list): + packs = [packs] + for pack in packs: + label_key = (item.key, dep) + existing_label = edge_labels.get(label_key) + if existing_label is None: + edge_labels[label_key] = pack.field + else: + edge_labels[label_key] = f"{existing_label}, {pack.field}" + + ready = [k for k, v in out_degrees.items() if v == 0] + result = [] + while ready: + layer = ready.copy() + result.append(layer) + ready.clear() + for dep in layer: + for node in in_edges[dep]: + out_degrees[node] -= 1 + if out_degrees[node] == 0: + ready.append(node) + + if len(self) != sum(map(len, result)): + raise ValueError("cyclic dependency detected") + + hierarchy = [list(map(self.get, layer)) for layer in result] + return ToposortResult(in_edges, hierarchy, edge_labels) # type: ignore + + def get_dependency_path(self, target: str) -> ToposortResult: + def dfs(key: str) -> None: + if key in reachable: + return + reachable.add(key) + for dep_key in self.get(key).data.injections: # type: ignore + dfs(dep_key) + + reachable: Set[str] = set() + dfs(target) + in_edges, raw_hierarchy, edge_labels = self.toposort() + hierarchy = [] + for raw_layer in raw_hierarchy: + layer = [] + for item in raw_layer: + if item.key in reachable: + layer.append(item) + if layer: + hierarchy.append(layer) + return ToposortResult(in_edges, hierarchy, edge_labels) + + def to_json(self) -> List[Dict[str, Any]]: + return [node.data.model_dump() for node in self] + + @classmethod + def from_json(cls, data: List[Dict[str, Any]]) -> "Workflow": + workflow = cls() + for json in data: + workflow.push(WorkNode(**json)) + return workflow + + def inject_caches(self, caches: Dict[str, Any]) -> "Workflow": + for k in caches: + self.push(WorkNode(key=k, endpoint="", injections={}, data={})) + return self + + def render( + self, + *, + target: Optional[str] = None, + caches: Optional[Dict[str, Any]] = None, + fig_w_ratio: int = 4, + fig_h_ratio: int = 3, + dpi: int = 200, + node_size: int = 2000, + node_shape: str = "s", + node_color: str = "lightblue", + layout: str = "multipartite_layout", + ) -> Image.Image: + if Image is None: + raise ValueError("PIL is required for `render`") + # setup workflow + workflow = self.copy() + if caches is not None: + workflow.inject_caches(caches) + # setup graph + G = nx.DiGraph() + if target is None: + target = self.last.key # type: ignore + in_edges, hierarchy, edge_labels = workflow.get_dependency_path(target) + # setup plt + fig_w = max(fig_w_ratio * len(hierarchy), 8) + fig_h = fig_h_ratio * max(map(len, hierarchy)) + plt.figure(figsize=(fig_w, fig_h), dpi=dpi) + box = plt.gca().get_position() + plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height]) + # map key to indices + key2idx: Dict[str, int] = {} + for layer in hierarchy: + for node in layer: + key2idx[node.key] = len(key2idx) + # add nodes + for i, layer in enumerate(hierarchy): + for node in layer: + G.add_node(key2idx[node.key], subset=f"layer_{i}") + # add edges + for dep, links in in_edges.items(): + for link in links: + if dep not in key2idx or link not in key2idx: + continue + label = edge_labels[(link, dep)] + G.add_edge(key2idx[dep], key2idx[link], label=label) + # calculate positions + layout_fn = getattr(nx, layout, None) + if layout_fn is None: + raise ValueError(f"unknown layout: {layout}") + pos = layout_fn(G) + # draw the nodes + nodes_styles = dict( + node_size=node_size, + node_shape=node_shape, + node_color=node_color, + ) + nx.draw_networkx_nodes(G, pos, **nodes_styles) + node_labels_styles = dict( + font_size=18, + ) + nx.draw_networkx_labels(G, pos, **node_labels_styles) + # draw the edges + nx_edge_labels = nx.get_edge_attributes(G, "label") + nx.draw_networkx_edges( + G, + pos, + arrows=True, + arrowstyle="-|>", + arrowsize=16, + node_size=nodes_styles["node_size"], + node_shape=nodes_styles["node_shape"], + ) + nx.draw_networkx_edge_labels(G, pos, edge_labels=nx_edge_labels) + # draw captions + patches = [ + mpatches.Patch(color=node_color, label=f"{idx}: {key}") + for key, idx in key2idx.items() + ] + plt.legend(handles=patches, bbox_to_anchor=(1, 0.5), loc="center left") + # render + plt.axis("off") + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + return Image.open(buf) + + class PluginType(str, Enum): """ These types should align with the `allPythonPlugins` locates at @@ -740,6 +952,7 @@ class IChatPluginInfo(IWorkflowPluginInfo): "ISend", "PluginType", "ReactPluginType", + "Workflow", # general "hash_identifier", "IPluginInfo", diff --git a/cfdraw/schema/settings.py b/cfdraw/schema/settings.py index 433d395e..8d1d39b8 100644 --- a/cfdraw/schema/settings.py +++ b/cfdraw/schema/settings.py @@ -4,12 +4,12 @@ from typing import Optional from pydantic import Field from pydantic import BaseModel -from cftool.misc import random_hash from cfdraw.parsers import noli from cfdraw.schema.plugins import hash_identifier from cfdraw.schema.plugins import ILogoSettings from cfdraw.schema.plugins import ReactPluginType +from cfdraw.core.toolkit.misc import random_hash class BoardOptions(BaseModel): diff --git a/cfdraw/utils/console.py b/cfdraw/utils/console.py deleted file mode 100644 index db9fa83d..00000000 --- a/cfdraw/utils/console.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Any -from typing import Dict -from typing import List -from typing import Optional - -from rich.prompt import Prompt -from rich.status import Status -from rich.console import Console - - -_console = Console() - - -def deprecate(msg: str) -> None: - _console.print(f"[yellow]DeprecationWarning: {msg}[/yellow]") - - -def log(msg: str) -> None: - _console.log(msg) - - -def print(msg: str) -> None: - _console.print(msg) - - -def rule(title: str) -> None: - _console.rule(title) - - -def ask( - question: str, - choices: Optional[List[str]] = None, - *, - default: Optional[str] = None, -) -> str: - kw: Dict[str, Any] = dict(choices=choices) - if default is not None: - kw["default"] = default - return Prompt.ask(question, **kw) # type: ignore - - -def status(msg: str) -> Status: - return _console.status(msg) diff --git a/cfdraw/utils/exec.py b/cfdraw/utils/exec.py index 9679a39a..a46273d9 100644 --- a/cfdraw/utils/exec.py +++ b/cfdraw/utils/exec.py @@ -2,12 +2,10 @@ import uvicorn import subprocess -from cftool.misc import print_info - from cfdraw import constants -from cfdraw.utils import console from cfdraw.utils import prerequisites from cfdraw.config import get_config +from cfdraw.core.toolkit import console def setup_frontend() -> None: @@ -32,14 +30,14 @@ def run_frontend(host: bool) -> None: stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, ) - print_info(f"๐Ÿ‘Œ Your app will be ready at {get_config().frontend_url} soon...") + console.log(f"๐Ÿ‘Œ Your app will be ready at {get_config().frontend_url} soon...") def run_frontend_prod() -> None: setup_frontend() config = get_config() if config.use_unified: - print_info(f"๐Ÿ‘€ Your app codes are being compiled, please wait for a while...") + console.log(f"๐Ÿ‘€ Your app codes are being compiled, please wait for a while...") subprocess.run( [prerequisites.get_yarn(), "build"], cwd=constants.WEB_ROOT, @@ -51,7 +49,7 @@ def run_frontend_prod() -> None: cwd=constants.WEB_ROOT, env=os.environ, ) - print_info( + console.log( f"๐Ÿ‘€ Your app codes are being compiled, " "please wait until a bunch of urls appear..." ) @@ -81,5 +79,5 @@ def run_backend_prod(module: str, *, log_level: constants.LogLevel) -> None: console.rule("[bold green]Launching Production Backend") config = get_config() if config.use_unified: - print_info(f"๐Ÿ‘Œ Your app will be ready at {config.api_url} soon...") + console.log(f"๐Ÿ‘Œ Your app will be ready at {config.api_url} soon...") run_backend(module, log_level=log_level, verbose=False) diff --git a/cfdraw/utils/misc.py b/cfdraw/utils/misc.py index 6ba7a420..0e6eb023 100644 --- a/cfdraw/utils/misc.py +++ b/cfdraw/utils/misc.py @@ -6,11 +6,10 @@ from typing import TypeVar from typing import Callable from typing import Coroutine -from cftool.misc import get_err_msg -from cftool.misc import print_error -from cftool.misc import print_warning from concurrent.futures import ThreadPoolExecutor +from cfdraw.core.toolkit import console + TFutureResponse = TypeVar("TFutureResponse") @@ -19,7 +18,7 @@ def deprecated(message: str) -> Callable[[type], type]: def _deprecated(cls: type) -> type: def init(self: Any, *args: Any, **kwargs: Any) -> None: if not cls._warned_deprecation: # type: ignore - print_warning(f"{cls.__name__} is deprecated, {message}") + console.warn(f"{cls.__name__} is deprecated, {message}") cls._warned_deprecation = True # type: ignore original_init(self, *args, **kwargs) @@ -62,7 +61,7 @@ def _run() -> None: if success: event.set() else: - print_error("[offload_run] Failed to execute future") + console.error("\[offload_run] Failed to execute future") except Exception: logging.exception("[offload_run] failed to execute future") diff --git a/cfdraw/utils/prerequisites.py b/cfdraw/utils/prerequisites.py index beff9000..0844aeb8 100644 --- a/cfdraw/utils/prerequisites.py +++ b/cfdraw/utils/prerequisites.py @@ -5,7 +5,7 @@ from typing import Dict from cfdraw import constants -from cfdraw.utils import console +from cfdraw.core.toolkit import console def get_yarn() -> str: diff --git a/cfdraw/utils/processes.py b/cfdraw/utils/processes.py index 875403d2..b97247b9 100644 --- a/cfdraw/utils/processes.py +++ b/cfdraw/utils/processes.py @@ -6,7 +6,7 @@ from typing import Optional -from cfdraw.utils import console +from cfdraw.core.toolkit import console def kill(pid: int) -> None: diff --git a/cfdraw/utils/server.py b/cfdraw/utils/server.py index 2dc43fe9..a8ca1408 100644 --- a/cfdraw/utils/server.py +++ b/cfdraw/utils/server.py @@ -9,10 +9,10 @@ from typing import Union from fastapi import Response from PIL.PngImagePlugin import PngInfo -from cftool.cv import to_rgb -from cftool.cv import np_to_bytes -from cftool.web import raise_err -from cftool.misc import random_hash +from cfdraw.core.toolkit.cv import to_rgb +from cfdraw.core.toolkit.cv import np_to_bytes +from cfdraw.core.toolkit.web import raise_err +from cfdraw.core.toolkit.misc import random_hash from cfdraw.config import get_config diff --git a/cfdraw/utils/template.py b/cfdraw/utils/template.py index cd44a341..925d5165 100644 --- a/cfdraw/utils/template.py +++ b/cfdraw/utils/template.py @@ -1,9 +1,8 @@ from enum import Enum from pathlib import Path -from cftool.misc import print_info from cfdraw import constants -from cfdraw.utils import console +from cfdraw.core.toolkit import console IMAGE_APP_TEMPLATE = f""" @@ -152,6 +151,6 @@ def set_init_codes(folder: Path, template: TemplateType) -> None: return with config_path.open("w") as f: f.write(CONFIG_TEMPLATE) - print_info(f"App can be modified at {app_path}") - print_info(f"Config can be modified at {config_path}") + console.log(f"App can be modified at {app_path}") + console.log(f"Config can be modified at {config_path}") console.rule(f"[bold green]๐ŸŽ‰ You can launch the app with `cfdraw run` now! ๐ŸŽ‰") diff --git a/examples/caption_and_diffusion/advanced.py b/examples/caption_and_diffusion/advanced.py index 2ed8d903..17a874d6 100644 --- a/examples/caption_and_diffusion/advanced.py +++ b/examples/caption_and_diffusion/advanced.py @@ -3,7 +3,7 @@ from typing import List from cfdraw import * -from cftool.misc import shallow_copy_dict +from cfdraw.core.toolkit.misc import shallow_copy_dict @cache_resource diff --git a/examples/carefree_creator/app.py b/examples/carefree_creator/app.py index 16eab753..a0350fd2 100644 --- a/examples/carefree_creator/app.py +++ b/examples/carefree_creator/app.py @@ -10,12 +10,12 @@ from typing import Optional from pathlib import Path from pydantic import BaseModel -from cftool.misc import shallow_copy_dict from cfcreator.common import InpaintingMode from cflearn.misc.toolkit import new_seed from cfcreator.sdks.apis import ALL_LATENCIES_KEY from cfdraw import * +from cfdraw.core.toolkit.misc import shallow_copy_dict from utils import * from fields import * diff --git a/examples/carefree_creator/fields.py b/examples/carefree_creator/fields.py index 9613707b..cb89b0df 100644 --- a/examples/carefree_creator/fields.py +++ b/examples/carefree_creator/fields.py @@ -368,7 +368,9 @@ url=IImageField( default="", label=I18N(zh="ๅˆๅง‹ๅ›พ", en="Init Image"), - tooltip=I18N(zh="ๅฏ้€‰้กน๏ผŒไธ้€‰ไนŸๆฒก้—ฎ้ข˜", en="This is optional, you can leave it blank"), + tooltip=I18N( + zh="ๅฏ้€‰้กน๏ผŒไธ้€‰ไนŸๆฒก้—ฎ้ข˜", en="This is optional, you can leave it blank" + ), ), fidelity=fidelity, max_wh=max_wh_field, @@ -386,12 +388,16 @@ url=IImageField( default="", label=I18N(zh="ๅŽŸๅ›พ", en="Image"), - tooltip=I18N(zh="ๆƒณ่ฆ่ฟ›่กŒ้ฃŽๆ ผ่žๅˆ็š„ๅŽŸๅ›พ", en="The original image to be harmonized"), + tooltip=I18N( + zh="ๆƒณ่ฆ่ฟ›่กŒ้ฃŽๆ ผ่žๅˆ็š„ๅŽŸๅ›พ", en="The original image to be harmonized" + ), ), mask_url=IImageField( default="", label=I18N(zh="ๅ‰ๆ™ฏ", en="Foreground"), - tooltip=I18N(zh="ๆƒณ่ฆ่ฟ›่กŒ้ฃŽๆ ผ่žๅˆ็š„ๅ‰ๆ™ฏๅŒบๅŸŸ", en="The foreground area to be harmonized"), + tooltip=I18N( + zh="ๆƒณ่ฆ่ฟ›่กŒ้ฃŽๆ ผ่žๅˆ็š„ๅ‰ๆ™ฏๅŒบๅŸŸ", en="The foreground area to be harmonized" + ), ), strength=INumberField( default=1.0, diff --git a/examples/carefree_creator/utils.py b/examples/carefree_creator/utils.py index eb575610..b0177076 100644 --- a/examples/carefree_creator/utils.py +++ b/examples/carefree_creator/utils.py @@ -1,16 +1,17 @@ from typing import Any from typing import Dict -from cfdraw import cache_resource -from cfdraw import INodeData from collections import defaultdict -from cftool.misc import random_hash -from cftool.data_structures import WorkNode -from cftool.data_structures import Workflow -from cftool.data_structures import InjectionPack from cfcreator.workflow import * from cfcreator.endpoints import * from cfcreator.sdks.apis import * +from cfdraw import cache_resource +from cfdraw import INodeData +from cfdraw.schema.plugins import WorkNode +from cfdraw.schema.plugins import Workflow +from cfdraw.schema.plugins import InjectionPack +from cfdraw.core.toolkit.misc import random_hash + DATA_MODEL_KEY = "$data_model" diff --git a/examples/stable_diffusion/advanced.py b/examples/stable_diffusion/advanced.py index d596becd..5a141076 100644 --- a/examples/stable_diffusion/advanced.py +++ b/examples/stable_diffusion/advanced.py @@ -3,7 +3,7 @@ from typing import List from cfdraw import * -from cftool.misc import shallow_copy_dict +from cfdraw.core.toolkit.misc import shallow_copy_dict @cache_resource diff --git a/examples/stable_diffusion_controlnet/advanced.py b/examples/stable_diffusion_controlnet/advanced.py index 92c27047..aae8a31a 100644 --- a/examples/stable_diffusion_controlnet/advanced.py +++ b/examples/stable_diffusion_controlnet/advanced.py @@ -3,7 +3,7 @@ from typing import List from cfdraw import * -from cftool.misc import shallow_copy_dict +from cfdraw.core.toolkit.misc import shallow_copy_dict @cache_resource diff --git a/examples/stable_diffusion_inpainting/advanced.py b/examples/stable_diffusion_inpainting/advanced.py index 2cf9ce07..7db18b74 100644 --- a/examples/stable_diffusion_inpainting/advanced.py +++ b/examples/stable_diffusion_inpainting/advanced.py @@ -3,7 +3,7 @@ from typing import List from cfdraw import * -from cftool.misc import shallow_copy_dict +from cfdraw.core.toolkit.misc import shallow_copy_dict @cache_resource diff --git a/setup.py b/setup.py index af087961..a95e5a69 100644 --- a/setup.py +++ b/setup.py @@ -15,20 +15,30 @@ entry_points={"console_scripts": ["cfdraw = cfdraw.cli:cli"]}, install_requires=[ "rich", + "tqdm", + "regex", "typer", - "fastapi>=0.95.1", - "gunicorn", - "pydantic<2.0.0", + "future", + "pathos", + "pillow", + "psutil", + "aiohttp", "uvicorn", - "websockets", + "filelock", + "networkx", + "aiofiles", + "gunicorn", + "requests", "watchdog", + "matplotlib", + "websockets", + "safetensors", "python-multipart", - "carefree-toolkit>=0.3.6.3", - "pillow", - "aiohttp", + "numpy>=1.22.3", + "fastapi>=0.95.1", + "pydantic>=2.0.0", + "websockets>=12.0", "charset-normalizer==2.1.0", - "aiofiles", - "regex", ], author="carefree0910", author_email="syameimaru.saki@gmail.com", diff --git a/tests/test_timer.py b/tests/test_timer.py index fad75842..6308f085 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -1,7 +1,7 @@ from datetime import datetime -from cftool.constants import TIME_FORMAT from cfdraw import * +from cfdraw.core.toolkit.constants import TIME_FORMAT class TimerPlugin(ITextAreaPlugin):