diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 169931a..1399c97 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -87,6 +87,10 @@ async def lifespan(app: FastAPI): # sequential prompt processing for in-house LLMs (non-nc_texttotext) llm_lock = threading.Lock() +# lock to update the sources dict currently being processed +index_lock = threading.Lock() +_indexing = {} + # limit the number of concurrent document parsing doc_parse_semaphore = mp.Semaphore(app_config.doc_parser_worker_limit) @@ -286,10 +290,25 @@ def _(userId: str = Body(embed=True)): @app.put('/loadSources') @enabled_guard(app) def _(sources: list[UploadFile]): + global _indexing + if len(sources) == 0: return JSONResponse('No sources provided', 400) for source in sources: + if not value_of(source.filename): + return JSONResponse(f'Invalid source filename for: {source.headers.get("title")}', 400) + + with index_lock: + if source.filename in _indexing: + # this request will be retried by the client + return JSONResponse( + f'Source already being processed: {source.filename}', + 503, + headers={'cc-retry': 'true'}, + ) + _indexing[source.filename] = True + if not ( value_of(source.headers.get('userIds')) and value_of(source.headers.get('title')) @@ -300,13 +319,21 @@ def _(sources: list[UploadFile]): ): return JSONResponse(f'Invaild/missing headers for: {source.filename}', 400) - if not value_of(source.filename): - return JSONResponse(f'Invalid source filename for: {source.headers.get("title")}', 400) - - doc_parse_semaphore.acquire(block=True, timeout=29*60) # ~29 minutes + # wait for 10 minutes before failing the request + semres = doc_parse_semaphore.acquire(block=True, timeout=10*60) + if not semres: + return JSONResponse( + 'Document parser worker limit reached, try again in some time', + 503, + headers={'cc-retry': 'true'} + ) added_sources = exec_in_proc(target=embed_sources, args=(vectordb_loader, app.extra['CONFIG'], sources)) doc_parse_semaphore.release() + for source in sources: + with index_lock: + _indexing.pop(source.filename) + if len(added_sources) != len(sources): print( 'Count of newly loaded sources:', len(added_sources), diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py index 513c8aa..e2a125f 100644 --- a/context_chat_backend/vectordb/pgvector.py +++ b/context_chat_backend/vectordb/pgvector.py @@ -198,14 +198,18 @@ def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm. session.execute(stmt) session.commit() - access = [ - AccessListStore( - uid=user_id, - source_id=source_id, - ) - for user_id in user_ids - ] - session.add_all(access) + stmt = ( + sa.dialects.postgresql.insert(AccessListStore) + .values([ + { + 'uid': user_id, + 'source_id': source_id, + } + for user_id in user_ids + ]) + .on_conflict_do_nothing(index_elements=['uid', 'source_id']) + ) + session.execute(stmt) session.commit() if session_ is None: