Skip to content

Commit

Permalink
type fixes and other numerous fixes
Browse files Browse the repository at this point in the history
- added pyproject.toml
- added github workflow for lint and static analysis
- use ruff and pyright
- add a separate pip dep for dev

Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Feb 23, 2024
1 parent 3f771ee commit 6d5581e
Show file tree
Hide file tree
Showing 19 changed files with 298 additions and 176 deletions.
17 changes: 0 additions & 17 deletions .flake8

This file was deleted.

41 changes: 41 additions & 0 deletions .github/workflows/lint-n-static-analysis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Lint and Static Analysis

on:
pull_request:
paths:
- main.py
- context_chat_backend/**
push:
branches:
- master
paths:
- main.py
- context_chat_backend/**

jobs:
analysis:
runs-on: ubuntu-latest

name: Lint and Static Analysis

steps:
- name: print pwd
run: |
pwd
- name: Setup python 3.11
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dev dependencies
run: |
pip install -r reqs.dev
- name: Lint with Ruff
run: |
ruff --output-format=github context_chat_backend main.py
- name: Static analysis with pyright
run: |
pyright context_chat_backend main.py
37 changes: 21 additions & 16 deletions context_chat_backend/chain/ingest/doc_loader.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from logging import error as log_error
import re
import tempfile
from collections.abc import Callable
from logging import error as log_error
from typing import BinaryIO

from fastapi import UploadFile
from pandas import read_csv, read_excel
from pypandoc import convert_text
from pypdf import PdfReader
from langchain.document_loaders import (
UnstructuredPowerPointLoader,
UnstructuredEmailLoader,
UnstructuredPowerPointLoader,
)
from pandas import read_csv, read_excel
from pypandoc import convert_text
from pypdf import PdfReader


def _temp_file_wrapper(file: BinaryIO, loader: callable, sep: str = '\n') -> str:
def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str:
raw_bytes = file.read()
tmp = tempfile.NamedTemporaryFile(mode='wb')
tmp.write(raw_bytes)
Expand All @@ -25,7 +26,7 @@ def _temp_file_wrapper(file: BinaryIO, loader: callable, sep: str = '\n') -> str
import os
os.remove(tmp.name)

return sep.join(map(lambda d: d.page_content, docs))
return sep.join(d.page_content for d in docs)


# -- LOADERS -- #
Expand All @@ -40,23 +41,23 @@ def _load_csv(file: BinaryIO) -> str:


def _load_epub(file: BinaryIO) -> str:
return convert_text(file.read(), 'plain', 'epub').strip()
return convert_text(str(file.read()), 'plain', 'epub').strip()


def _load_docx(file: BinaryIO) -> str:
return convert_text(file.read(), 'plain', 'docx').strip()
return convert_text(str(file.read()), 'plain', 'docx').strip()


def _load_ppt_x(file: BinaryIO) -> str:
return _temp_file_wrapper(file, lambda fp: UnstructuredPowerPointLoader(fp).load()).strip()


def _load_rtf(file: BinaryIO) -> str:
return convert_text(file.read(), 'plain', 'rtf').strip()
return convert_text(str(file.read()), 'plain', 'rtf').strip()


def _load_rst(file: BinaryIO) -> str:
return convert_text(file.read(), 'plain', 'rst').strip()
return convert_text(str(file.read()), 'plain', 'rst').strip()


def _load_xml(file: BinaryIO) -> str:
Expand All @@ -70,7 +71,7 @@ def _load_xlsx(file: BinaryIO) -> str:


def _load_odt(file: BinaryIO) -> str:
return convert_text(file.read(), 'plain', 'odt').strip()
return convert_text(str(file.read()), 'plain', 'odt').strip()


def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None:
Expand All @@ -95,7 +96,7 @@ def attachment_partitioner(


def _load_org(file: BinaryIO) -> str:
return convert_text(file.read(), 'plain', 'org').strip()
return convert_text(str(file.read()), 'plain', 'org').strip()


# -- LOADER FUNCTION MAP -- #
Expand Down Expand Up @@ -124,11 +125,15 @@ def decode_source(source: UploadFile) -> str | None:
try:
# .pot files are powerpoint templates but also plain text files,
# so we skip them to prevent decoding errors
if source.headers.get('title').endswith('.pot'):
if source.headers.get('title', '').endswith('.pot'):
return None

mimetype = source.headers.get('type')
if mimetype is None:
return None

if _loader_map.get(source.headers.get('type')):
return _loader_map[source.headers.get('type')](source.file)
if _loader_map.get(mimetype):
return _loader_map[mimetype](source.file)

return source.file.read().decode('utf-8')
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions context_chat_backend/chain/ingest/doc_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter:

mt_map = {
'text/markdown': MarkdownTextSplitter(**kwargs),
'application/json': RecursiveCharacterTextSplitter(separators=['{', '}', r'\[', r'\]', ',', ''], **kwargs), # noqa: E501
'application/json': RecursiveCharacterTextSplitter(separators=['{', '}', r'\[', r'\]', ',', ''], **kwargs),
# processed csv, does not contain commas
'text/csv': RecursiveCharacterTextSplitter(separators=['\n', ' ', ''], **kwargs),
# remove end tags for less verbosity, and remove all whitespace outside of tags
Expand All @@ -26,7 +26,7 @@ def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter:
'application/vnd.ms-excel.sheet.macroEnabled.12': RecursiveCharacterTextSplitter(separators=['\n\n', '\n', ' ', ''], **kwargs), # noqa: E501
}

if mimetype in mt_map.keys():
if mimetype in mt_map:
return mt_map[mimetype]

# all other mimetypes
Expand Down
17 changes: 9 additions & 8 deletions context_chat_backend/chain/ingest/injest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from logging import error as log_error
import re
from logging import error as log_error

from fastapi.datastructures import UploadFile
from langchain.schema import Document

from ...utils import to_int
from ...vectordb import BaseVectorDB
from .doc_loader import decode_source
from .doc_splitter import get_splitter_for
from .mimetype_list import SUPPORTED_MIMETYPES
from ...utils import to_int
from ...vectordb import BaseVectorDB


def _allowed_file(file: UploadFile) -> bool:
Expand Down Expand Up @@ -51,21 +51,22 @@ def _filter_documents(
.difference(set(existing_objects))
new_sources.update(set(to_delete.keys()))

filtered_documents = [
return [
doc for doc in documents
if doc.metadata.get('source') in new_sources
]

return filtered_documents


def _sources_to_documents(sources: list[UploadFile]) -> list[Document]:
def _sources_to_documents(sources: list[UploadFile]) -> dict[str, list[Document]]:
'''
Converts a list of sources to a dictionary of documents with the user_id as the key.
'''
documents = {}

for source in sources:
user_id = source.headers.get('userId')
if user_id is None:
log_error('userId not found in headers for source: ' + source.filename)
log_error(f'userId not found in headers for source: {source.filename}')
continue

# transform the source to have text data
Expand Down
13 changes: 5 additions & 8 deletions context_chat_backend/chain/one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,19 @@ def process_query(
ctx_limit: int = 5,
template: str = _LLM_TEMPLATE,
end_separator: str = '',
) -> tuple[str, list]:
) -> tuple[str, set]:
if not use_context:
return llm.predict(query), []
return llm.predict(query), set()

user_client = vectordb.get_user_client(user_id)
if user_client is None:
return llm.predict(query), []
return llm.predict(query), set()

context_docs = user_client.similarity_search(query, k=ctx_limit)
context_text = '\n\n'.join(map(
lambda d: f'{d.metadata.get("title")}\n{d.page_content}',
context_docs,
))
context_text = '\n\n'.join(f'{d.metadata.get("title")}\n{d.page_content}' for d in context_docs)

output = llm.predict(template.format(context=context_text, question=query)) \
.strip().rstrip(end_separator).strip()
unique_sources = list(set(map(lambda d: d.metadata.get('source', ''), context_docs)))
unique_sources = {d.metadata.get('source') for d in context_docs}

return (output, unique_sources)
30 changes: 20 additions & 10 deletions context_chat_backend/config_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from pprint import pprint
from typing import TypedDict

from ruamel.yaml import YAML

from .models import models
from .vectordb import vector_dbs


class TConfig(TypedDict):
vectordb: tuple[str, dict]
embedding: tuple[str, dict]
llm: tuple[str, dict]


def _first_in_list(
input_dict: dict[str, dict],
supported_list: list[str]
Expand All @@ -21,7 +28,7 @@ def _first_in_list(
return None


def get_config(file_path: str = 'config.yaml') -> dict[str, tuple[str, dict]]:
def get_config(file_path: str = 'config.yaml') -> TConfig:
'''
Get the config from the given file path (relative to the root directory).
'''
Expand All @@ -32,27 +39,30 @@ def get_config(file_path: str = 'config.yaml') -> dict[str, tuple[str, dict]]:
except Exception as e:
raise AssertionError('Error: could not load config from', file_path, 'file') from e

selected_config = {
'vectordb': _first_in_list(config.get('vectordb', {}), vector_dbs),
'embedding': _first_in_list(config.get('embedding', {}), models['embedding']),
'llm': _first_in_list(config.get('llm', {}), models['llm']),
}

if not selected_config['vectordb']:
vectordb = _first_in_list(config.get('vectordb', {}), vector_dbs)
if not vectordb:
raise AssertionError(
f'Error: vectordb should be at least one of {vector_dbs} in the config file'
)

if not selected_config['embedding']:
embedding = _first_in_list(config.get('embedding', {}), models['embedding'])
if not embedding:
raise AssertionError(
f'Error: embedding model should be at least one of {models["embedding"]} in the config file'
)

if not selected_config['llm']:
llm = _first_in_list(config.get('llm', {}), models['llm'])
if not llm:
raise AssertionError(
f'Error: llm model should be at least one of {models["llm"]} in the config file'
)

selected_config: TConfig = {
'vectordb': vectordb,
'embedding': embedding,
'llm': llm,
}

pprint(f'Selected config: {selected_config}')

return selected_config
Loading

0 comments on commit 6d5581e

Please sign in to comment.