diff --git a/.env.example b/.env.example index 3637778..9f6bcf8 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,8 @@ BOT_TOKEN=... PB_EMAIL=... PB_PASSWORD=... +LAVALINK_URL=... GOOGLE_API_KEY=... LAVALINK_PASSWORD=... +OLLAMA_URL=... TZ=... diff --git a/bot/main.py b/bot/main.py index a61582f..0743b1b 100644 --- a/bot/main.py +++ b/bot/main.py @@ -10,6 +10,7 @@ from asyncio import sleep from datetime import UTC, datetime, time, timedelta from random import choice, randint +from typing import List from zoneinfo import ZoneInfo import coloredlogs @@ -20,6 +21,7 @@ from discord.commands import option from discord.ext import commands, tasks from dotenv import load_dotenv +from ollama import AsyncClient from pocketbase import PocketBaseError # type: ignore from pb import PB, pb_login from ui.message import StoreMessage @@ -35,6 +37,8 @@ TZ = ZoneInfo(os.getenv("TZ") or "Europe/Amsterdam") +ollama = AsyncClient(host=os.getenv("OLLAMA_URL") or "http://ai:11434") + class LogFilter(logging.Filter): def filter(self, record: logging.LogRecord): @@ -99,13 +103,31 @@ def filter(self, record: logging.LogRecord): logger.info("Found %s upcoming holidays", len(holidays)) +async def download_ai_models(models: List[str]): + """Download the AI models from the Ollama server.""" + downloaded_models = await ollama.list() + for model in models.copy(): + if any( + m["name"].replace(":latest", "") == model + for m in downloaded_models["models"] + ): + models.remove(model) + if not models: + return + logger.debug("Downloading %s AI models (%s)", len(models), ", ".join(models)) + for model in models: + logger.debug("Downloading %s...", model) + await ollama.pull(model=model) + + async def connect_nodes(): """Connect to our Lavalink nodes.""" await bert.wait_until_ready() nodes = [ wavelink.Node( - uri="http://lavalink:2333", password=os.getenv("LAVALINK_PASSWORD") + uri=os.getenv("LAVALINK_URL") or "http://lavalink:2333", + password=os.getenv("LAVALINK_PASSWORD"), ) ] await wavelink.Pool.connect(nodes=nodes, client=bert) @@ -209,6 +231,102 @@ async def on_message(message: discord.Message): await message.delete() return + if message.channel.name == "bert-ai": + available_models = [ + model["name"].split(":")[0] for model in (await ollama.list())["models"] + ] + model = "llama2-uncensored" + + if message.content == "bert clear": + await message.channel.send("Understood. ||bert-ignore||") + return + + if message.content.startswith("bert model"): + if len(message.content.split(" ")) == 2: + history = await message.channel.history( + limit=100, + before=message.created_at, + after=message.created_at - timedelta(minutes=10), + ).flatten() + history.reverse() + for msg in history: + if "bert-ignore" not in msg.content and not msg.author.bot: + if msg.content == "bert clear": + break + if ( + msg.content.startswith("bert model ") + and msg.content.split(" ")[2] in available_models + ): + model = msg.content.split(" ")[2] + break + await message.channel.send(f"Current model is {model}. ||bert-ignore||") + else: + if (model := message.content.split(" ")[2]) in available_models: + await message.channel.send(f"Model set to {model}. ||bert-ignore||") + else: + await message.channel.send( + f"Model {model} is not available. ||bert-ignore||" + ) + return + + async with message.channel.typing(): + images = [] + for sticker in message.stickers: + if sticker.format.name in ("png", "apng"): + images.append(await sticker.read()) + for attachment in message.attachments: + if attachment.content_type.startswith("image"): + images.append(await attachment.read()) + + messages = [] + history = await message.channel.history( + limit=100, + before=message.created_at, + after=message.created_at - timedelta(minutes=10), + oldest_first=True, + ).flatten() + for msg in history: + if "bert-ignore" not in msg.content: + if msg.author.bot: + if msg.author == bert.user: + messages.append( + {"role": "assistant", "content": msg.content} + ) + elif msg.content == "bert clear": + messages.clear() + elif msg.content.startswith("bert model "): + if msg.content.split(" ")[2] in available_models: + model = msg.content.split(" ")[2] + else: + images = [] + for sticker in message.stickers: + if sticker.format.name in ("png", "apng"): + images.append(await sticker.read()) + for attachment in message.attachments: + if attachment.content_type.startswith("image"): + images.append(await attachment.read()) + + messages.append( + {"role": "user", "content": msg.content, "images": images} + ) + messages.append( + {"role": "user", "content": message.content, "images": images} + ) + + ai_reply = await ollama.chat( + "llava" if images else model, messages=messages + ) + + if response := ai_reply["message"]["content"]: + if len(response) > 2000: + await message.channel.send( + "_The response is too long to send in one message_" + ) + else: + await message.channel.send(response) + else: + await message.channel.send("_No response from AI_") + @bert.event async def on_member_join(member: discord.Member): @@ -450,6 +568,39 @@ async def delete(interaction: discord.Interaction, key: str): ) +async def autocomplete_models(ctx: discord.AutocompleteContext): + """Autocomplete the AI models from the Ollama server.""" + models = (await ollama.list())["models"] + return [ + discord.OptionChoice(model["name"].split(":")[0]) + for model in models + if ctx.value in model["name"].split(":")[0] + ] + + +@bert.slash_command( + integration_types={ + discord.IntegrationType.guild_install, + discord.IntegrationType.user_install, + } +) +@option("prompt", description="The prompt to give to the AI") +@option("model", description="The model to use", autocomplete=autocomplete_models) +async def ai(interaction: discord.Interaction, prompt: str, model: str = "llama3.2"): + """Bert AI Technologies Ltd.""" + await interaction.response.defer() + ai_response = await ollama.generate(model, prompt) + if response := ai_response["response"]: + if len(response) > 2000: + await interaction.followup.send( + "_The response is too long to send in one message_" + ) + else: + await interaction.followup.send(response) + else: + await interaction.followup.send("_No response from AI_") + + @bert.slash_command(integration_types={discord.IntegrationType.user_install}) async def everythingisawesome(interaction: discord.Interaction): """Everything is AWESOME""" @@ -988,6 +1139,7 @@ async def main(): except PocketBaseError: logger.critical("Failed to login to Pocketbase") sys.exit(111) # Exit code 111: Connection refused + await download_ai_models(["llama2-uncensored", "llava"]) async with bert: await bert.start(os.getenv("BOT_TOKEN")) diff --git a/bot/requirements.txt b/bot/requirements.txt index 00464b9..d111233 100644 --- a/bot/requirements.txt +++ b/bot/requirements.txt @@ -6,3 +6,4 @@ pocketbase-async==0.11.0 coloredlogs==15.0.1 feedparser==6.0.11 PyNaCl==1.5.0 +ollama==0.3.3 diff --git a/docker-compose.yml b/docker-compose.yml index 11a78ad..7f84bed 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,8 @@ services: condition: service_healthy lavalink: condition: service_healthy + ai: + condition: service_started restart: unless-stopped website: @@ -67,5 +69,20 @@ services: - ./bot/sounds:/opt/Lavalink/sounds restart: unless-stopped + ai: + container_name: bert-ai + image: ollama/ollama + volumes: + - ai:/root/.ollama + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + restart: unless-stopped + volumes: pocketbase: + ai: