Skip to content

Commit

Permalink
fix: metadata search for provider
Browse files Browse the repository at this point in the history
Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Feb 22, 2024
1 parent d834353 commit 3f771ee
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 75 deletions.
3 changes: 2 additions & 1 deletion context_chat_backend/chain/ingest/injest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _sources_to_documents(sources: list[UploadFile]) -> list[Document]:
'title': source.headers.get('title'),
'type': source.headers.get('type'),
'modified': source.headers.get('modified'),
'provider': source.headers.get('provider'),
}

document = Document(page_content=content, metadata=metadata)
Expand Down Expand Up @@ -158,7 +159,7 @@ def embed_sources(
# either not a file or a file that is allowed
sources_filtered = [
source for source in sources
if not source.filename.startswith('file: ')
if (source.filename is not None and not source.filename.startswith('file: '))
or _allowed_file(source)
]

Expand Down
59 changes: 20 additions & 39 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@ def _(userId: str):
# TODO: for testing, remove later
@app.get('/search')
@enabled_guard(app)
def _(userId: str, keyword: str):
from chromadb import ClientAPI
from .vectordb import COLLECTION_NAME
def _(userId: str, sourceNames: str):
sourceNames: list[str] = [source.strip() for source in sourceNames.split(',') if source.strip() != '']

if len(sourceNames) == 0:
return JSONResponse('No sources provided', 400)

db: BaseVectorDB = app.extra.get('VECTOR_DB')
client: ClientAPI = db.client
db.setup_schema(userId)

return JSONResponse(
client.get_collection(COLLECTION_NAME(userId)).get(
where_document={'$contains': [{'source': keyword}]},
include=['metadatas'],
)
)
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames)
sources = list(map(lambda s: s.get('id'), source_objs.values()))

return JSONResponse({ 'sources': sources })


@app.put('/enabled')
Expand Down Expand Up @@ -110,47 +111,26 @@ def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()])
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames)
res = db.delete_by_ids(userId, [
source.get('id')
for source in source_objs.values()
if value_of(source.get('id') is not None)
])

# NOTE: None returned in `delete_by_ids` should have meant an error but it didn't in the case of
# weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take
# class name as input, which will be required in future versions of weaviate)
if res is None:
print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \
successful, therefore not considered an error for now.')
res = db.delete(userId, 'source', sourceNames)

if res is False:
return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)

return JSONResponse('All valid sources deleted')


@app.post('/deleteMatchingSources')
@app.post('/deleteSourcesByProvider')
@enabled_guard(app)
def _(userId: Annotated[str, Body()], keyword: Annotated[str, Body()]):
def _(userId: Annotated[str, Body()], providerKey: Annotated[str, Body()]):
if value_of(providerKey) is None:
return JSONResponse('Invalid provider key provided', 400)

db: BaseVectorDB = app.extra.get('VECTOR_DB')

if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

objs = db.get_objects_from_metadata(userId, 'source', [keyword], True)
res = db.delete_by_ids(userId, [
obj.get('id')
for obj in objs.values()
if value_of(obj.get('id') is not None)
])

# NOTE: None returned in `delete_by_ids` should have meant an error but it didn't in the case of
# weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take
# class name as input, which will be required in future versions of weaviate)
if res is None:
print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \
successful, therefore not considered an error for now.')
res = db.delete(userId, 'provider', [providerKey])

if res is False:
return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)
Expand All @@ -169,6 +149,7 @@ def _(sources: list[UploadFile]):
value_of(source.headers.get('userId'))
and value_of(source.headers.get('type'))
and value_of(source.headers.get('modified'))
and value_of(source.headers.get('provider'))
for source in sources]
):
return JSONResponse('Invaild/missing headers', 400)
Expand Down
61 changes: 52 additions & 9 deletions context_chat_backend/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import VectorStore

from ..utils import value_of


