Skip to content

Commit

Permalink
wip: db schema change
Browse files Browse the repository at this point in the history
Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Dec 2, 2024
1 parent 9dc37d7 commit 57265d7
Show file tree
Hide file tree
Showing 12 changed files with 557 additions and 400 deletions.
15 changes: 3 additions & 12 deletions context_chat_backend/chain/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,20 @@ def get_context_docs(
scope_type: ScopeType | None = None,
scope_list: list[str] | None = None,
) -> list[Document]:
user_client = vectordb.get_user_client(user_id)

# unscoped search
if not scope_type:
return user_client.similarity_search(query, k=ctx_limit)
return vectordb.doc_search(user_id, query, ctx_limit)

if not scope_list:
raise ContextException('Error: scope list must be provided and not empty if scope type is provided')

ctx_filter = vectordb.get_metadata_filter([{
'metadata_key': scope_type.value,
'values': scope_list,
}])

if ctx_filter is None:
raise ContextException(f'Error: could not get filter for \nscope type: {scope_type}\nscope list: {scope_list}')

return user_client.similarity_search(query, k=ctx_limit, filter=ctx_filter)
return vectordb.doc_search(user_id, query, ctx_limit, scope_type, scope_list)


def get_context_chunks(context_docs: list[Document]) -> list[str]:
context_chunks = []
for doc in context_docs:
# todo: just the filename perhaps?
if title := doc.metadata.get('title'):
context_chunks.append(title)
context_chunks.append(doc.page_content)
Expand Down
12 changes: 11 additions & 1 deletion context_chat_backend/chain/ingest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from langchain.schema import Document
from pydantic import BaseModel

from .injest import embed_sources

__all__ = [ 'embed_sources' ]
__all__ = [ 'embed_sources', 'InDocument' ]

class InDocument(BaseModel):
documents: list[Document] # the split documents of the same source
userIds: list[str]
source_id: str
provider: str
modified: int
17 changes: 0 additions & 17 deletions context_chat_backend/chain/ingest/delete.py

This file was deleted.

9 changes: 4 additions & 5 deletions context_chat_backend/chain/ingest/doc_splitter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from langchain.text_splitter import (
MarkdownTextSplitter,
RecursiveCharacterTextSplitter,
TextSplitter,
)
from functools import lru_cache

from langchain.text_splitter import MarkdownTextSplitter, RecursiveCharacterTextSplitter, TextSplitter


