Skip to content

Commit

Permalink
🔥 replace ProxyModule with weakref.proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 1, 2024
1 parent 454e96a commit 05cb10e
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 131 deletions.
20 changes: 11 additions & 9 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .model import keeping as keeping
from .module import import_plugin
from .module import package as package
from .module import requires as requires
from .service import plugin_service

if TYPE_CHECKING:
Expand All @@ -37,8 +38,8 @@ def load_plugin(path: str, config: dict | None = None, recursive_guard: set[str]
if recursive_guard is None:
recursive_guard = set()
path = path.replace("::", "arclet.entari.plugins.")
if path in plugin_service._submoded:
logger.error(f"plugin {path!r} is already defined as submodule of {plugin_service._submoded[path]!r}")
if path in plugin_service._subplugined:
logger.error(f"plugin {path!r} is already defined as submodule of {plugin_service._subplugined[path]!r}")
return
if path in plugin_service.plugins:
return plugin_service.plugins[path]
Expand Down Expand Up @@ -98,8 +99,8 @@ def metadata(data: PluginMetadata):
def find_plugin(name: str) -> Plugin | None:
if name in plugin_service.plugins:
return plugin_service.plugins[name]
if name in plugin_service._submoded:
return plugin_service.plugins[plugin_service._submoded[name]]
if name in plugin_service._subplugined:
return plugin_service.plugins[plugin_service._subplugined[name]]
return None


Expand All @@ -110,11 +111,12 @@ def find_plugin_by_file(file: str) -> Plugin | None:
return plugin
if plugin.module.__file__ and Path(plugin.module.__file__).parent == path:
return plugin
for submod in plugin.submodules.values():
if submod.__file__ == str(path):
return plugin
if submod.__file__ and Path(submod.__file__).parent == path:
return plugin
for subplug in plugin.subplugins:
if plug := plugin_service.plugins.get(subplug):
if plug.module.__file__ == str(path):
return plugin
if plug.module.__file__ and Path(plug.module.__file__).parent == path:
return plugin
path1 = Path(path)
while path1.parent != path1:
if str(path1) == plugin.module.__file__:
Expand Down
103 changes: 15 additions & 88 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from collections.abc import Awaitable
from contextvars import ContextVar
from dataclasses import dataclass, field
import inspect
from pathlib import Path
import sys
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from weakref import finalize, ref
from weakref import finalize, proxy

from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx
from arclet.letoderea.builtin.breakpoint import R
from arclet.letoderea.typing import TTarget
from creart import it
from launart import Launart, Service
from loguru import logger
from satori.client import Account

from .service import plugin_service
Expand Down Expand Up @@ -114,7 +114,7 @@ class Plugin:
id: str
module: ModuleType
dispatchers: dict[str, PluginDispatcher] = field(default_factory=dict)
submodules: dict[str, ModuleType] = field(default_factory=dict)
subplugins: set[str] = field(default_factory=set)
config: dict[str, Any] = field(default_factory=dict)
_metadata: PluginMetadata | None = None
_is_disposed: bool = False
Expand Down Expand Up @@ -170,13 +170,16 @@ def dispose(self):
Path(self.module.__spec__.cached).unlink(missing_ok=True)
sys.modules.pop(self.module.__name__, None)
delattr(self.module, "__plugin__")
for submod in self.submodules.values():
delattr(submod, "__plugin__")
sys.modules.pop(submod.__name__, None)
plugin_service._submoded.pop(submod.__name__, None)
if submod.__spec__ and submod.__spec__.cached:
Path(submod.__spec__.cached).unlink(missing_ok=True)
self.submodules.clear()
for subplug in self.subplugins:
if subplug not in plugin_service.plugins:
continue
logger.debug(f"disposing sub-plugin {subplug} of {self.id}")
try:
plugin_service.plugins[subplug].dispose()
except Exception as e:
logger.error(f"failed to dispose sub-plugin {subplug} caused by {e!r}")
plugin_service.plugins.pop(subplug, None)
self.subplugins.clear()
for disp in self.dispatchers.values():
disp.dispose()
self.dispatchers.clear()
Expand Down Expand Up @@ -208,10 +211,10 @@ def validate(self, func):
)

def proxy(self):
return _ProxyModule(self.id)
return proxy(self.module)

def subproxy(self, sub_id: str):
return _ProxyModule(self.id, sub_id)
return proxy(plugin_service.plugins[sub_id].module)

def service(self, serv: Service | type[Service]):
if isinstance(serv, type):
Expand Down Expand Up @@ -246,79 +249,3 @@ def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T:
else:
obj = plugin_service._keep_values[plug.id][id_].obj # type: ignore
return obj


class _ProxyModule(ModuleType):

def __get_module(self) -> ModuleType:
mod = self.__origin()
if not mod:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
return mod

