From cde486a7ebbeb5e1968048b8269f21bfcf400e09 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 28 Feb 2024 20:41:00 +0530 Subject: [PATCH] add support for scoped context in query Signed-off-by: Anupam Kumar --- context_chat_backend/chain/__init__.py | 4 +- context_chat_backend/chain/one_shot.py | 60 +++++++++++++++--- context_chat_backend/controller.py | 74 +++++++++++++++++++++-- context_chat_backend/vectordb/__init__.py | 4 +- context_chat_backend/vectordb/base.py | 21 +++++++ context_chat_backend/vectordb/chroma.py | 29 +++++++-- context_chat_backend/vectordb/weaviate.py | 38 +++++++++--- 7 files changed, 203 insertions(+), 27 deletions(-) diff --git a/context_chat_backend/chain/__init__.py b/context_chat_backend/chain/__init__.py index 3ce46b6..433bbef 100644 --- a/context_chat_backend/chain/__init__.py +++ b/context_chat_backend/chain/__init__.py @@ -1,7 +1,9 @@ from .ingest import embed_sources -from .one_shot import process_query +from .one_shot import ScopeType, process_query, process_scoped_query __all__ = [ + 'ScopeType', 'embed_sources', 'process_query', + 'process_scoped_query', ] diff --git a/context_chat_backend/chain/one_shot.py b/context_chat_backend/chain/one_shot.py index b46a372..004f972 100644 --- a/context_chat_backend/chain/one_shot.py +++ b/context_chat_backend/chain/one_shot.py @@ -1,3 +1,6 @@ +from enum import Enum +from logging import error as log_error + from langchain.llms.base import LLM from ..vectordb import BaseVectorDB @@ -9,6 +12,11 @@ ''' +class ScopeType(Enum): + PROVIDER = 'provider' + SOURCE = 'source' + + def process_query( user_id: str, vectordb: BaseVectorDB, @@ -16,21 +24,59 @@ def process_query( query: str, use_context: bool = True, ctx_limit: int = 5, - template: str = _LLM_TEMPLATE, + ctx_filter: dict | None = None, + template: str | None = None, end_separator: str = '', -) -> tuple[str, set]: +) -> tuple[str, list[str]]: if not use_context: - return llm.predict(query), set() + return llm.predict(query), [] user_client = vectordb.get_user_client(user_id) if user_client is None: - return llm.predict(query), set() + return llm.predict(query), [] + + if ctx_filter is not None: + context_docs = user_client.similarity_search(query, k=ctx_limit, filter=ctx_filter) + else: + context_docs = user_client.similarity_search(query, k=ctx_limit) - context_docs = user_client.similarity_search(query, k=ctx_limit) context_text = '\n\n'.join(f'{d.metadata.get("title")}\n{d.page_content}' for d in context_docs) - output = llm.predict(template.format(context=context_text, question=query)) \ + output = llm.predict((template or _LLM_TEMPLATE).format(context=context_text, question=query)) \ .strip().rstrip(end_separator).strip() - unique_sources = {sources for d in context_docs if (sources := d.metadata.get('source'))} + unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))}) return (output, unique_sources) + + +def process_scoped_query( + user_id: str, + vectordb: BaseVectorDB, + llm: LLM, + query: str, + scope_type: ScopeType, + scope_list: list[str], + ctx_limit: int = 5, + template: str | None = None, + end_separator: str = '', +) -> tuple[str, list[str]]: + ctx_filter = vectordb.get_metadata_filter([{ + 'metadata_key': scope_type.value, + 'values': scope_list, + }]) + + if ctx_filter is None: + log_error(f'Error: could not get filter for (\nscope type: {scope_type}\n\ +scope list: {scope_list}\n\nproceeding with an unscoped query') + + return process_query( + user_id=user_id, + vectordb=vectordb, + llm=llm, + query=query, + use_context=True, + ctx_limit=ctx_limit, + ctx_filter=ctx_filter, + template=template, + end_separator=end_separator, + ) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index c4946bb..0b112d8 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -1,11 +1,12 @@ from os import getenv -from typing import Annotated +from typing import Annotated, Any from dotenv import load_dotenv from fastapi import BackgroundTasks, Body, FastAPI, Request, UploadFile from langchain.llms.base import LLM +from pydantic import BaseModel, FieldValidationInfo, field_validator -from .chain import embed_sources, process_query +from .chain import ScopeType, embed_sources, process_query, process_scoped_query from .download import download_all_models from .ocs_utils import AppAPIAuthMiddleware from .utils import JSONResponse, enabled_guard, update_progress, value_of @@ -190,6 +191,15 @@ def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5): if db is None: return JSONResponse('Error: VectorDB not initialised', 500) + if value_of(userId) is None: + return JSONResponse('Empty User ID', 400) + + if value_of(query) is None: + return JSONResponse('Empty query', 400) + + if ctxLimit < 1: + return JSONResponse('Invalid context chunk limit', 400) + template = app.extra.get('LLM_TEMPLATE') end_separator = app.extra.get('LLM_END_SEPARATOR', '') @@ -200,14 +210,66 @@ def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5): query=query, use_context=useContext, ctx_limit=ctxLimit, + template=template, end_separator=end_separator, - **({'template': template} if template else {}), ) - if output is None: - return JSONResponse('Error: check if the model specified supports the query type', 500) + return JSONResponse({ + 'output': output, + 'sources': sources, + }) + + +class ScopedQuery(BaseModel): + userId: str + query: str + scopeType: ScopeType + scopeList: list[str] + ctxLimit: int = 5 + + @field_validator('userId', 'query', 'scopeList', 'ctxLimit') + @classmethod + def check_empty_values(cls, value: Any, info: FieldValidationInfo): + if value_of(value) is None: + raise ValueError('Empty value for field', info.field_name) + + return value + + @field_validator('ctxLimit') + @classmethod + def at_least_one_context(cls, v: int): + if v < 1: + raise ValueError('Invalid context chunk limit') + + return v + +@app.post('/scopedQuery') +@enabled_guard(app) +def _(scopedQuery: ScopedQuery): + llm: LLM | None = app.extra.get('LLM_MODEL') + if llm is None: + return JSONResponse('Error: LLM not initialised', 500) + + db: BaseVectorDB | None = app.extra.get('VECTOR_DB') + if db is None: + return JSONResponse('Error: VectorDB not initialised', 500) + + template = app.extra.get('LLM_TEMPLATE') + end_separator = app.extra.get('LLM_END_SEPARATOR', '') + + (output, sources) = process_scoped_query( + user_id=scopedQuery.userId, + vectordb=db, + llm=llm, + query=scopedQuery.query, + ctx_limit=scopedQuery.ctxLimit, + template=template, + end_separator=end_separator, + scope_type=scopedQuery.scopeType, + scope_list=scopedQuery.scopeList, + ) return JSONResponse({ 'output': output, - 'sources': list(sources), + 'sources': sources, }) diff --git a/context_chat_backend/vectordb/__init__.py b/context_chat_backend/vectordb/__init__.py index 4b1a3cd..874f604 100644 --- a/context_chat_backend/vectordb/__init__.py +++ b/context_chat_backend/vectordb/__init__.py @@ -1,10 +1,10 @@ from importlib import import_module -from .base import BaseVectorDB +from .base import BaseVectorDB, MetadataFilter vector_dbs = ['weaviate', 'chroma'] -__all__ = ['get_vector_db', 'vector_dbs', 'BaseVectorDB', 'COLLECTION_NAME'] +__all__ = ['get_vector_db', 'vector_dbs', 'BaseVectorDB', 'COLLECTION_NAME', 'MetadataFilter'] # class name/index name is capitalized (user1 => User1) maybe because it is a class name, diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py index 1437dfc..2035ff7 100644 --- a/context_chat_backend/vectordb/base.py +++ b/context_chat_backend/vectordb/base.py @@ -14,6 +14,11 @@ class TSearchObject(TypedDict): TSearchDict = dict[str, TSearchObject] +class MetadataFilter(TypedDict): + metadata_key: str + values: list[str] + + class BaseVectorDB(ABC): client: Any = None embedding: Any = None @@ -59,6 +64,22 @@ def setup_schema(self, user_id: str) -> None: None ''' + @abstractmethod + def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None: + ''' + Returns the metadata filter for the given filters. + + Args + ---- + filters: tuple[MetadataFilter] + Tuple of metadata filters. + + Returns + ------- + dict + Metadata filter dictionary. + ''' + @abstractmethod def get_objects_from_metadata( self, diff --git a/context_chat_backend/vectordb/chroma.py b/context_chat_backend/vectordb/chroma.py index 30f7a6f..d19a42f 100644 --- a/context_chat_backend/vectordb/chroma.py +++ b/context_chat_backend/vectordb/chroma.py @@ -1,14 +1,14 @@ from logging import error as log_error from os import getenv -from chromadb import Client, Where +from chromadb import Client from chromadb.config import Settings from dotenv import load_dotenv from langchain.schema.embeddings import Embeddings from langchain.vectorstores import Chroma, VectorStore from . import COLLECTION_NAME -from .base import BaseVectorDB, TSearchDict +from .base import BaseVectorDB, MetadataFilter, TSearchDict load_dotenv() @@ -59,6 +59,19 @@ def get_user_client( embedding_function=em, ) + def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None: + if len(filters) == 0: + return None + + if len(filters) == 1: + return { filters[0]['metadata_key']: { '$in': filters[0]['values'] } } + + return { + '$or': [{ + f['metadata_key']: { '$in': f['values'] } + } for f in filters] + } + def get_objects_from_metadata( self, user_id: str, @@ -72,10 +85,18 @@ def get_objects_from_metadata( self.setup_schema(user_id) - if len(values) == 0: + try: + data_filter = self.get_metadata_filter([{ + 'metadata_key': metadata_key, + 'values': values, + }]) + except KeyError as e: + # todo: info instead of error + log_error(f'Error: Chromadb filter error: {e}') return {} - data_filter: Where = { metadata_key: { '$in': values } } # type: ignore + if data_filter is None: + return {} try: results = self.client.get_collection(COLLECTION_NAME(user_id)).get( diff --git a/context_chat_backend/vectordb/weaviate.py b/context_chat_backend/vectordb/weaviate.py index 718b26d..7ca4a03 100644 --- a/context_chat_backend/vectordb/weaviate.py +++ b/context_chat_backend/vectordb/weaviate.py @@ -8,7 +8,7 @@ from ..utils import value_of from . import COLLECTION_NAME -from .base import BaseVectorDB, TSearchDict +from .base import BaseVectorDB, MetadataFilter, TSearchDict load_dotenv() @@ -125,6 +125,26 @@ def get_user_client( return weaviate_obj + def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None: + if len(filters) == 0: + return None + + if len(filters) == 1: + return { + 'path': filters[0]['metadata_key'], + 'operator': 'ContainsAny', + 'valueTextList': filters[0]['values'], + } + + return { + 'operator': 'Or', + 'operands': [{ + 'path': f['metadata_key'], + 'operator': 'ContainsAny', + 'valueTextList': f['values'], + } for f in filters] + } + def get_objects_from_metadata( self, user_id: str, @@ -138,14 +158,18 @@ def get_objects_from_metadata( self.setup_schema(user_id) - if len(values) == 0: + try: + data_filter = self.get_metadata_filter([{ + 'metadata_key': metadata_key, + 'values': values, + }]) + except KeyError as e: + # todo: info instead of error + log_error(f'Error: Chromadb filter error: {e}') return {} - data_filter = { - 'path': [metadata_key], - 'operator': 'ContainsAny', - 'valueTextList': values, - } + if data_filter is None: + return {} results = self.client.query \ .get(COLLECTION_NAME(user_id), [metadata_key, 'modified']) \