@lru_cache(maxsize=32)
def get_splitter_for(chunk_size: int, mimetype: str = 'text/plain') -> TextSplitter:
kwargs = {
'chunk_size': chunk_size,
Expand Down
159 changes: 51 additions & 108 deletions context_chat_backend/chain/ingest/injest.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import re
from logging import error as log_error

from fastapi.datastructures import UploadFile
from langchain.schema import Document

from ...config_parser import TConfig
from ...dyn_loader import VectorDBLoader
from ...utils import not_none, to_int
from ...vectordb import BaseVectorDB
from ...utils import is_valid_source_id, to_int
from ...vectordb import BaseVectorDB, DbException
from . import InDocument
from .doc_loader import decode_source
from .doc_splitter import get_splitter_for
from .mimetype_list import SUPPORTED_MIMETYPES


def _allowed_file(file: UploadFile) -> bool:
return file.headers.get('type', default='') in SUPPORTED_MIMETYPES
return file.headers['type'] in SUPPORTED_MIMETYPES


def _filter_sources(
user_id: str,
vectordb: BaseVectorDB,
sources: list[UploadFile]
) -> list[UploadFile]:
Expand All @@ -31,160 +30,104 @@ def _filter_sources(
------
DbException
'''
to_delete = {}

input_sources = {}
for source in sources:
if not not_none(source.filename) or not not_none(source.headers.get('modified')):
continue
input_sources[source.filename] = source.headers.get('modified')

existing_objects = vectordb.get_objects_from_metadata(
user_id,
'source',
list(input_sources.keys())
)

for source, existing_meta in existing_objects.items():
# recently modified files are re-embedded
if to_int(input_sources.get(source)) > to_int(existing_meta.get('modified')):
to_delete[source] = existing_meta.get('id')

# delete old sources
vectordb.delete_by_ids(user_id, list(to_delete.values()))

# sources not already in the vectordb + the ones that were deleted
new_sources = set(input_sources.keys()) \
.difference(set(existing_objects))
new_sources.update(set(to_delete.keys()))
try:
new_sources = vectordb.sources_to_embed(sources)
except Exception as e:
raise DbException('Error: Vectordb sources_to_embed error') from e

return [
source for source in sources
if source.filename in new_sources
]


def _sources_to_documents(sources: list[UploadFile]) -> dict[str, list[Document]]:
'''
Converts a list of sources to a dictionary of documents with the user_id as the key.
'''
documents = {}
def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]:
indocuments = []

for source in sources:
print('processing source:', source.filename, flush=True)
user_id = source.headers.get('userId')
if user_id is None:
log_error(f'userId not found in headers for source: {source.filename}')
continue

# transform the source to have text data
content = decode_source(source)

if content is None or content == '':
print('decoded empty source:', source.filename, flush=True)
continue

# replace more than two newlines with two newlines (also blank spaces, more than 4)
content = re.sub(r'((\r)?\n){3,}', '\n\n', content)
# NOTE: do not use this with all docs when programming files are added
content = re.sub(r'(\s){5,}', r'\g<1>', content)
# filter out null bytes
content = content.replace('\0', '')

if content is None or content == '':
print('decoded empty source after cleanup:', source.filename, flush=True)
continue

print('decoded non empty source:', source.filename, flush=True)

metadata = {
'source': source.filename,
'title': source.headers.get('title'),
'type': source.headers.get('type'),
'modified': source.headers.get('modified'),
'provider': source.headers.get('provider'),
'title': source.headers['title'],
'type': source.headers['type'],
}
doc = Document(page_content=content, metadata=metadata)

document = Document(page_content=content, metadata=metadata)

if documents.get(user_id) is not None:
documents[user_id].append(document)
else:
documents[user_id] = [document]

return documents
splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type'])
split_docs = splitter.split_documents([doc])

indocuments.append(InDocument(
documents=split_docs,
userIds=source.headers['userIds'].split(','),
source_id=source.filename, # pyright: ignore[reportArgumentType]
provider=source.headers['provider'],
modified=to_int(source.headers['modified']),
))

def _bucket_by_type(documents: list[Document]) -> dict[str, list[Document]]:
bucketed_documents = {}

for doc in documents:
doc_type = doc.metadata.get('type')

if bucketed_documents.get(doc_type) is not None:
bucketed_documents[doc_type].append(doc)
else:
bucketed_documents[doc_type] = [doc]

return bucketed_documents
return indocuments


def _process_sources(
vectordb: BaseVectorDB,
config: TConfig,
sources: list[UploadFile],
) -> bool:
filtered_sources = _filter_sources(sources[0].headers['userId'], vectordb, sources)
) -> list[str]:
'''
Processes the sources and adds them to the vectordb.
Returns the list of source ids that were successfully added.
'''
filtered_sources = _filter_sources(vectordb, sources)

if len(filtered_sources) == 0:
# no new sources to embed
print('Filtered all sources, nothing to embed', flush=True)
return True
return []

print('Filtered sources:', [source.filename for source in filtered_sources], flush=True)
ddocuments: dict[str, list[Document]] = _sources_to_documents(filtered_sources)
indocuments = _sources_to_indocuments(config, filtered_sources)

print('Converted sources to documents')

if len(ddocuments.keys()) == 0:
if len(indocuments) == 0:
# document(s) were empty, not an error
print('All documents were found empty after being processed', flush=True)
return True

success = True

for user_id, documents in ddocuments.items():
split_documents: list[Document] = []

type_bucketed_docs = _bucket_by_type(documents)

for _type, _docs in type_bucketed_docs.items():
text_splitter = get_splitter_for(config.embedding_chunk_size, _type)
split_docs = text_splitter.split_documents(_docs)
split_documents.extend(split_docs)

# replace more than two newlines with two newlines (also blank spaces, more than 4)
for doc in split_documents:
doc.page_content = re.sub(r'((\r)?\n){3,}', '\n\n', doc.page_content)
# NOTE: do not use this with all docs when programming files are added
doc.page_content = re.sub(r'(\s){5,}', r'\g<1>', doc.page_content)
# filter out null bytes
doc.page_content = doc.page_content.replace('\0', '')

# filter out empty documents
split_documents = list(filter(lambda doc: doc.page_content != '', split_documents))

print('split documents count:', len(split_documents), flush=True)

if len(split_documents) == 0:
continue

user_client = vectordb.get_user_client(user_id)
doc_ids = user_client.add_documents(split_documents)

print('Added documents to vectordb', flush=True)
# does not do per document error checking
success &= len(split_documents) == len(doc_ids)
return []

return success
added_sources = vectordb.add_indocuments(indocuments)
print('Added documents to vectordb', flush=True)
return added_sources


def embed_sources(
vectordb_loader: VectorDBLoader,
config: TConfig,
sources: list[UploadFile],
) -> bool:
) -> list[str]:
# either not a file or a file that is allowed
sources_filtered = [
source for source in sources
if (source.filename is not None and not source.filename.startswith('files__default: '))
if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
or _allowed_file(source)
]

Expand Down
Loading

0 comments on commit 57265d7

Please sign in to comment.