Skip to content

Commit

Permalink
added async postgres to inference (#1961)
Browse files Browse the repository at this point in the history
  • Loading branch information
yk authored Mar 7, 2023
1 parent feae209 commit cd16d9c
Show file tree
Hide file tree
Showing 13 changed files with 335 additions and 260 deletions.
2 changes: 1 addition & 1 deletion inference/server/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# output_encoding = utf-8

# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/postgres

[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
Expand Down
39 changes: 30 additions & 9 deletions inference/server/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
from logging.config import fileConfig

import sqlmodel
from alembic import context
from loguru import logger
from oasst_inference_server import models # noqa: F401
from sqlalchemy import engine_from_config, pool
from sqlalchemy import engine_from_config, pool, text
from sqlalchemy.ext.asyncio import AsyncEngine

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down Expand Up @@ -50,7 +53,16 @@ def run_migrations_offline() -> None:
context.run_migrations()


def run_migrations_online() -> None:
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.get_context()._ensure_version_table()
connection.execute(text("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE"))
context.run_migrations()


async def run_async_migrations() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
Expand All @@ -61,18 +73,27 @@ def run_migrations_online() -> None:
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
connectable = AsyncEngine(connectable)

logger.info(f"Running migrations on {connectable.url}")

with context.begin_transaction():
context.get_context()._ensure_version_table()
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
context.run_migrations()
async with connectable.connect() as connection:
logger.info("Connected to database")
await connection.run_sync(do_run_migrations)
logger.info("Migrations complete")
logger.info("Disconnecting from database")
await connectable.dispose()
logger.info("Disconnected from database")


if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
connection = config.attributes.get("connection", None)
if connection is None:
asyncio.run(run_async_migrations())
else:
do_run_migrations(connection)
99 changes: 51 additions & 48 deletions inference/server/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import time
from pathlib import Path
import asyncio
import signal
import sys

import aiohttp
import alembic.command
import alembic.config
import fastapi
import sqlmodel
from fastapi import Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_inference_server import auth, client_handler, deps, models, worker_handler
from oasst_inference_server import auth, client_handler, database, deps, models, worker_handler
from oasst_inference_server.schemas import chat as chat_schema
from oasst_inference_server.schemas import worker as worker_schema
from oasst_inference_server.settings import settings
Expand Down Expand Up @@ -67,19 +66,23 @@ def get_root_token(token: str = Depends(get_bearer_token)) -> str:
)


def terminate_server(signum, frame):
logger.info(f"Signal {signum}. Terminating server...")
sys.exit(0)


@app.on_event("startup")
def alembic_upgrade():
async def alembic_upgrade():
signal.signal(signal.SIGINT, terminate_server)
if not settings.update_alembic:
logger.info("Skipping alembic upgrade on startup (update_alembic is False)")
return
logger.info("Attempting to upgrade alembic on startup")
retry = 0
while True:
try:
alembic_ini_path = Path(__file__).parent / "alembic.ini"
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
alembic.command.upgrade(alembic_cfg, "head")
async with database.make_engine().begin() as conn:
await conn.run_sync(database.alembic_upgrade)
logger.info("Successfully upgraded alembic on startup")
break
except Exception:
Expand All @@ -90,28 +93,26 @@ def alembic_upgrade():

timeout = settings.alembic_retry_timeout * 2**retry
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
time.sleep(timeout)
await asyncio.sleep(timeout)
signal.signal(signal.SIGINT, signal.SIG_DFL)


@app.on_event("startup")
def maybe_add_debug_api_keys():
async def maybe_add_debug_api_keys():
if not settings.debug_api_keys:
logger.info("No debug API keys configured, skipping")
return
try:
logger.info("Adding debug API keys")
with deps.manual_create_session() as session:
async with deps.manual_create_session() as session:
for api_key in settings.debug_api_keys:
logger.info(f"Checking if debug API key {api_key} exists")
if (
session.exec(
sqlmodel.select(models.DbWorker).where(models.DbWorker.api_key == api_key)
).one_or_none()
is None
):
await session.exec(sqlmodel.select(models.DbWorker).where(models.DbWorker.api_key == api_key))
).one_or_none() is None:
logger.info(f"Adding debug API key {api_key}")
session.add(models.DbWorker(api_key=api_key, name="Debug API Key"))
session.commit()
await session.commit()
else:
logger.info(f"Debug API key {api_key} already exists")
except Exception:
Expand All @@ -129,7 +130,7 @@ async def login_discord():
@app.get("/auth/callback/discord", response_model=protocol.Token)
async def callback_discord(
code: str,
db: sqlmodel.Session = Depends(deps.create_session),
db: database.AsyncSession = Depends(deps.create_session),
):
redirect_uri = f"{settings.api_root}/auth/callback/discord"