def __init__(self, plugin_id: str, sub_id: str | None = None) -> None:
self.__plugin_id = plugin_id
self.__sub_id = sub_id
if self.__plugin_id not in plugin_service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
if self.__sub_id:
self.__origin = ref(plugin_service.plugins[self.__plugin_id].submodules[self.__sub_id])
else:
self.__origin = ref(plugin_service.plugins[self.__plugin_id].module)
super().__init__(self.__get_module().__name__)
self.__doc__ = self.__get_module().__doc__
self.__file__ = self.__get_module().__file__
self.__loader__ = self.__get_module().__loader__
self.__package__ = self.__get_module().__package__
if path := getattr(self.__get_module(), "__path__", None):
self.__path__ = path
self.__spec__ = self.__get_module().__spec__

def __repr__(self):
if self.__sub_id:
return f"<ProxyModule {self.__sub_id!r}>"
return f"<ProxyModule {self.__plugin_id!r}>"

@property
def __dict__(self) -> dict[str, Any]:
return self.__get_module().__dict__

def __getattr__(self, name: str):
if name in (
"_ProxyModule__plugin_id",
"_ProxyModule__sub_id",
"_ProxyModule__origin",
"__name__",
"__doc__",
"__file__",
"__loader__",
"__package__",
"__path__",
"__spec__",
):
return super().__getattribute__(name)
if self.__plugin_id not in plugin_service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
if plug := inspect.currentframe().f_back.f_globals.get("__plugin__"): # type: ignore
if plug.id != self.__plugin_id:
plugin_service._referents[self.__plugin_id].add(plug.id)
elif plug := inspect.currentframe().f_back.f_back.f_globals.get("__plugin__"): # type: ignore
if plug.id != self.__plugin_id:
plugin_service._referents[self.__plugin_id].add(plug.id)
return getattr(self.__get_module(), name)

def __setattr__(self, name: str, value):
if name in (
"_ProxyModule__plugin_id",
"_ProxyModule__sub_id",
"_ProxyModule__origin",
"__name__",
"__doc__",
"__file__",
"__loader__",
"__package__",
"__path__",
"__spec__",
):
return super().__setattr__(name, value)
setattr(self.__get_module(), name, value)
91 changes: 62 additions & 29 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ast
from collections.abc import Sequence
from importlib import _bootstrap # type: ignore
from importlib import import_module
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader
from importlib.machinery import ExtensionFileLoader, PathFinder, SourceFileLoader
from importlib.util import module_from_spec, resolve_name
import sys
from types import ModuleType
Expand All @@ -12,6 +13,7 @@
from .service import plugin_service

_SUBMODULE_WAITLIST: dict[str, set[str]] = {}
_ENSURE_IS_PLUGIN: set[str] = set()


def package(*names: str):
Expand All @@ -21,6 +23,11 @@ def package(*names: str):
_SUBMODULE_WAITLIST.setdefault(plugin.module.__name__, set()).update(names)


def requires(*names: str):
"""手动指定哪些模块是插件"""
_ENSURE_IS_PLUGIN.update(names)


def __entari_import__(name: str, plugin_name: str, ensure_plugin: bool = False):
if name in plugin_service.plugins:
plug = plugin_service.plugins[name]
Expand All @@ -33,18 +40,33 @@ def __entari_import__(name: str, plugin_name: str, ensure_plugin: bool = False):
if plugin_name != mod.__plugin__.id:
plugin_service._referents[mod.__plugin__.id].add(plugin_name)
return mod.__plugin__.subproxy(name)
return __import__(name, fromlist=["__path__"])
if name in _ENSURE_IS_PLUGIN:
mod = import_plugin(name)
if mod:
if plugin_name != mod.__plugin__.id:
plugin_service._referents[mod.__plugin__.id].add(plugin_name)
return mod.__plugin__.proxy()
return __import__(name, fromlist=["__path__"])
if ensure_plugin:
module = import_plugin(name, plugin_name)
if not module:
raise ModuleNotFoundError(f"module {name!r} not found")
if hasattr(module, "__plugin__"):
if not plugin_name:
if name != module.__plugin__.id:
plugin_service._referents[name].add(module.__plugin__.id)
return module.__plugin__.proxy()
if plugin_name != module.__plugin__.id:
plugin_service._referents[module.__plugin__.id].add(plugin_name)
return module.__plugin__.subproxy(f"{plugin_name}{name}")
return module
return __import__(name, fromlist=["__path__"])
if not name.startswith("."):
return __import__(name, fromlist=["__path__"])
return import_module(f".{name}", plugin_name)


def getattr_or_import(module, name):
try:
return getattr(module, name)
except AttributeError:
return __entari_import__(f".{name}", module.__name__, True)


