Skip to content

Commit

Permalink
add support for scoped context in query
Browse files Browse the repository at this point in the history
Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Feb 29, 2024
1 parent 6941176 commit cde486a
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 27 deletions.
4 changes: 3 additions & 1 deletion context_chat_backend/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .ingest import embed_sources
from .one_shot import process_query
from .one_shot import ScopeType, process_query, process_scoped_query

__all__ = [
'ScopeType',
'embed_sources',
'process_query',
'process_scoped_query',
]
60 changes: 53 additions & 7 deletions context_chat_backend/chain/one_shot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from enum import Enum
from logging import error as log_error

from langchain.llms.base import LLM

from ..vectordb import BaseVectorDB
Expand All @@ -9,28 +12,71 @@
'''


class ScopeType(Enum):
PROVIDER = 'provider'
SOURCE = 'source'


def process_query(
user_id: str,
vectordb: BaseVectorDB,
llm: LLM,
query: str,
use_context: bool = True,
ctx_limit: int = 5,
template: str = _LLM_TEMPLATE,
ctx_filter: dict | None = None,
template: str | None = None,
end_separator: str = '',
) -> tuple[str, set]:
) -> tuple[str, list[str]]:
if not use_context:
return llm.predict(query), set()
return llm.predict(query), []

user_client = vectordb.get_user_client(user_id)
if user_client is None:
return llm.predict(query), set()
return llm.predict(query), []

if ctx_filter is not None:
context_docs = user_client.similarity_search(query, k=ctx_limit, filter=ctx_filter)
else:
context_docs = user_client.similarity_search(query, k=ctx_limit)

context_docs = user_client.similarity_search(query, k=ctx_limit)
context_text = '\n\n'.join(f'{d.metadata.get("title")}\n{d.page_content}' for d in context_docs)

output = llm.predict(template.format(context=context_text, question=query)) \
output = llm.predict((template or _LLM_TEMPLATE).format(context=context_text, question=query)) \
.strip().rstrip(end_separator).strip()
unique_sources = {sources for d in context_docs if (sources := d.metadata.get('source'))}
unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))})

return (output, unique_sources)


def process_scoped_query(
user_id: str,
vectordb: BaseVectorDB,
llm: LLM,
query: str,
scope_type: ScopeType,
scope_list: list[str],
ctx_limit: int = 5,
template: str | None = None,
end_separator: str = '',
) -> tuple[str, list[str]]:
ctx_filter = vectordb.get_metadata_filter([{
'metadata_key': scope_type.value,
'values': scope_list,
}])

if ctx_filter is None:
log_error(f'Error: could not get filter for (\nscope type: {scope_type}\n\
scope list: {scope_list}\n\nproceeding with an unscoped query')

return process_query(
user_id=user_id,
vectordb=vectordb,
llm=llm,
query=query,
use_context=True,
ctx_limit=ctx_limit,
ctx_filter=ctx_filter,
template=template,
end_separator=end_separator,
)
74 changes: 68 additions & 6 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from os import getenv
from typing import Annotated
from typing import Annotated, Any

from dotenv import load_dotenv
from fastapi import BackgroundTasks, Body, FastAPI, Request, UploadFile
from langchain.llms.base import LLM
from pydantic import BaseModel, FieldValidationInfo, field_validator

from .chain import embed_sources, process_query
from .chain import ScopeType, embed_sources, process_query, process_scoped_query
from .download import download_all_models
from .ocs_utils import AppAPIAuthMiddleware
from .utils import JSONResponse, enabled_guard, update_progress, value_of
Expand Down Expand Up @@ -190,6 +191,15 @@ def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5):
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

if value_of(userId) is None:
return JSONResponse('Empty User ID', 400)

if value_of(query) is None:
return JSONResponse('Empty query', 400)

if ctxLimit < 1:
return JSONResponse('Invalid context chunk limit', 400)

template = app.extra.get('LLM_TEMPLATE')
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

Expand All @@ -200,14 +210,66 @@ def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5):
query=query,
use_context=useContext,
ctx_limit=ctxLimit,
template=template,
end_separator=end_separator,
**({'template': template} if template else {}),
)

if output is None:
return JSONResponse('Error: check if the model specified supports the query type', 500)
return JSONResponse({
'output': output,
'sources': sources,
})


class ScopedQuery(BaseModel):
userId: str
query: str
scopeType: ScopeType
scopeList: list[str]
ctxLimit: int = 5

@field_validator('userId', 'query', 'scopeList', 'ctxLimit')
@classmethod
def check_empty_values(cls, value: Any, info: FieldValidationInfo):
if value_of(value) is None:
raise ValueError('Empty value for field', info.field_name)

return value

@field_validator('ctxLimit')
@classmethod
def at_least_one_context(cls, v: int):
if v < 1:
raise ValueError('Invalid context chunk limit')

return v

@app.post('/scopedQuery')
@enabled_guard(app)
def _(scopedQuery: ScopedQuery):
llm: LLM | None = app.extra.get('LLM_MODEL')
if llm is None:
return JSONResponse('Error: LLM not initialised', 500)

db: BaseVectorDB | None = app.extra.get('VECTOR_DB')
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

template = app.extra.get('LLM_TEMPLATE')
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

(output, sources) = process_scoped_query(
user_id=scopedQuery.userId,
vectordb=db,
llm=llm,
query=scopedQuery.query,
ctx_limit=scopedQuery.ctxLimit,
template=template,
end_separator=end_separator,
scope_type=scopedQuery.scopeType,
scope_list=scopedQuery.scopeList,
)

return JSONResponse({
'output': output,
'sources': list(sources),
'sources': sources,
})
4 changes: 2 additions & 2 deletions context_chat_backend/vectordb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from importlib import import_module

from .base import BaseVectorDB
from .base import BaseVectorDB, MetadataFilter

vector_dbs = ['weaviate', 'chroma']

__all__ = ['get_vector_db', 'vector_dbs', 'BaseVectorDB', 'COLLECTION_NAME']
__all__ = ['get_vector_db', 'vector_dbs', 'BaseVectorDB', 'COLLECTION_NAME', 'MetadataFilter']


# class name/index name is capitalized (user1 => User1) maybe because it is a class name,
Expand Down
21 changes: 21 additions & 0 deletions context_chat_backend/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class TSearchObject(TypedDict):
TSearchDict = dict[str, TSearchObject]


class MetadataFilter(TypedDict):
metadata_key: str
values: list[str]


class BaseVectorDB(ABC):
client: Any = None
embedding: Any = None
Expand Down Expand Up @@ -59,6 +64,22 @@ def setup_schema(self, user_id: str) -> None:
None
'''

