Skip to content

Commit

Permalink
✨ feat(api): support github oauth2
Browse files Browse the repository at this point in the history
  • Loading branch information
centonhuang committed Jun 14, 2024
1 parent 91feaae commit b097ceb
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 65 deletions.
4 changes: 2 additions & 2 deletions aris_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import uvicorn

from src.api import create_app
from src.config import API_PORT
from src.config import API_HOST, API_PORT


def main() -> None:
app = create_app()
uvicorn.run(
app,
host="0.0.0.0",
host=API_HOST,
port=API_PORT,
log_level="critical",
)
Expand Down
4 changes: 4 additions & 0 deletions envs/local/api.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ LOGGER_ROOT=/path/to/log/
# api port
API_PORT=8080

# oauth2 config
OAUTH2_GITHUB_CLIENT_ID=xxx
OAUTH2_GITHUB_CLIENT_SECRET=xxx

# mysql config
MYSQL_HOST=localhost
MYSQL_PORT=3306
Expand Down
4 changes: 2 additions & 2 deletions src/api/router/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from .key import key_router
from .model import model_router
from .oauth2 import oauth2_router
from .session import session_router
from .user import user_router
from .vectordb import vector_db_router

v1_router = APIRouter(prefix="/v1", tags=["v1"])

v1_router.include_router(user_router)
v1_router.include_router(key_router)
v1_router.include_router(model_router)
v1_router.include_router(vector_db_router)
v1_router.include_router(session_router)
v1_router.include_router(oauth2_router)
7 changes: 7 additions & 0 deletions src/api/router/v1/oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from fastapi import APIRouter

from .github import github_router

oauth2_router = APIRouter(prefix="/oauth2", tags=["oauth2"])

oauth2_router.include_router(github_router)
59 changes: 59 additions & 0 deletions src/api/router/v1/oauth2/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from datetime import datetime
from typing import Literal

from src.middleware.jwt import encode_token
from src.middleware.mysql import session
from src.middleware.mysql.models import UserSchema


def login(name: str, unique_id: str, avatar: str, platform: Literal["github"]) -> str:
with session() as conn:
if not conn.is_active:
conn.rollback()
conn.close()
else:
conn.commit()

query = conn.query(UserSchema.uid, UserSchema.is_admin).filter(UserSchema.unique_id == unique_id).filter(UserSchema.platform == platform)
result = query.first()

if result:
(uid, is_admin) = result
else:
uid, is_admin = register(name, unique_id, avatar, platform), 0

with session() as conn:
if not conn.is_active:
conn.rollback()
conn.close()
else:
conn.commit()

conn.query(UserSchema).filter(UserSchema.uid == uid).update({UserSchema.last_login: datetime.now()})
conn.flush()
conn.commit()

return encode_token(uid=uid, level=is_admin)


def register(name: str, unique_id: str, avatar: str, platform: Literal["github"]) -> int:
with session() as conn:
if not conn.is_active:
conn.rollback()
conn.close()
else:
conn.commit()

user = UserSchema(
name=name,
unique_id=unique_id,
avatar=avatar,
platform=platform,
)
conn.add(user)
conn.flush()
conn.commit()

uid = user.uid

return uid
63 changes: 63 additions & 0 deletions src/api/router/v1/oauth2/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from urllib.parse import urlencode

import requests
from fastapi import APIRouter, HTTPException, status
from fastapi.responses import JSONResponse

from src.api.model.response import StandardResponse
from src.config.env import OAUTH2_GITHUB_CLIENT_ID, OAUTH2_GITHUB_CLIENT_SECRET
from src.config.gbl import OAUTH2_GITHUB_AUTH_URL, OAUTH2_GITHUB_REDIRECT_URL, OAUTH2_GITHUB_TOKEN_URL, OAUTH2_GITHUB_USER_API

from .common import login

github_router = APIRouter(prefix="/github", tags=["oauth2"])


@github_router.get("/login")
async def github_login() -> StandardResponse:
query = {
"client_id": OAUTH2_GITHUB_CLIENT_ID,
"redirect_uri": OAUTH2_GITHUB_REDIRECT_URL,
"scope": "user",
}
return StandardResponse(
code=1,
status="success",
data={"url": f"{OAUTH2_GITHUB_AUTH_URL}?{urlencode(query)}"},
)


@github_router.get("/callback")
async def github_callback(code: str) -> JSONResponse:
if not code:
raise HTTPException(status_code=400, detail="Code is required")

query = {
"client_id": OAUTH2_GITHUB_CLIENT_ID,
"client_secret": OAUTH2_GITHUB_CLIENT_SECRET,
"code": code,
}
headers = {"Accept": "application/json"}
response = requests.post(OAUTH2_GITHUB_TOKEN_URL, data=query, headers=headers)

if response.status_code != 200:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to retrieve access token")

data = response.json()
access_token = data.get("access_token")
if not access_token:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Access token is missing in response")

# Use the access token to get user info
user_response = requests.get(OAUTH2_GITHUB_USER_API, headers={"Authorization": f"Bearer {access_token}"})

if user_response.status_code != 200:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to retrieve user information")

user_data = user_response.json()
token = login(user_data["login"], user_data["id"], user_data["avatar_url"], "github")
return StandardResponse(
code=1,
status="success",
data={"token": token},
)
59 changes: 0 additions & 59 deletions src/api/router/v1/user.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/config/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
API_HOST = os.environ.get("API_HOST")
API_PORT = int(os.environ.get("API_PORT", "8000"))

OAUTH2_GITHUB_CLIENT_ID = os.environ.get("OAUTH2_GITHUB_CLIENT_ID")
OAUTH2_GITHUB_CLIENT_SECRET = os.environ.get("OAUTH2_GITHUB_CLIENT_SECRET")

MYSQL_DATABASE = os.environ.get("MYSQL_DATABASE")
MYSQL_HOST = os.environ.get("MYSQL_HOST")
MYSQL_PASSWORD = os.environ.get("MYSQL_PASSWORD")
Expand Down
8 changes: 8 additions & 0 deletions src/config/gbl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .env import API_HOST, API_PORT

SUPPORT_URL_TYPE = ["arxiv", "git", "render", "recursive"]

SUPPORT_UPLOAD_FILE = [
Expand Down Expand Up @@ -29,3 +31,9 @@
"cbl",
"cob",
]

OAUTH2_GITHUB_AUTH_URL = "https://github.com/login/oauth/authorize"
OAUTH2_GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token"
OAUTH2_GITHUB_USER_API = "https://api.github.com/user"

OAUTH2_GITHUB_REDIRECT_URL = f"http://{API_HOST}:{API_PORT}/v1/oauth2/github/callback"
7 changes: 5 additions & 2 deletions src/middleware/mysql/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ class UserSchema(BaseSchema):

__tablename__ = "users"
uid: int = Column(Integer, primary_key=True, autoincrement=True)
unique_id: str = Column(String(100), nullable=False, unique=True)
name: str = Column(String(100), nullable=False)
avatar: str = Column(String(255), nullable=False)
platform: str = Column(String(50), nullable=False)
create_at: datetime = Column(DateTime, default=datetime.now)
update_at: datetime = Column(DateTime, default=datetime.now, onupdate=datetime.now)
delete_at: datetime = Column(DateTime, nullable=True)
user_name: str = Column(String(255), nullable=False)
password: str = Column(String(255), nullable=False)
last_login: datetime = Column(DateTime, nullable=True, default=datetime.now)
is_admin: bool = Column(Boolean, nullable=False, default=False)
ak_num: int = Column(Integer, nullable=False, default=0)

0 comments on commit b097ceb

Please sign in to comment.