class PluginLoader(SourceFileLoader):
Expand All @@ -58,11 +80,20 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
The 'data' argument can be any object type that compile() supports.
"""
nodes = ast.parse(data, type_comments=True)
try:
nodes = ast.parse(data, type_comments=True)
except SyntaxError:
return _bootstrap._call_with_frames_removed(
compile, data, path, "exec", dont_inherit=True, optimize=_optimize
)
bodys = []
for body in nodes.body:
if isinstance(body, ast.ImportFrom):
if body.level == 0:
if body.module == "__future__":
print(1)
bodys.append(body)
continue
if len(body.names) == 1 and body.names[0].name == "*":
new = ast.parse(
f"__mod = __entari_import__({body.module!r}, {self.name!r});"
Expand All @@ -75,8 +106,11 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
else:
new = ast.parse(
f"__mod = __entari_import__({body.module!r}, {self.name!r});"
f"{';'.join(f'{alias.asname or alias.name} = __mod.{alias.name}' for alias in body.names)};"
f"del __mod"
+ ";".join(
f"{alias.asname or alias.name} = __getattr_or_import__(__mod, {alias.name!r})"
for alias in body.names
)
+ ";del __mod"
)
for node in ast.walk(new):
node.lineno = body.lineno # type: ignore
Expand All @@ -90,7 +124,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
new = ast.parse(
";".join(
f"{alias.asname or alias.name}="
f"__entari_import__('{relative}{alias.name}', {self.name!r}, True)"
f"__entari_import__('{relative}{alias.name}', {self.name!r}, {body.level == 1})"
for alias in body.names
)
)
Expand All @@ -102,7 +136,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
relative = "." * body.level
if len(body.names) == 1 and body.names[0].name == "*":
new = ast.parse(
f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, True);"
f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, {body.level == 1});"
f"__mod_all = getattr(__mod, '__all__', dir(__mod));"
"globals().update("
"{name: getattr(__mod, name) for name in __mod_all if not name.startswith('__')}"
Expand All @@ -111,9 +145,12 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
)
else:
new = ast.parse(
f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, True);"
f"{';'.join(f'{alias.asname or alias.name} = __mod.{alias.name}' for alias in body.names)};"
f"del __mod"
f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, {body.level == 1});"
+ ";".join(
f"{alias.asname or alias.name} = __getattr_or_import__(__mod, {alias.name!r})"
for alias in body.names
)
+ ";del __mod"
)
for node in ast.walk(new):
node.lineno = body.lineno # type: ignore
Expand Down Expand Up @@ -143,19 +180,8 @@ def create_module(self, spec) -> Optional[ModuleType]:

def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = None) -> None:
if plugin := plugin_service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None:
if module.__name__ == plugin.module.__name__: # from . import xxxx
return
setattr(module, "__plugin__", plugin)
setattr(module, "__entari_import__", __entari_import__)
try:
super().exec_module(module)
except Exception:
delattr(module, "__plugin__")
raise
else:
plugin.submodules[module.__name__] = module
plugin_service._submoded[module.__name__] = plugin.id
return
plugin.subplugins.add(module.__name__)
plugin_service._subplugined[module.__name__] = plugin.id

if self.loaded:
return
Expand All @@ -164,6 +190,7 @@ def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = Non
plugin = Plugin(module.__name__, module, config=config or {})
setattr(module, "__plugin__", plugin)
setattr(module, "__entari_import__", __entari_import__)
setattr(module, "__getattr_or_import__", getattr_or_import)

# enter plugin context
_plugin_token = _current_plugin.set(plugin)
Expand Down Expand Up @@ -198,6 +225,8 @@ def find_spec(
module_origin = module_spec.origin
if not module_origin:
return
if isinstance(module_spec.loader, ExtensionFileLoader):
return
if plug := _current_plugin.get(None):
if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin:
return plug.module.__spec__
Expand All @@ -212,8 +241,11 @@ def find_spec(
if module_spec.name in plugin_service.plugins:
module_spec.loader = PluginLoader(fullname, module_origin)
return module_spec
if module_spec.name in plugin_service._submoded:
module_spec.loader = PluginLoader(fullname, module_origin, plugin_service._submoded[module_spec.name])
if module_spec.name in plugin_service._subplugined:
module_spec.loader = PluginLoader(fullname, module_origin, plugin_service._subplugined[module_spec.name])
return module_spec
if module_spec.parent and module_spec.parent in plugin_service.plugins:
module_spec.loader = PluginLoader(fullname, module_origin, module_spec.parent)
return module_spec
return

Expand All @@ -224,6 +256,7 @@ def find_spec(name, package=None):
if parent_name:
if parent_name in plugin_service.plugins:
parent = plugin_service.plugins[parent_name].module

else:
parent = __import__(parent_name, fromlist=["__path__"])
try:
Expand Down
Loading

0 comments on commit 05cb10e

Please sign in to comment.