From bfc6c5660944164d3675e49b1176fa5d0f69ff8b Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Thu, 9 Jan 2025 16:17:25 -0800 Subject: [PATCH 1/2] up --- js/sdk/package.json | 2 +- js/sdk/src/v3/clients/users.ts | 20 + py/core/database/users.py | 34 +- py/core/main/api/v3/retrieval_router.py | 4 +- py/core/main/api/v3/users_router.py | 65 ++- py/core/main/services/management_service.py | 456 -------------------- py/core/providers/auth/r2r_auth.py | 65 +-- py/core/providers/crypto/bcrypt.py | 2 + py/core/providers/crypto/nacl.py | 2 + py/nginx.conf | 18 +- py/sdk/v3/users.py | 27 +- 11 files changed, 174 insertions(+), 521 deletions(-) diff --git a/js/sdk/package.json b/js/sdk/package.json index 0c0d2d58e..153279acc 100644 --- a/js/sdk/package.json +++ b/js/sdk/package.json @@ -1,6 +1,6 @@ { "name": "r2r-js", - "version": "0.4.12", + "version": "0.4.13", "description": "", "main": "dist/index.js", "browser": "dist/index.browser.js", diff --git a/js/sdk/src/v3/clients/users.ts b/js/sdk/src/v3/clients/users.ts index 2d24245be..4cf8853d7 100644 --- a/js/sdk/src/v3/clients/users.ts +++ b/js/sdk/src/v3/clients/users.ts @@ -537,4 +537,24 @@ export class UsersClient { async oauthGithubAuthorize(): Promise<{ redirect_url: string }> { return this.client.makeRequest("GET", "users/oauth/github/authorize"); } + + @feature("users.oauthGoogleCallback") + async oauthGoogleCallback(options: { code: string; state: string }): Promise { + return this.client.makeRequest("POST", "users/oauth/google/callback", { + params: { + code: options.code, + state: options.state, + }, + }); + } + + @feature("users.oauthGithubCallback") + async oauthGithubCallback(options: { code: string; state: string }): Promise { + return this.client.makeRequest("POST", "users/oauth/github/callback", { + params: { + code: options.code, + state: options.state, + }, + }); + } } diff --git a/py/core/database/users.py b/py/core/database/users.py index c969f4b7f..68e312bf9 100644 --- a/py/core/database/users.py +++ b/py/core/database/users.py @@ -140,8 +140,6 @@ async def create_tables(self): """ await self.connection_manager.execute_query(check_columns_query) - - async def get_user_by_id(self, id: UUID) -> User: query, _ = ( QueryBuilder(self._get_table_name("users")) @@ -246,7 +244,13 @@ async def get_user_by_email(self, email: str) -> User: ) async def create_user( - self, email: str, password: Optional[str] = None, account_type: Optional[str] = "password", google_id: Optional[str] = None, github_id: Optional[str] = None, is_superuser: bool = False + self, + email: str, + password: Optional[str] = None, + account_type: Optional[str] = "password", + google_id: Optional[str] = None, + github_id: Optional[str] = None, + is_superuser: bool = False, ) -> User: """Create a new user.""" # 1) Check if a user with this email already exists @@ -298,7 +302,8 @@ async def create_user( "limits_overrides": None, "metadata": None, "account_type": account_type, - "hashed_password": hashed_password or "", # Ensure hashed_password is not None + "hashed_password": hashed_password + or "", # Ensure hashed_password is not None # !!WARNING - Upstream checks are required to treat oauth differently from password!! "google_id": google_id, "github_id": github_id, @@ -374,7 +379,6 @@ async def update_user( except R2RException: raise R2RException(status_code=404, message="User not found") - # If the new user.google_id != current_user.google_id, check for duplicates if user.email and (user.email != current_user.email): existing_email_user = await self.get_user_by_email(user.email) @@ -384,10 +388,11 @@ async def update_user( message="That email account is already associated with another user.", ) - # If the new user.google_id != current_user.google_id, check for duplicates if user.google_id and (user.google_id != current_user.google_id): - existing_google_user = await self.get_user_by_google_id(user.google_id) + existing_google_user = await self.get_user_by_google_id( + user.google_id + ) if existing_google_user and existing_google_user.id != user.id: raise R2RException( status_code=400, @@ -396,7 +401,9 @@ async def update_user( # Similarly for GitHub: if user.github_id and (user.github_id != current_user.github_id): - existing_github_user = await self.get_user_by_github_id(user.github_id) + existing_github_user = await self.get_user_by_github_id( + user.github_id + ) if existing_github_user and existing_github_user.id != user.id: raise R2RException( status_code=400, @@ -765,7 +772,6 @@ async def get_users_in_collection( "google_id", "github_id", "COUNT(*) OVER() AS total_entries", - ] ) .where("$1 = ANY(collection_ids)") @@ -1208,7 +1214,6 @@ async def export_to_csv( detail=f"Failed to export data: {str(e)}", ) - async def get_user_by_google_id(self, google_id: str) -> Optional[User]: """Return a User if the google_id is found; otherwise None.""" query, params = ( @@ -1237,7 +1242,9 @@ async def get_user_by_google_id(self, google_id: str) -> Optional[User]: .where("google_id = $1") .build() ) - result = await self.connection_manager.fetchrow_query(query, [google_id]) + result = await self.connection_manager.fetchrow_query( + query, [google_id] + ) if not result: return None @@ -1261,7 +1268,6 @@ async def get_user_by_google_id(self, google_id: str) -> Optional[User]: github_id=result["github_id"], ) - async def get_user_by_github_id(self, github_id: str) -> Optional[User]: """Return a User if the github_id is found; otherwise None.""" query, params = ( @@ -1290,7 +1296,9 @@ async def get_user_by_github_id(self, github_id: str) -> Optional[User]: .where("github_id = $1") .build() ) - result = await self.connection_manager.fetchrow_query(query, [github_id]) + result = await self.connection_manager.fetchrow_query( + query, [github_id] + ) if not result: return None diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 1edde1697..b25375b19 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -446,7 +446,7 @@ async def stream_generator(): return return StreamingResponse( - stream_generator(), media_type="application/json" + stream_generator(), media_type="text/event-stream" ) # type: ignore else: return response @@ -673,7 +673,7 @@ async def stream_generator(): return return StreamingResponse( - stream_generator(), media_type="application/json" + stream_generator(), media_type="text/event-stream" ) # type: ignore else: return response diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 5effc1583..dc70b11dd 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -1,14 +1,19 @@ import os import textwrap +import urllib.parse from typing import Optional from uuid import UUID -from fastapi import Body, Depends, Path, Query, Request, HTTPException +import requests +from fastapi import Body, Depends, HTTPException, Path, Query, Request from fastapi.background import BackgroundTasks from fastapi.responses import FileResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from google.auth.transport import requests as google_requests + +# missing these lines +from google.oauth2 import id_token from pydantic import EmailStr -import requests from core.base import R2RException from core.base.api.models import ( @@ -27,10 +32,6 @@ from ...abstractions import R2RProviders, R2RServices from .base_router import BaseRouterV3 -# missing these lines -from google.oauth2 import id_token -from google.auth.transport import requests as google_requests -import urllib.parse oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -139,10 +140,11 @@ def validate_password(password: str) -> bool: return False return True - if not validate_password(password): - raise R2RException( - f"Password must be at least 10 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character from '!@#$%^&*'.", - ) + # if not validate_password(password): + # raise R2RException( + # f"Password must be at least 10 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character from '!@#$%^&*'.", + # 400, + # ) registration_response = await self.services.auth.register( email, password @@ -1763,6 +1765,7 @@ async def get_user_limits( id ) return limits_info + @self.router.get("/users/oauth/google/authorize") async def google_authorize(): """ @@ -1786,7 +1789,9 @@ async def google_authorize(): # In a real app, you might return a RedirectResponse(google_auth_url) @self.router.get("/users/oauth/google/callback") - async def google_callback(request: Request, code: str = Query(...), state: str = Query(...)): + async def google_callback( + code: str = Query(...), state: str = Query(...) + ): """ Google's callback that will receive the `code` and `state`. We then exchange code for tokens, verify, and log the user in. @@ -1802,19 +1807,33 @@ async def google_callback(request: Request, code: str = Query(...), state: str = "grant_type": "authorization_code", }, ).json() - + print("token data = ", token_data) if "error" in token_data: - raise HTTPException(status_code=400, detail=f"Failed to get token: {token_data}") + raise HTTPException( + status_code=400, + detail=f"Failed to get token: {token_data}", + ) + print("verifying...") # 2. Verify the ID token id_token_str = token_data["id_token"] try: # google_auth.transport.requests.Request() is a session for verifying id_info = id_token.verify_oauth2_token( - id_token_str, google_requests.Request(), self.google_client_id + id_token_str, + google_requests.Request(), + self.google_client_id, ) + print("id_info = ", id_info) except ValueError as e: - raise HTTPException(status_code=400, detail=f"Token verification failed: {str(e)}") + raise HTTPException( + status_code=400, + detail=f"Token verification failed: {str(e)}", + ) + + print( + "doing the final call, storing data with our oauth callback handler..." + ) # id_info will contain "sub", "email", etc. google_id = id_info["sub"] @@ -1826,6 +1845,7 @@ async def google_callback(request: Request, code: str = Query(...), state: str = oauth_id=google_id, email=email, ) + print("token_response = ", token_response) # 4. Return tokens or redirect to your front-end # Some people store tokens in a cookie or redirect to a front-end route passing them as a query param. @@ -1850,7 +1870,9 @@ async def github_authorize(): return {"redirect_url": github_auth_url} @self.router.get("/users/oauth/github/callback") - async def github_callback(code: str = Query(...), state: str = Query(...)): + async def github_callback( + code: str = Query(...), state: str = Query(...) + ): """ GitHub callback route to exchange code for an access_token, then fetch user info from GitHub's API, @@ -1870,15 +1892,20 @@ async def github_callback(code: str = Query(...), state: str = Query(...)): ) token_data = token_resp.json() if "error" in token_data: - raise HTTPException(status_code=400, detail=f"Failed to get token: {token_data}") + raise HTTPException( + status_code=400, + detail=f"Failed to get token: {token_data}", + ) access_token = token_data["access_token"] # 2. Use the access_token to fetch user info user_info_resp = requests.get( "https://api.github.com/user", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ).json() - github_id = str(user_info_resp["id"]) # GitHub user ID is typically an integer + github_id = str( + user_info_resp["id"] + ) # GitHub user ID is typically an integer # fetch email (sometimes you need to call /user/emails endpoint if user sets email private) email = user_info_resp.get("email") diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 634af0d91..53ac481f9 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -1086,459 +1086,3 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: "usage": usage, } return result - - # async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: - # """ - # Return a dictionary containing: - # - The system default limits (from self.config.limits) - # - The user's overrides (from user.limits_overrides) - # - The final 'effective' set of limits after merging (overall) - # - The usage for each relevant limit (per-route usage, etc.) - # """ - # # 1) Fetch the user - # user = await self.providers.database.users_handler.get_user_by_id(user_id) - # user_overrides = user.limits_overrides or {} - - # # 2) Grab system defaults - # system_defaults = { - # "global_per_min": self.config.database.limits.global_per_min, - # "route_per_min": self.config.database.limits.route_per_min, - # "monthly_limit": self.config.database.limits.monthly_limit, - # # Add additional fields if your LimitSettings has them - # } - - # # 3) Build the overall (global) "effective limits" ignoring any specific route - # overall_effective = self.providers.database.limits_handler.determine_effective_limits( - # user, route="" - # ) - - # # 4) Build usage data. We'll do top-level usage for global_per_min/monthly, - # # then do route-by-route usage in a loop. - # usage: dict[str, Any] = {} - # now = datetime.now(timezone.utc) - # one_min_ago = now - timedelta(minutes=1) - - # # (a) Global usage (per-minute) - # global_per_min_used = await self.providers.database.limits_handler._count_requests( - # user_id, route=None, since=one_min_ago - # ) - # # (a2) Global usage (monthly) - i.e. usage across ALL routes - # global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests( - # user_id, route=None - # ) - - # usage["global_per_min"] = { - # "used": global_per_min_used, - # "limit": overall_effective.global_per_min, - # "remaining": ( - # overall_effective.global_per_min - global_per_min_used - # if overall_effective.global_per_min is not None - # else None - # ), - # } - # usage["monthly_limit"] = { - # "used": global_monthly_used, - # "limit": overall_effective.monthly_limit, - # "remaining": ( - # overall_effective.monthly_limit - global_monthly_used - # if overall_effective.monthly_limit is not None - # else None - # ), - # } - - # # (b) Route-level usage. We'll gather all routes from system + user overrides - # system_route_limits = self.config.database.route_limits # dict[str, LimitSettings] - # user_route_overrides = user_overrides.get("route_overrides", {}) - # route_keys = set(system_route_limits.keys()) | set(user_route_overrides.keys()) - - # usage["routes"] = {} - # for route in route_keys: - # # 1) Get the final merged limits for this specific route - # route_effective = self.providers.database.limits_handler.determine_effective_limits( - # user, route - # ) - - # # 2) Count requests for the last minute on this route - # route_per_min_used = await self.providers.database.limits_handler._count_requests( - # user_id, route, one_min_ago - # ) - - # # 3) Count route-specific monthly usage - # route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests( - # user_id, route - # ) - - # usage["routes"][route] = { - # "route_per_min": { - # "used": route_per_min_used, - # "limit": route_effective.route_per_min, - # "remaining": ( - # route_effective.route_per_min - route_per_min_used - # if route_effective.route_per_min is not None - # else None - # ), - # }, - # "monthly_limit": { - # "used": route_monthly_used, - # "limit": route_effective.monthly_limit, - # "remaining": ( - # route_effective.monthly_limit - route_monthly_used - # if route_effective.monthly_limit is not None - # else None - # ), - # }, - # } - - # # 5) Return a structured response - # result = { - # "system_defaults": system_defaults, - # "user_overrides": user_overrides, - # "effective_limits": { - # "global_per_min": overall_effective.global_per_min, - # "route_per_min": overall_effective.route_per_min, - # "monthly_limit": overall_effective.monthly_limit, - # }, - # "usage": usage, - # } - # print('result = ', result) - # return result - - # async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: - # """ - # Return a dictionary containing: - # - The system default limits (from self.config.limits) - # - The user's overrides (from user.limits_overrides) - # - The final 'effective' set of limits after merging (overall, i.e. ignoring route-specific for the moment) - # - The usage for each relevant limit (how many requests used, how many remain, etc.) - # including per-route usage. - # """ - # # 1) Fetch the user - # user = await self.providers.database.users_handler.get_user_by_id(user_id) - # user_overrides = user.limits_overrides or {} - - # # 2) Grab system defaults (just a dict version of self.config.database.limits) - # system_defaults = { - # "global_per_min": self.config.database.limits.global_per_min, - # "route_per_min": self.config.database.limits.route_per_min, - # "monthly_limit": self.config.database.limits.monthly_limit, - # # Add additional fields if your LimitSettings has them - # } - - # # 3) Build the overall "effective limits" ignoring any specific route - # # (you can pass an empty string or some sentinel if your code - # # merges route-specific overrides only if route is found). - # overall_effective = self.providers.database.limits_handler.determine_effective_limits( - # user, route="" - # ) - - # # 4) Build usage data. We'll do top-level usage for global_per_min/monthly, - # # then do route-by-route usage in a loop. - # usage: dict[str, Any] = {} - # now = datetime.now(timezone.utc) - # one_min_ago = now - timedelta(minutes=1) - - # # (a) Global usage - # global_per_min_used = await self.providers.database.limits_handler._count_requests( - # user_id, route=None, since=one_min_ago - # ) - # monthly_used = await self.providers.database.limits_handler._count_monthly_requests(user_id) - - # usage["global_per_min"] = { - # "used": global_per_min_used, - # "limit": overall_effective.global_per_min, - # "remaining": ( - # overall_effective.global_per_min - global_per_min_used - # if overall_effective.global_per_min is not None - # else None - # ), - # } - # usage["monthly_limit"] = { - # "used": monthly_used, - # "limit": overall_effective.monthly_limit, - # "remaining": ( - # overall_effective.monthly_limit - monthly_used - # if overall_effective.monthly_limit is not None - # else None - # ), - # } - - # # (b) Route-level usage. We'll gather all routes from system + user overrides - # system_route_limits = self.config.database.route_limits # dict[str, LimitSettings] - # user_route_overrides = user_overrides.get("route_overrides", {}) - # route_keys = set(system_route_limits.keys()) | set(user_route_overrides.keys()) - - # usage["routes"] = {} - # for route in route_keys: - # # 1) Get the final merged limits for this specific route - # route_effective = self.providers.database.limits_handler.determine_effective_limits( - # user, route - # ) - - # # 2) Count requests for the last minute on this route - # route_per_min_used = await self.providers.database.limits_handler._count_requests( - # user_id, route, one_min_ago - # ) - # # 3) For monthly usage, we’re reusing the total monthly usage. If you want per-route monthly usage, - # # you'll need a route-specific monthly counter. For now, we reuse `monthly_used`. - - # usage["routes"][route] = { - # "route_per_min": { - # "used": route_per_min_used, - # "limit": route_effective.route_per_min, - # "remaining": ( - # route_effective.route_per_min - route_per_min_used - # if route_effective.route_per_min is not None - # else None - # ), - # }, - # # If you want route-specific monthly usage, you'd do a route-based query in your limits handler. - # # For now, we just show how the route's monthly limit compares to the user's total monthly usage. - # "monthly_limit": { - # "used": monthly_used, - # "limit": route_effective.monthly_limit, - # "remaining": ( - # route_effective.monthly_limit - monthly_used - # if route_effective.monthly_limit is not None - # else None - # ), - # }, - # } - - # # 5) Return a structured response including: - # # - The system default limits - # # - The raw user_overrides - # # - The final "effective" overall limits (no route) - # # - Usage details for global + monthly + per-route - # result = { - # "system_defaults": system_defaults, - # "user_overrides": user_overrides, - # "effective_limits": { - # "global_per_min": overall_effective.global_per_min, - # "route_per_min": overall_effective.route_per_min, - # "monthly_limit": overall_effective.monthly_limit, - # }, - # "usage": usage, - # } - # print('result = ', result) - # return result - - # # async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: - # # """ - # # Return a dictionary containing: - # # - The system default limits (from self.config.limits) - # # - The user's overrides (from user.limits_overrides) - # # - The final 'effective' set of limits after merging - # # - The usage for each relevant limit (how many requests used, how many remain, etc.) - # # """ - # # # 1. Fetch the user to see if they have overrides - # # user = await self.providers.database.users_handler.get_user_by_id( - # # user_id - # # ) - - # # # 2. System defaults - # # system_defaults = { - # # "global_per_min": self.config.database.limits.global_per_min, - # # "route_per_min": self.config.database.limits.route_per_min, - # # "monthly_limit": self.config.database.limits.monthly_limit, - # # # add other fields if your LimitSettings has them - # # } - - # # # 3. Grab user-level overrides - # # # (In your code, user.limits_overrides is a JSON field, e.g. {"global_per_min": 80, "route_overrides": {...}} ) - # # user_overrides = user.limits_overrides or {} - - # # # 4. Build effective limits by merging system_defaults with user_overrides - # # # For simplicity, we only directly handle "global_per_min" and "monthly_limit" at the top level - # # # Then route-specific overrides from user (like user_overrides["route_overrides"]) - # # # overshadow system route limits if they exist. - # # effective_limits = dict(system_defaults) - - # # # If the user added "global_per_min" or "monthly_limit" overrides, override them - # # if user_overrides.get("global_per_min") is not None: - # # effective_limits["global_per_min"] = user_overrides[ - # # "global_per_min" - # # ] - # # if user_overrides.get("monthly_limit") is not None: - # # effective_limits["monthly_limit"] = user_overrides["monthly_limit"] - # # if user_overrides.get("route_per_min") is not None: - # # effective_limits["route_per_min"] = user_overrides["route_per_min"] - - # # # We'll also gather route-level overrides from: - # # # - self.config.route_limits (system route overrides) - # # # - user_overrides["route_overrides"] (user route overrides) - # # # So we can later show usage for each route. - # # system_route_limits = ( - # # self.config.database.route_limits - # # ) # dict[str, LimitSettings] - # # user_route_overrides = user_overrides.get( - # # "route_overrides", {} - # # ) # e.g. { "/api/foo": {...}, ... } - - # # # 5. Build usage data - # # usage = {} - # # # => We'll fill usage["global_per_min"], usage["monthly_limit"], and usage["routes"][route] ... - # # # We'll rely on your PostgresLimitsHandler to do the counting. - - # # # (a) Compute usage for global_per_min (requests in last minute) & monthly_limit - # # now = datetime.now(timezone.utc) - # # one_min_ago = now - timedelta(minutes=1) - - # # # Use your limits_handler to count - # # global_per_min_used = ( - # # await self.providers.database.limits_handler._count_requests( - # # user_id, route=None, since=one_min_ago - # # ) - # # ) - # # monthly_used = await self.providers.database.limits_handler._count_monthly_requests( - # # user_id - # # ) - - # # # The final effective global/min is in `effective_limits["global_per_min"]`, etc. - # # usage["global_per_min"] = { - # # "used": global_per_min_used, - # # "limit": effective_limits["global_per_min"], - # # "remaining": ( - # # effective_limits["global_per_min"] - global_per_min_used - # # if effective_limits["global_per_min"] is not None - # # else None - # # ), - # # } - # # usage["monthly_limit"] = { - # # "used": monthly_used, - # # "limit": effective_limits["monthly_limit"], - # # "remaining": ( - # # effective_limits["monthly_limit"] - monthly_used - # # if effective_limits["monthly_limit"] is not None - # # else None - # # ), - # # } - - # # # (b) Build route-level usage - # # # We'll gather a union of the routes from system_route_limits + user_route_overrides - # # route_keys = set(system_route_limits.keys()) | set( - # # user_route_overrides.keys() - # # ) - # # usage["routes"] = {} - # # for route in route_keys: - # # # 1) System route-limits - # # sys_route_lim = system_route_limits.get(route) # or None - # # route_global_per_min = ( - # # sys_route_lim.global_per_min - # # if sys_route_lim - # # else system_defaults["global_per_min"] - # # ) - # # route_route_per_min = self._count_requests - # # # ( - # # # sys_route_lim.route_per_min - # # # if sys_route_lim - # # # else system_defaults["route_per_min"] - # # # ) - # # route_monthly_limit = ( - # # sys_route_lim.monthly_limit - # # if sys_route_lim - # # else system_defaults["monthly_limit"] - # # ) - - # # # 2) Merge user overrides for that route - # # user_route_cfg = user_route_overrides.get( - # # route, {} - # # ) # e.g. { "route_per_min": 25, "global_per_min": 80, ... } - # # if user_route_cfg.get("global_per_min") is not None: - # # route_global_per_min = user_route_cfg["global_per_min"] - # # if user_route_cfg.get("route_per_min") is not None: - # # route_route_per_min = user_route_cfg["route_per_min"] - # # if user_route_cfg.get("monthly_limit") is not None: - # # route_monthly_limit = user_route_cfg["monthly_limit"] - - # # # Now let's measure usage for this route over the last minute - # # route_per_min_used = ( - # # await self.providers.database.limits_handler._count_requests( - # # user_id, route, one_min_ago - # # ) - # # ) - # # # monthly usage is the same for all routes if there's a global monthly limit, - # # # but if you have route-specific monthly limits, we still want to do a global monthly count. - # # # (You can do something more advanced if you only want route-specific monthly usage, but - # # # your code currently lumps monthly usage by user_id, not by user+route.) - # # # We'll reuse monthly_used from above, so if there's a route-specific monthly limit, - # # # it compares the entire month's usage to that route limit. - # # # If you want only the route's monthly usage, you'd need a new function - # # # e.g. `_count_requests(user_id, route, start_of_month)` in your limits_handler. - - # # usage["routes"][route] = { - # # # The route-level per-minute usage (stuff relevant to route_per_min) - # # "route_per_min": { - # # "used": route_per_min_used, - # # "limit": route_route_per_min, - # # "remaining": ( - # # route_route_per_min - route_per_min_used - # # if route_route_per_min is not None - # # else None - # # ), - # # }, - # # # If you want to represent the "global_per_min" that applies to this route, - # # # you could put that here too if it’s route-specific. - # # # But typically "global_per_min" is for all requests, so usage is the same as above. - # # # The route-specific monthly usage, in your code, is not specifically counted by route, - # # # but if you want to do it the same as route_per_min, you'd do: - # # # route_monthly_used = await self.providers.database.limits_handler._count_requests( - # # # user_id, route, start_of_month - # # # ) - # # # We'll just reuse the global monthly usage to compare to the route's monthly limit: - # # "monthly_limit": { - # # "used": monthly_used, - # # "limit": route_monthly_limit, - # # "remaining": ( - # # route_monthly_limit - monthly_used - # # if route_monthly_limit is not None - # # else None - # # ), - # # }, - # # } - - # # # Return a structured response - # # return { - # # "system_defaults": system_defaults, - # # "user_overrides": user_overrides, - # # "effective_limits": effective_limits, - # # "usage": usage, - # # } - - # # """ - # # Return a dictionary containing: - # # - The system default limits (from config) - # # - The user overrides (from user.limits_overrides) - # # - The final "effective" set of limits after merging - # # """ - # # # 1. Fetch the user to see if they have overrides - # # user = await self.providers.database.users_handler.get_user_by_id( - # # user_id - # # ) - - # # # 2. System defaults (example: from self.config.limits) - # # # Adjust these names as needed based on your actual config - # # system_defaults = { - # # "global_per_min": self.config.database.limits.global_per_min, - # # "route_per_min": self.config.database.limits.route_per_min, - # # "monthly_limit": self.config.database.limits.monthly_limit, - # # # add other fields as needed - # # } - - # # # 3. Grab user-level overrides - # # user_overrides = ( - # # user.limits_overrides or {} - # # ) # In DB, typically a JSON field - - # # # 4. Merge them. "Effective" means the final set of limits after user overrides - # # # overshadow system defaults if present - # # effective_limits = dict(system_defaults) # start with system - # # for k, v in user_overrides.items(): - # # # If your overrides nest like {"global_per_min": X, "route_overrides": {...}}, - # # # you might need more robust merging logic. For simplicity, we do a shallow merge here. - # # effective_limits[k] = v - - # # return { - # # "system_defaults": system_defaults, - # # "user_overrides": user_overrides, - # # "effective_limits": effective_limits, - # # } diff --git a/py/core/providers/auth/r2r_auth.py b/py/core/providers/auth/r2r_auth.py index eceea9eae..5ba68f84d 100644 --- a/py/core/providers/auth/r2r_auth.py +++ b/py/core/providers/auth/r2r_auth.py @@ -207,7 +207,8 @@ async def register( if account_type == "password": if not password: raise R2RException( - status_code=400, message="Password is required for password accounts" + status_code=400, + message="Password is required for password accounts", ) else: if github_id and google_id: @@ -221,7 +222,12 @@ async def register( message="Invalid OAuth specification without GitHub or Google ID", ) new_user = await self.database_provider.users_handler.create_user( - email=email, password=password, is_superuser=is_superuser, account_type=account_type, github_id=github_id, google_id=google_id + email=email, + password=password, + is_superuser=is_superuser, + account_type=account_type, + github_id=github_id, + google_id=google_id, ) default_collection: CollectionResponse = ( await self.database_provider.collections_handler.create_collection( @@ -310,10 +316,12 @@ async def login(self, email: str, password: str) -> dict[str, Token]: ) if user.account_type != "password": - logger.warning(f"Password login not allowed for {user.account_type} accounts: {email}") + logger.warning( + f"Password login not allowed for {user.account_type} accounts: {email}" + ) raise R2RException( status_code=401, - message=f"This account is configured for {user.account_type} login, not password." + message=f"This account is configured for {user.account_type} login, not password.", ) logger.debug(f"User found: {user}") @@ -534,7 +542,9 @@ async def rename_api_key( name=new_name, ) - async def oauth_callback_handler(self, provider: str, oauth_id: str, email: Optional[str]) -> dict[str, Token]: + async def oauth_callback_handler( + self, provider: str, oauth_id: str, email: Optional[str] + ) -> dict[str, Token]: """ Handles a login/registration flow for OAuth providers (e.g., Google or GitHub). :param provider: "google" or "github" @@ -546,24 +556,27 @@ async def oauth_callback_handler(self, provider: str, oauth_id: str, email: Opti # The logic depends on your preference. We'll assume "google" => google_id, etc. try: if provider == "google": - user = await self.database_provider.users_handler.get_user_by_email(email) - # If user found, check if user.google_id matches or is null. If null, update it - if user and not user.google_id: - user.google_id = oauth_id - user.account_type = "google" - await self.database_provider.users_handler.update_user(user) - elif user and user.google_id != oauth_id: - # Edge case: Another user with same email? Or user changed google account? - # Decide how to handle. - pass - - if not user: + print("getting user...") + try: + user = await self.database_provider.users_handler.get_user_by_email( + email + ) + # If user found, check if user.google_id matches or is null. If null, update it + if user and not user.google_id: + raise R2RException( + status_code=401, + message="User already exists and is not linked to Google account", + ) + except: # Create new user - user = await self.database_provider.users_handler.create_user( - email=email or f"{oauth_id}@google_oauth.fake", # fallback - password=None, # no password - account_type="google", - google_id=oauth_id + user = ( + await self.database_provider.users_handler.create_user( + email=email + or f"{oauth_id}@google_oauth.fake", # fallback + password=None, # no password + account_type="google", + google_id=oauth_id, + ) ) elif provider == "github": # Similar approach for GitHub @@ -573,11 +586,15 @@ async def oauth_callback_handler(self, provider: str, oauth_id: str, email: Opti except R2RException: # If no user found or creation fails - raise R2RException(status_code=401, message="Could not create or fetch user") + raise R2RException( + status_code=401, message="Could not create or fetch user" + ) # If user is inactive, etc. if not user.is_active: - raise R2RException(status_code=401, message="User account is inactive") + raise R2RException( + status_code=401, message="User account is inactive" + ) # Possibly mark user as verified if you trust the OAuth provider's email user.is_verified = True diff --git a/py/core/providers/crypto/bcrypt.py b/py/core/providers/crypto/bcrypt.py index 41355118e..9d5d8e09e 100644 --- a/py/core/providers/crypto/bcrypt.py +++ b/py/core/providers/crypto/bcrypt.py @@ -1,4 +1,5 @@ import base64 +import logging import os from abc import ABC from datetime import datetime, timezone @@ -53,6 +54,7 @@ def __init__(self, config: BcryptCryptoConfig): raise ValueError( "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig" ) + logging.info("Initializing BcryptCryptoProvider") super().__init__(config) self.config: BcryptCryptoConfig = config diff --git a/py/core/providers/crypto/nacl.py b/py/core/providers/crypto/nacl.py index 71168b052..6cca949fd 100644 --- a/py/core/providers/crypto/nacl.py +++ b/py/core/providers/crypto/nacl.py @@ -1,5 +1,6 @@ import base64 import json +import logging import os import secrets import string @@ -51,6 +52,7 @@ def __init__(self, config: NaClCryptoConfig): ) super().__init__(config) self.config: NaClCryptoConfig = config + logging.info("Initializing NaClCryptoProvider") # Securely load the secret key for JWT # Priority: config.secret_key > environment variable > default diff --git a/py/nginx.conf b/py/nginx.conf index 0af23d37d..448e0c7bd 100644 --- a/py/nginx.conf +++ b/py/nginx.conf @@ -8,6 +8,7 @@ http { # Required basic settings include /etc/nginx/mime.types; default_type application/octet-stream; + client_max_body_size 100M; # Logging settings log_format main '$remote_addr - $remote_user [$time_local] "$request" ' @@ -36,20 +37,27 @@ http { proxy_send_timeout 300s; proxy_read_timeout 300s; - # Buffer settings - proxy_buffers 8 16k; - proxy_buffer_size 32k; - location / { proxy_pass http://r2r_backend; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection 'upgrade'; + proxy_set_header Connection ''; + proxy_set_header X-Accel-Buffering no; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + proxy_buffers 4 256k; + proxy_busy_buffers_size 512k; + proxy_buffer_size 128k; + + proxy_buffering off; + proxy_cache off; + + add_header Cache-Control no-cache; + add_header X-Accel-Buffering no; + # Retry settings proxy_next_upstream error timeout invalid_header http_500 http_502 http_503 http_504; proxy_next_upstream_tries 3; diff --git a/py/sdk/v3/users.py b/py/sdk/v3/users.py index 8509a525c..1139c534f 100644 --- a/py/sdk/v3/users.py +++ b/py/sdk/v3/users.py @@ -538,7 +538,6 @@ async def get_limits(self) -> dict[str, Any]: version="v3", ) - async def oauth_google_authorize(self) -> dict[str, str]: """ Get Google OAuth 2.0 authorization URL from the server. @@ -560,3 +559,29 @@ async def oauth_github_authorize(self) -> dict[str, str]: "users/oauth/github/authorize", version="v3", ) + + async def oauth_google_callback( + self, code: str, state: str + ) -> dict[str, Any]: + """ + Exchange `code` and `state` with the Google OAuth 2.0 callback route. + """ + return await self.client._make_request( + "GET", + "users/oauth/google/callback", + params={"code": code, "state": state}, + version="v3", + ) + + async def oauth_github_callback( + self, code: str, state: str + ) -> dict[str, Any]: + """ + Exchange `code` and `state` with the GitHub OAuth 2.0 callback route. + """ + return await self.client._make_request( + "GET", + "users/oauth/github/callback", + params={"code": code, "state": state}, + version="v3", + ) From e2bb519338159dd0161fd9c0e441f90ed0ac82ee Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Thu, 9 Jan 2025 16:20:29 -0800 Subject: [PATCH 2/2] rm ugly prints --- py/core/main/api/v3/users_router.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 7983fb291..b0e02b476 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -1807,14 +1807,12 @@ async def google_callback( "grant_type": "authorization_code", }, ).json() - print("token data = ", token_data) if "error" in token_data: raise HTTPException( status_code=400, detail=f"Failed to get token: {token_data}", ) - print("verifying...") # 2. Verify the ID token id_token_str = token_data["id_token"] try: @@ -1824,7 +1822,6 @@ async def google_callback( google_requests.Request(), self.google_client_id, ) - print("id_info = ", id_info) except ValueError as e: raise HTTPException( status_code=400,