Skip to content

Commit

Permalink
❇️ detect config changes
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 5, 2024
1 parent 503e76e commit e8b04ff
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 52 deletions.
104 changes: 95 additions & 9 deletions arclet/entari/builtins/auto_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from pathlib import Path
from typing import Union

from arclet.letoderea import es
from launart import Launart, Service, any_completed
from launart.status import Phase
from watchfiles import PythonFilter, awatch

from arclet.entari import Plugin, dispose_plugin, load_plugin, metadata
from arclet.entari.config import EntariConfig
from arclet.entari.event.config import ConfigReload
from arclet.entari.logger import log
from arclet.entari.plugin import find_plugin_by_file
from arclet.entari.plugin import find_plugin, find_plugin_by_file

metadata(
"AutoReload",
Expand All @@ -31,44 +34,127 @@ def required(self) -> set[str]:
def stages(self) -> set[Phase]:
return {"blocking"}

def __init__(self, dirs: list[Union[str, Path]]):
def __init__(self, dirs: list[Union[str, Path]], is_watch_config: bool):
self.dirs = dirs
self.is_watch_config = is_watch_config
self.fail = {}
super().__init__()

async def watch(self):
async for event in awatch(*self.dirs, watch_filter=PythonFilter()):
for change in event:
if plugin := find_plugin_by_file(change[1]):
logger("INFO", f"Detected change in {plugin.id}, reloading...")
if plugin.id == "arclet.entari.builtins.auto_reload":
logger("DEBUG", f"Detected change in <blue>{plugin.id!r}</blue>, ignored")
continue
logger("INFO", f"Detected change in <blue>{plugin.id!r}</blue>, reloading...")
pid = plugin.id
del plugin
dispose_plugin(pid)
if plugin := load_plugin(pid):
logger("INFO", f"Reloaded {plugin.id}")
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
del plugin
else:
logger("ERROR", f"Failed to reload {pid}")
logger("ERROR", f"Failed to reload <blue>{pid!r}</blue>")
self.fail[change[1]] = pid
elif change[1] in self.fail:
logger("INFO", f"Detected change in {change[1]} which failed to reload, retrying...")
logger("INFO", f"Detected change in {change[1]!r} which failed to reload, retrying...")
if plugin := load_plugin(self.fail[change[1]]):
logger("INFO", f"Reloaded {plugin.id}")
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
del self.fail[change[1]]
else:
logger("ERROR", f"Failed to reload {self.fail[change[1]]}")
logger("ERROR", f"Failed to reload <blue>{self.fail[change[1]]!r}</blue>")

async def watch_config(self):
file = EntariConfig.instance.path
async for event in awatch(file.resolve().absolute().parent, recursive=False):
for change in event:
if change[0].name != "modified":
continue
if Path(change[1]).resolve().name != file.name:
continue
if not self.is_watch_config:
continue
logger("INFO", f"Detected change in {str(file)!r}, reloading config...")

old_basic = EntariConfig.instance.basic.copy()
old_plugin = EntariConfig.instance.plugin.copy()
EntariConfig.instance.reload()
for key in old_basic:
if key in EntariConfig.instance.basic and old_basic[key] != EntariConfig.instance.basic[key]:
logger(
"DEBUG",
f"Basic config <y>{key!r}</y> changed from <r>{old_basic[key]!r}</r> "
f"to <g>{EntariConfig.instance.basic[key]!r}</g>",
)
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key]))
for plugin_name in old_plugin:
pid = plugin_name.replace("::", "arclet.entari.builtins.")
if (
plugin_name not in EntariConfig.instance.plugin
or EntariConfig.instance.plugin[plugin_name] is False
):
dispose_plugin(pid)
logger("INFO", f"Disposed plugin <blue>{pid!r}</blue>")
continue
if old_plugin[plugin_name] != EntariConfig.instance.plugin[plugin_name]:
logger(
"DEBUG",
f"Plugin <y>{plugin_name!r}</y> config changed from <r>{old_plugin[plugin_name]!r}</r> "
f"to <g>{EntariConfig.instance.plugin[plugin_name]!r}</g>",
)
res = await es.post(
ConfigReload("plugin", plugin_name, EntariConfig.instance.plugin[plugin_name])
)
if res and res.value:
logger("DEBUG", f"Plugin <y>{pid!r}</y> config change handled by itself.")
continue
if plugin := find_plugin(pid):
logger("INFO", f"Detected <blue>{pid!r}</blue>'s config change, reloading...")
plugin_file = str(plugin.module.__file__)
dispose_plugin(plugin_name)
if plugin := load_plugin(plugin_name):
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
del plugin
else:
logger("ERROR", f"Failed to reload <blue>{plugin_name!r}</blue>")
self.fail[plugin_file] = pid
else:
logger("INFO", f"Detected <blue>{pid!r}</blue> appended, loading...")
load_plugin(plugin_name)
if new := (set(EntariConfig.instance.plugin) - set(old_plugin)):
for plugin_name in new:
load_plugin(plugin_name)