class BaseVectorDB(ABC):
client = None
Expand Down Expand Up @@ -56,7 +58,6 @@ def get_objects_from_metadata(
user_id: str,
metadata_key: str,
values: List[str],
contains: bool = False,
) -> dict:
'''
Get all objects with the given metadata key and values.
Expand All @@ -70,9 +71,6 @@ def get_objects_from_metadata(
Metadata key to get.
values: List[str]
List of metadata names to get.
contains: bool
If True, gets all objects that contain any of the given values,
otherwise gets all objects that have the given values.
Returns
-------
Expand All @@ -89,7 +87,7 @@ def get_objects_from_metadata(
}
'''

def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]:
def delete_by_ids(self, user_id: str, ids: list[str]) -> bool:
'''
Deletes all documents with the given ids for the given user.
Expand All @@ -102,9 +100,9 @@ def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]:
Returns
-------
Optional[bool]
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
bool
True if deletion is successful,
False otherwise
'''
if len(ids) == 0:
return True
Expand All @@ -113,4 +111,49 @@ def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]:
if user_client is None:
return False

return user_client.delete(ids)
res = user_client.delete(ids)

# NOTE: None should have meant an error but it didn't in the case of
# weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take
# class name as input, which will be required in future versions of weaviate)
if res is None:
print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \
successful, therefore not considered an error for now.')
return True

return res

def delete(self, user_id: str, metadata_key: str, values: list[str]) -> bool:
'''
Deletes all documents with the matching values for the given metadata key.
Args
----
user_id: str
User ID from whose database to delete the documents.
metadata_key: str
Metadata key to delete by.
values: list[str]
List of metadata values to match.
Returns
-------
bool
True if deletion is successful,
False otherwise
'''
if len(values) == 0:
return True

user_client = self.get_user_client(user_id)
if user_client is None:
return False

objs = self.get_objects_from_metadata(user_id, metadata_key, values)
ids = [
obj.get('id')
for obj in objs.values()
if value_of(obj.get('id') is not None)
]

return self.delete_by_ids(user_id, ids)
12 changes: 1 addition & 11 deletions context_chat_backend/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def get_objects_from_metadata(
user_id: str,
metadata_key: str,
values: List[str],
contains: bool = False,
) -> dict:
# NOTE: the limit of objects returned is not known, maybe it would be better to set one manually

Expand All @@ -77,16 +76,7 @@ def get_objects_from_metadata(
if len(values) == 0:
return {}

if len(values) == 1:
if contains:
data_filter = { metadata_key: { '$in': values[0] } }
else:
data_filter = { metadata_key: values[0] }
else:
if contains:
data_filter = {'$or': [{ metadata_key: { '$in': val } } for val in values]}
else:
data_filter = {'$or': [{ metadata_key: val } for val in values]}
data_filter = { metadata_key: { '$in': values } }

try:
results = self.client.get_collection(COLLECTION_NAME(user_id)).get(
Expand Down
21 changes: 6 additions & 15 deletions context_chat_backend/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
'description': 'Last modified time of the file',
'name': 'modified',
},
{
'dataType': ['text'],
'description': 'The provider of the source',
'name': 'provider',
}
],
# TODO: optimisation for large number of objects
'vectorIndexType': 'hnsw',
Expand Down Expand Up @@ -126,31 +131,17 @@ def get_objects_from_metadata(
user_id: str,
metadata_key: str,
values: List[str],
contains: bool = False,
) -> dict:
# NOTE: the limit of objects returned is not known, maybe it would be better to set one manually

if not self.client:
raise Exception('Error: Weaviate client not initialised')

if not self.client.schema.exists(COLLECTION_NAME(user_id)):
self.setup_schema(user_id)
self.setup_schema(user_id)

if len(values) == 0:
return {}

# todo
if len(values) == 1:
if contains:
data_filter = { metadata_key: { '$in': values[0] } }
else:
data_filter = { metadata_key: values[0] }
else:
if contains:
data_filter = {'$or': [{ metadata_key: { '$in': val } } for val in values]}
else:
data_filter = {'$or': [{ metadata_key: val } for val in values]}

data_filter = {
'path': [metadata_key],
'operator': 'ContainsAny',
Expand Down

0 comments on commit 3f771ee

Please sign in to comment.