Skip to content

Commit

Permalink
Merge pull request open-webui#772 from jannikstdl/choose-embedding-model
Browse files Browse the repository at this point in the history
feat: choose embedding model when using docker
  • Loading branch information
tjbck authored Feb 19, 2024
2 parents 792fe98 + 7c127c3 commit c391692
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 17 deletions.
23 changes: 19 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,24 @@ ENV WEBUI_SECRET_KEY ""
ENV SCARF_NO_ANALYTICS true
ENV DO_NOT_TRACK true

#Whisper TTS Settings
######## Preloaded models ########
# whisper TTS Settings
ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"

# RAG Embedding Model Settings
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR

######## Preloaded models ########

WORKDIR /app/backend

# install python dependencies
Expand All @@ -48,9 +62,10 @@ RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/*

# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')"
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"

# preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"

# copy embedding weight from build
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
Expand Down
2 changes: 1 addition & 1 deletion backend/apps/audio/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def transcribe(

model = WhisperModel(
WHISPER_MODEL,
device="cpu",
device="auto",
compute_type="int8",
download_root=WHISPER_MODEL_DIR,
)
Expand Down
72 changes: 61 additions & 11 deletions backend/apps/rag/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
Expand All @@ -14,7 +13,8 @@
from pathlib import Path
from typing import List

# from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions

from langchain_community.document_loaders import (
WebBaseLoader,
Expand All @@ -30,16 +30,12 @@
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import Chroma


from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
import json
import time


from apps.web.models.documents import (
Expand All @@ -58,23 +54,37 @@
from config import (
UPLOAD_DIR,
DOCS_DIR,
EMBED_MODEL,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)

from constants import ERROR_MESSAGES

# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
# model_name=EMBED_MODEL
# )
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )


app = FastAPI()

app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)


origins = ["*"]
Expand Down Expand Up @@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs]

try:
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)

collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
Expand All @@ -126,6 +139,38 @@ async def get_status():
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}


@app.get("/embedding/model")
async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}


class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str


@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)

return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}


Expand Down Expand Up @@ -190,8 +235,10 @@ def query_doc(
user=Depends(get_current_user),
):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
Expand Down Expand Up @@ -263,9 +310,12 @@ def query_collection(

for collection_name in form_data.collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)

result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)
Expand Down
7 changes: 6 additions & 1 deletion backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@
####################################

CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
)
CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
Expand Down

0 comments on commit c391692

Please sign in to comment.