async def launch(self, manager: Launart):
async with self.stage("blocking"):
watch_task = asyncio.create_task(self.watch())
sigexit_task = asyncio.create_task(manager.status.wait_for_sigexit())
watch_config_task = asyncio.create_task(self.watch_config())
done, pending = await any_completed(
sigexit_task,
watch_task,
watch_config_task,
)
if sigexit_task in done:
watch_task.cancel()
watch_config_task.cancel()
self.fail.clear()


plug = Plugin.current()
plug.service(Watcher(plug.config.get("watch_dirs", ["."])))
watch_dirs = plug.config.get("watch_dirs", ["."])
watch_config = plug.config.get("watch_config", False)

plug.service(serv := Watcher(watch_dirs, watch_config))


@plug.use(ConfigReload)
def handle_config_reload(event: ConfigReload):
if event.scope != "plugin":
return
if event.key not in ("::auto_reload", "arclet.entari.builtins.auto_reload"):
return
serv.dirs = event.value.get("watch_dirs", ["."])
serv.is_watch_config = event.value.get("watch_config", False)
return True
37 changes: 25 additions & 12 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
from typing import Callable, Optional, TypeVar, Union, cast, overload

from arclet.alconna import Alconna, Arg, Args, Arparma, CommandMeta, Namespace, command_manager, config, output_manager
from arclet.alconna.tools.construct import AlconnaString, alconna_from_format
from arclet.alconna.typing import TAValue
from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Scope, Subscriber, es
from arclet.letoderea.handler import depend_handler
from arclet.letoderea.provider import ProviderFactory
from arclet.letoderea.typing import TTarget
from nepattern import DirectPattern
from satori.element import At, Text
from tarina.string import split
Expand All @@ -21,7 +22,6 @@
from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger, get_cmd

T = TypeVar("T")
TCallable = TypeVar("TCallable", bound=Callable[..., Any])


class EntariCommands:
Expand All @@ -30,7 +30,7 @@ class EntariCommands:
def __init__(self, need_tome: bool = False, remove_tome: bool = True):
self.trie: CharTrie[Subscriber] = CharTrie()
self.publisher = Publisher("entari.command", MessageCreatedEvent)
self.publisher.providers.append(AlconnaProviderFactory())
self.publisher.bind(AlconnaProviderFactory())
self.need_tome = need_tome
self.remove_tome = remove_tome
config.namespaces["Entari"] = Namespace(
Expand All @@ -39,7 +39,7 @@ def __init__(self, need_tome: bool = False, remove_tome: bool = True):
converter=lambda x: MessageChain(x),
)

@self.publisher.register(auxiliaries=[MessageJudger()])
@es.on(MessageCreatedEvent, auxiliaries=[MessageJudger()])
async def listener(event: MessageCreatedEvent):
msg = str(event.content.exclude(At)).lstrip()
if not msg:
Expand Down Expand Up @@ -108,7 +108,7 @@ def command(
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
):
class Command(AlconnaString):
def __call__(_cmd_self, func: TCallable) -> TCallable:
def __call__(_cmd_self, func: TTarget[T]) -> Subscriber[T]:
return self.on(_cmd_self.build(), need_tome, remove_tome, auxiliaries, providers)(func)

return Command(command, help_text)
Expand All @@ -121,7 +121,7 @@ def on(
remove_tome: bool = True,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
) -> Callable[[TCallable], TCallable]: ...
) -> Callable[[TTarget[T]], Subscriber[T]]: ...

@overload
def on(
Expand All @@ -134,7 +134,7 @@ def on(
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
) -> Callable[[TCallable], TCallable]: ...
) -> Callable[[TTarget[T]], Subscriber[T]]: ...

def on(
self,
Expand All @@ -146,11 +146,11 @@ def on(
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
) -> Callable[[TCallable], TCallable]:
) -> Callable[[TTarget[T]], Subscriber[T]]:
auxiliaries = auxiliaries or []
providers = providers or []

def wrapper(func: TCallable) -> TCallable:
def wrapper(func: TTarget[T]) -> Subscriber[T]:
if isinstance(command, str):
mapping = {arg.name: arg.value for arg in Args.from_callable(func)[0]}
mapping.update(args or {}) # type: ignore
Expand All @@ -165,6 +165,11 @@ def wrapper(func: TCallable) -> TCallable:
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
self.publisher.remove_subscriber(target)
self.trie[key] = target

def _remove(_):
self.trie.pop(key, None) # type: ignore

