Skip to content

Commit

Permalink
💬 merge check_mod and check_module
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 30, 2024
1 parent 49c4d7e commit 8b19fef
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,7 @@ def package(*names: str):
_SUBMODULE_WAITLIST.setdefault(plugin.module.__name__, set()).update(names)


def _check_mod(name, package=None):
module = import_plugin(name, package)
if not module:
raise ModuleNotFoundError(f"module {name!r} not found")
if hasattr(module, "__plugin__"):
if not package:
if name != module.__plugin__.id:
service._referents[name].add(module.__plugin__.id)
return module.__plugin__.proxy()
return module.__plugin__.subproxy(f"{package}{name}")
return module


def _check_import(name: str, plugin_name: str):
def __entari_import__(name: str, plugin_name: str, ensure_plugin: bool = False):
if name in service.plugins:
plug = service.plugins[name]
if plugin_name != plug.id:
Expand All @@ -46,6 +33,17 @@ def _check_import(name: str, plugin_name: str):
if plugin_name != mod.__plugin__.id:
service._referents[mod.__plugin__.id].add(plugin_name)
return mod.__plugin__.subproxy(name)
if ensure_plugin:
module = import_plugin(name, plugin_name)
if not module:
raise ModuleNotFoundError(f"module {name!r} not found")
if hasattr(module, "__plugin__"):
if not plugin_name:
if name != module.__plugin__.id:
service._referents[name].add(module.__plugin__.id)
return module.__plugin__.proxy()
return module.__plugin__.subproxy(f"{plugin_name}{name}")
return module
return __import__(name, fromlist=["__path__"])


Expand All @@ -67,7 +65,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
if body.level == 0:
if len(body.names) == 1 and body.names[0].name == "*":
new = ast.parse(
f"__mod = __check_import({body.module!r}, {self.name!r});"
f"__mod = __entari_import__({body.module!r}, {self.name!r});"
f"__mod_all = getattr(__mod, '__all__', dir(__mod));"
"globals().update("
"{name: getattr(__mod, name) for name in __mod_all if not name.startswith('__')}"
Expand All @@ -76,7 +74,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
)
else:
new = ast.parse(
f"__mod = __check_import({body.module!r}, {self.name!r});"
f"__mod = __entari_import__({body.module!r}, {self.name!r});"
f"{';'.join(f'{alias.asname or alias.name} = __mod.{alias.name}' for alias in body.names)};"
f"del __mod"
)
Expand All @@ -91,7 +89,8 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
else:
new = ast.parse(
";".join(
f"{alias.asname or alias.name}=__check_mod('{relative}{alias.name}', {self.name!r})"
f"{alias.asname or alias.name}="
f"__entari_import__('{relative}{alias.name}', {self.name!r}, True)"
for alias in body.names
)
)
Expand All @@ -103,7 +102,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
relative = "." * body.level
if len(body.names) == 1 and body.names[0].name == "*":
new = ast.parse(
f"__mod = __check_mod('{relative}{body.module}', {self.name!r});"
f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, True);"
f"__mod_all = getattr(__mod, '__all__', dir(__mod));"
"globals().update("
"{name: getattr(__mod, name) for name in __mod_all if not name.startswith('__')}"
Expand All @@ -112,7 +111,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
)
else:
new = ast.parse(
f"__mod = __check_mod('{relative}{body.module}', {self.name!r});"
f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, True);"
f"{';'.join(f'{alias.asname or alias.name} = __mod.{alias.name}' for alias in body.names)};"
f"del __mod"
)
Expand All @@ -125,7 +124,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
new = ast.parse(
",".join(aliases)
+ "="
+ ",".join(f"__check_import({alias.name!r}, {self.name!r})" for alias in body.names)
+ ",".join(f"__entari_import__({alias.name!r}, {self.name!r})" for alias in body.names)
)
for node in ast.walk(new):
node.lineno = body.lineno # type: ignore
Expand All @@ -134,6 +133,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
else:
bodys.append(body)
nodes.body = bodys
print(ast.unparse(nodes))
return _bootstrap._call_with_frames_removed(compile, nodes, path, "exec", dont_inherit=True, optimize=_optimize)

def create_module(self, spec) -> Optional[ModuleType]:
Expand All @@ -147,8 +147,7 @@ def exec_module(self, module: ModuleType) -> None:
if module.__name__ == plugin.module.__name__: # from . import xxxx
return
setattr(module, "__plugin__", plugin)
setattr(module, "__check_mod", _check_mod)
setattr(module, "__check_import", _check_import)
setattr(module, "__entari_import__", __entari_import__)
try:
super().exec_module(module)
except Exception:
Expand All @@ -165,8 +164,7 @@ def exec_module(self, module: ModuleType) -> None:
# create plugin before executing
plugin = Plugin(module.__name__, module)
setattr(module, "__plugin__", plugin)
setattr(module, "__check_mod", _check_mod)
setattr(module, "__check_import", _check_import)
setattr(module, "__entari_import__", __entari_import__)

# enter plugin context
_plugin_token = _current_plugin.set(plugin)
Expand Down

0 comments on commit 8b19fef

Please sign in to comment.