@abstractmethod
def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None:
'''
Returns the metadata filter for the given filters.
Args
----
filters: tuple[MetadataFilter]
Tuple of metadata filters.
Returns
-------
dict
Metadata filter dictionary.
'''

@abstractmethod
def get_objects_from_metadata(
self,
Expand Down
29 changes: 25 additions & 4 deletions context_chat_backend/vectordb/chroma.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from logging import error as log_error
from os import getenv

from chromadb import Client, Where
from chromadb import Client
from chromadb.config import Settings
from dotenv import load_dotenv
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import Chroma, VectorStore

from . import COLLECTION_NAME
from .base import BaseVectorDB, TSearchDict
from .base import BaseVectorDB, MetadataFilter, TSearchDict

load_dotenv()

Expand Down Expand Up @@ -59,6 +59,19 @@ def get_user_client(
embedding_function=em,
)

def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None:
if len(filters) == 0:
return None

if len(filters) == 1:
return { filters[0]['metadata_key']: { '$in': filters[0]['values'] } }

return {
'$or': [{
f['metadata_key']: { '$in': f['values'] }
} for f in filters]
}

def get_objects_from_metadata(
self,
user_id: str,
Expand All @@ -72,10 +85,18 @@ def get_objects_from_metadata(

self.setup_schema(user_id)

if len(values) == 0:
try:
data_filter = self.get_metadata_filter([{
'metadata_key': metadata_key,
'values': values,
}])
except KeyError as e:
# todo: info instead of error
log_error(f'Error: Chromadb filter error: {e}')
return {}

data_filter: Where = { metadata_key: { '$in': values } } # type: ignore
if data_filter is None:
return {}

try:
results = self.client.get_collection(COLLECTION_NAME(user_id)).get(
Expand Down
38 changes: 31 additions & 7 deletions context_chat_backend/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..utils import value_of
from . import COLLECTION_NAME
from .base import BaseVectorDB, TSearchDict
from .base import BaseVectorDB, MetadataFilter, TSearchDict

load_dotenv()

Expand Down Expand Up @@ -125,6 +125,26 @@ def get_user_client(

return weaviate_obj

def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None:
if len(filters) == 0:
return None

if len(filters) == 1:
return {
'path': filters[0]['metadata_key'],
'operator': 'ContainsAny',
'valueTextList': filters[0]['values'],
}

return {
'operator': 'Or',
'operands': [{
'path': f['metadata_key'],
'operator': 'ContainsAny',
'valueTextList': f['values'],
} for f in filters]
}

def get_objects_from_metadata(
self,
user_id: str,
Expand All @@ -138,14 +158,18 @@ def get_objects_from_metadata(

self.setup_schema(user_id)

if len(values) == 0:
try:
data_filter = self.get_metadata_filter([{
'metadata_key': metadata_key,
'values': values,
}])
except KeyError as e:
# todo: info instead of error
log_error(f'Error: Chromadb filter error: {e}')
return {}

data_filter = {
'path': [metadata_key],
'operator': 'ContainsAny',
'valueTextList': values,
}
if data_filter is None:
return {}

results = self.client.query \
.get(COLLECTION_NAME(user_id), [metadata_key, 'modified']) \
Expand Down

0 comments on commit cde486a

Please sign in to comment.