target._dispose = _remove
else:
auxiliaries.insert(
0, AlconnaSuppiler(command, need_tome or self.need_tome, remove_tome or self.remove_tome)
Expand All @@ -173,16 +178,24 @@ def wrapper(func: TCallable) -> TCallable:
self.publisher.remove_subscriber(target)
if not isinstance(command.command, str):
raise TypeError("Command name must be a string.")
keys = []
if not command.prefixes:
self.trie[command.command] = target
keys.append(command.command)
elif not all(isinstance(i, str) for i in command.prefixes):
raise TypeError("Command prefixes must be a list of string.")
else:
self.publisher.remove_subscriber(target)
for prefix in cast(list[str], command.prefixes):
self.trie[prefix + command.command] = target
keys.append(prefix + command.command)

def _remove(_):
for key in keys:
self.trie.pop(key, None) # type: ignore

target._dispose = _remove
command.reset_namespace(self.__namespace__)
return func
return target

return wrapper

Expand All @@ -200,7 +213,7 @@ def config_commands(need_tome: bool = False, remove_tome: bool = True):


async def execute(message: Union[str, MessageChain]):
res = await es.post(CommandExecute(message), CommandExecute.__disp_name__)
res = await es.post(CommandExecute(message))
if res:
return res.value

Expand Down
71 changes: 56 additions & 15 deletions arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .command import _commands
from .config import EntariConfig
from .event.config import ConfigReload
from .event.protocol import MessageCreatedEvent, event_parse
from .event.send import SendResponse
from .logger import log
Expand Down Expand Up @@ -50,6 +51,29 @@ async def __call__(self, context: Contexts):
global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider()])


def record():
@es.on(MessageCreatedEvent, priority=0)
async def log_msg(event: MessageCreatedEvent):
log.message.info(
f"[{event.channel.name or event.channel.id}] "
f"{event.member.nick if event.member else (event.user.name or event.user.id)}"
f"({event.user.id}) -> {event.message.content!r}"
)

@es.use(SendResponse.__publisher__)
async def log_send(event: SendResponse):
if event.session:
log.message.info(f"[{event.session.channel.name or event.session.channel.id}] <- {event.message!r}")
else:
log.message.info(f"[{event.channel}] <- {event.message!r}")

def dispose():
log_msg.dispose()
log_send.dispose()

return dispose


class Entari(App):
id = "entari.service"

Expand Down Expand Up @@ -88,6 +112,7 @@ def __init__(
if not hasattr(EntariConfig, "instance"):
EntariConfig.load()
log.set_level(log_level)
log.core.opt(colors=True).debug(f"Log level set to <y><c>{log_level}</c></y>")
for plug in EntariConfig.instance.plugin:
load_plugin(plug)
self.ignore_self_message = ignore_self_message
Expand All @@ -97,21 +122,37 @@ def __init__(
self._ref_tasks = set()

if record_message:

@self.on_message(priority=0)
async def log_msg(event: MessageCreatedEvent):
log.message.info(
f"[{event.channel.name or event.channel.id}] "
f"{event.member.nick if event.member else (event.user.name or event.user.id)}"
f"({event.user.id}) -> {event.message.content!r}"
)

@es.use(SendResponse.__publisher__)
async def log_send(event: SendResponse):
if event.session:
log.message.info(f"[{event.session.channel.name or event.session.channel.id}] <- {event.message!r}")
else:
log.message.info(f"[{event.channel}] <- {event.message!r}")
dispose = record()
else:
dispose = None

@es.on(ConfigReload)
def reset_self(scope, key, value):
nonlocal dispose
if scope != "basic":
return
if key == "log_level":
log.set_level(value)
log.core.opt(colors=True).debug(f"Log level set to <y><c>{value}</c></y>")
elif key == "ignore_self_message":
self.ignore_self_message = value
elif key == "record_message":
if value and not dispose:
dispose = record()
elif not value and dispose:
dispose()
dispose = None
elif key == "network":
for conn in self.connections:
it(Launart).remove_component(conn)
self.connections.clear()
for conf in value:
if conf["type"] in ("websocket", "websockets", "ws"):
self.apply(WebsocketsInfo(**{k: v for k, v in conf.items() if k != "type"}))
elif conf["type"] in ("webhook", "wh", "http"):
self.apply(WebhookInfo(**{k: v for k, v in conf.items() if k != "type"}))
for conn in self.connections:
it(Launart).add_component(conn)

def on(
self,
Expand Down
4 changes: 2 additions & 2 deletions arclet/entari/event/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from typing import Callable, TypeVar

from ..plugin import dispatch

TE = TypeVar("TE", bound="BasedEvent")


class BasedEvent:
@classmethod
def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None):
from ..plugin import dispatch

name = name or getattr(cls, "__publisher__", None)
return dispatch(cls, predicate=predicate, name=name) # type: ignore
Loading

0 comments on commit e8b04ff

Please sign in to comment.