Skip to content

Commit

Permalink
Merge pull request #74 from socode-marcelo/main
Browse files Browse the repository at this point in the history
Added system prompt features and model management features
  • Loading branch information
ruecat authored Nov 11, 2024
2 parents 8534944 + eb75e15 commit 4884bbd
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 45 deletions.
98 changes: 92 additions & 6 deletions bot/func/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from aiohttp import ClientTimeout
from asyncio import Lock
from functools import wraps
# from bot.run import load_allowed_ids_from_db
from dotenv import load_dotenv
load_dotenv()
token = os.getenv("TOKEN")
Expand All @@ -25,6 +24,71 @@
else:
log_level = logging.getLevelName(log_level_str)
logging.basicConfig(level=log_level)

async def manage_model(action: str, model_name: str):
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/{action}"

if action == "pull":
# Use the exact payload structure from the curl example
data = json.dumps({"name": model_name})
headers = {
'Content-Type': 'application/json'
}
logging.info(f"Pulling model: {model_name}")
logging.info(f"Request URL: {url}")
logging.info(f"Request Payload: {data}")

async with session.post(url, data=data, headers=headers) as response:
logging.info(f"Pull model response status: {response.status}")
response_text = await response.text()
logging.info(f"Pull model response text: {response_text}")
return response
elif action == "delete":
data = json.dumps({"name": model_name})
headers = {
'Content-Type': 'application/json'
}
async with session.delete(url, data=data, headers=headers) as response:
return response
else:
logging.error(f"Unsupported model management action: {action}")
return None

def add_system_prompt(user_id, prompt, is_global):
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("INSERT INTO system_prompts (user_id, prompt, is_global) VALUES (?, ?, ?)",
(user_id, prompt, is_global))
conn.commit()
conn.close()

def get_system_prompts(user_id=None, is_global=None):
conn = sqlite3.connect('users.db')
c = conn.cursor()
query = "SELECT * FROM system_prompts WHERE 1=1"
params = []

if user_id is not None:
query += " AND user_id = ?"
params.append(user_id)

if is_global is not None:
query += " AND is_global = ?"
params.append(is_global)

c.execute(query, params)
prompts = c.fetchall()
conn.close()
return prompts

def delete_ystem_prompt(prompt_id):
conn = sqlite3.connect('users.db')
c = conn.cursor()
c.execute("DELETE FROM system_prompts WHERE id = ?", (prompt_id,))
conn.commit()
conn.close()

async def model_list():
async with aiohttp.ClientSession() as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/tags"
Expand All @@ -34,28 +98,50 @@ async def model_list():
return data["models"]
else:
return []

async def generate(payload: dict, modelname: str, prompt: str):
client_timeout = ClientTimeout(total=int(timeout))
async with aiohttp.ClientSession(timeout=client_timeout) as session:
url = f"http://{ollama_base_url}:{ollama_port}/api/chat"

# Prepare the payload according to Ollama API specification
ollama_payload = {
"model": modelname,
"messages": payload.get("messages", []),
"stream": payload.get("stream", True)
}

try:
async with session.post(url, json=payload) as response:
logging.info(f"Sending request to Ollama API: {url}")
logging.info(f"Payload: {json.dumps(ollama_payload, indent=2)}")

async with session.post(url, json=ollama_payload) as response:
if response.status != 200:
error_text = await response.text()
logging.error(f"API Error: {response.status} - {error_text}")
raise aiohttp.ClientResponseError(
status=response.status, message=response.reason
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"API Error: {error_text}"
)
buffer = b""

buffer = b""
async for chunk in response.content.iter_any():
buffer += chunk
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
line = line.strip()
if line:
yield json.loads(line)
try:
yield json.loads(line)
except json.JSONDecodeError as e:
logging.error(f"JSON Decode Error: {e}")
logging.error(f"Problematic line: {line}")

except aiohttp.ClientError as e:
print(f"Error during request: {e}")
logging.error(f"Client Error during request: {e}")
raise

def load_allowed_ids_from_db():
conn = sqlite3.connect('users.db')
Expand Down
Loading

0 comments on commit 4884bbd

Please sign in to comment.