Expand Down Expand Up @@ -166,15 +167,15 @@ async def callback_discord(
raise HTTPException(status_code=400, detail="Invalid user info response from Discord")

# Try to find a user in our DB linked to the Discord user
user: models.DbUser = query_user_by_provider_id(db, discord_id=discord_id)
user: models.DbUser = await query_user_by_provider_id(db, discord_id=discord_id)

# Create if no user exists
if not user:
user = models.DbUser(provider="discord", provider_account_id=discord_id, display_name=discord_username)

db.add(user)
db.commit()
db.refresh(user)
await db.commit()
await db.refresh(user)

# Discord account is authenticated and linked to a user; create JWT
access_token = auth.create_access_token({"user_id": user.id})
Expand All @@ -188,7 +189,7 @@ async def list_chats(
) -> chat_schema.ListChatsResponse:
"""Lists all chats."""
logger.info("Listing all chats.")
chats = ucr.get_chats()
chats = await ucr.get_chats()
chats_list = [chat.to_list_read() for chat in chats]
return chat_schema.ListChatsResponse(chats=chats_list)

Expand All @@ -200,7 +201,7 @@ async def create_chat(
) -> chat_schema.ChatListRead:
"""Allows a client to create a new chat."""
logger.info(f"Received {request=}")
chat = ucr.create_chat()
chat = await ucr.create_chat()
return chat.to_list_read()


Expand All @@ -210,7 +211,7 @@ async def get_chat(
ucr: UserChatRepository = Depends(deps.create_user_chat_repository),
) -> chat_schema.ChatRead:
"""Allows a client to get the current state of a chat."""
chat = ucr.get_chat_by_id(id)
chat = await ucr.get_chat_by_id(id)
return chat.to_read()


Expand All @@ -225,45 +226,45 @@ async def get_chat(


@app.put("/worker")
def create_worker(
async def create_worker(
request: worker_schema.CreateWorkerRequest,
root_token: str = Depends(get_root_token),
session: sqlmodel.Session = Depends(deps.create_session),
):
session: database.AsyncSession = Depends(deps.create_session),
) -> worker_schema.WorkerRead:
"""Allows a client to register a worker."""
worker = models.DbWorker(name=request.name)
session.add(worker)
session.commit()
session.refresh(worker)
return worker
await session.commit()
await session.refresh(worker)
return worker_schema.WorkerRead.from_orm(worker)


@app.get("/worker")
def list_workers(
async def list_workers(
root_token: str = Depends(get_root_token),
session: sqlmodel.Session = Depends(deps.create_session),
):
session: database.AsyncSession = Depends(deps.create_session),
) -> list[worker_schema.WorkerRead]:
"""Lists all workers."""
workers = session.exec(sqlmodel.select(models.DbWorker)).all()
return list(workers)
workers = (await session.exec(sqlmodel.select(models.DbWorker))).all()
return [worker_schema.WorkerRead.from_orm(worker) for worker in workers]


@app.delete("/worker/{worker_id}")
def delete_worker(
async def delete_worker(
worker_id: str,
root_token: str = Depends(get_root_token),
session: sqlmodel.Session = Depends(deps.create_session),
session: database.AsyncSession = Depends(deps.create_session),
):
"""Deletes a worker."""
worker = session.get(models.DbWorker, worker_id)
worker = await session.get(models.DbWorker, worker_id)
session.delete(worker)
session.commit()
await session.commit()
return fastapi.Response(status_code=200)


def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = None) -> models.DbUser | None:
async def query_user_by_provider_id(db: database.AsyncSession, discord_id: str | None = None) -> models.DbUser | None:
"""Returns the user associated with a given provider ID if any."""
user_qry = db.query(models.DbUser)
user_qry = sqlmodel.select(models.DbUser)

if discord_id:
user_qry = user_qry.filter(models.DbUser.provider == "discord").filter(
Expand All @@ -273,12 +274,12 @@ def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = Non
else:
return None

user: models.DbUser = user_qry.first()
user: models.DbUser = (await db.exec(user_qry)).first()
return user


@app.get("/auth/login/debug")
async def login_debug(username: str, db: sqlmodel.Session = Depends(deps.create_session)):
async def login_debug(username: str, db: database.AsyncSession = Depends(deps.create_session)):
"""Login using a debug username, which the system will accept unconditionally."""

if not settings.allow_debug_auth:
Expand All @@ -288,14 +289,16 @@ async def login_debug(username: str, db: sqlmodel.Session = Depends(deps.create_
raise HTTPException(status_code=400, detail="Username is required")

# Try to find the user
user: models.DbUser = db.exec(sqlmodel.select(models.DbUser).where(models.DbUser.id == username)).one_or_none()
user: models.DbUser = (
await db.exec(sqlmodel.select(models.DbUser).where(models.DbUser.id == username))
).one_or_none()

if user is None:
logger.info(f"Creating new debug user {username=}")
user = models.DbUser(id=username, display_name=username, provider="debug", provider_account_id=username)
db.add(user)
db.commit()
db.refresh(user)
await db.commit()
await db.refresh(user)

# Discord account is authenticated and linked to a user; create JWT
access_token = auth.create_access_token({"user_id": user.id})
Expand Down
Loading

0 comments on commit cd16d9c

Please sign in to comment.