Skip to content

Commit

Permalink
Support ollama model
Browse files Browse the repository at this point in the history
Signed-off-by: Hemslo Wang <[email protected]>
  • Loading branch information
hemslo committed Feb 6, 2024
1 parent 22e1f84 commit 0acd577
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 46 deletions.
11 changes: 10 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
OPENAI_API_KEY=
AUTH_TOKEN=
CHAT_PROVIDER=
EMBEDDING_DIM=
EMBEDDING_PROVIDER=
OLLAMA_CHAT_MODEL=
OLLAMA_EMBEDDING_MODEL=
OLLAMA_URL=
OPENAI_API_KEY=
OPENAI_CHAT_MODEL=
OPENAI_EMBEDDING_MODEL=
REDIS_URL=
17 changes: 10 additions & 7 deletions app/chains/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@
from langchain_core.runnables import (
RunnableParallel,
)
from langchain_openai import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate
from langserve import CustomUserType
from pydantic import Field

from app import config
from app.dependencies.llm import get_llm
from app.dependencies.redis import get_redis

llm = ChatOpenAI(
model=config.OPENAI_CHAT_MODEL,
temperature=0,
)
llm = get_llm()

retriever = get_redis().as_retriever(search_type="mmr")
retriever = get_redis().as_retriever(
search_type="mmr",
search_kwargs={
"fetch_k": 20,
"k": 3,
"lambda_mult": 0.5,
},
)


REPHRASE_TEMPLATE = """\
Expand Down
25 changes: 13 additions & 12 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

load_dotenv()

AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
if AUTH_TOKEN is None:
raise ValueError("AUTH_TOKEN is not set in the environment variables")

CHAT_PROVIDER = os.environ.get("CHAT_PROVIDER", "openai")
EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", 1536))
EMBEDDING_PROVIDER = os.environ.get("EMBEDDING_PROVIDER", "openai")
INDEX_NAME = "document"
INDEX_SCHEMA_PATH = Path(os.path.dirname(__file__)) / "schema.yaml"
OLLAMA_CHAT_MODEL = os.environ.get("OLLAMA_CHAT_MODEL", "llama2")
OLLAMA_EMBEDDING_MODEL = os.environ.get("OLLAMA_EMBEDDING_MODEL", "llama2")
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
OPENAI_CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", "gpt-3.5-turbo-1106")
OPENAI_EMBEDDING_MODEL = os.environ.get(
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
)

OPENAI_CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", "gpt-3.5-turbo-1106")

INDEX_SCHEMA_PATH = Path(os.path.dirname(__file__)) / "schema.yaml"

REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/")

INDEX_NAME = "document"

AUTH_TOKEN = os.environ.get("AUTH_TOKEN")

if AUTH_TOKEN is None:
raise ValueError("AUTH_TOKEN is not set in the environment variables")
29 changes: 29 additions & 0 deletions app/dependencies/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from langchain_core.embeddings import Embeddings

from app import config


def _get_embeddings() -> Embeddings:
match config.EMBEDDING_PROVIDER:
case "openai":
from langchain_openai import OpenAIEmbeddings

return OpenAIEmbeddings(
model=config.OPENAI_EMBEDDING_MODEL,
)
case "ollama":
from langchain_community.embeddings import OllamaEmbeddings

return OllamaEmbeddings(
model=config.OLLAMA_EMBEDDING_MODEL,
base_url=config.OLLAMA_URL,
)
case _:
raise ValueError(f"Unknown embedding provider: {config.EMBEDDING_PROVIDER}")


embeddings = _get_embeddings()


def get_embeddings() -> Embeddings:
return embeddings
30 changes: 30 additions & 0 deletions app/dependencies/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from langchain_core.language_models import BaseChatModel

from app import config


def _get_llm() -> BaseChatModel:
match config.CHAT_PROVIDER:
case "openai":
from langchain_openai import ChatOpenAI

return ChatOpenAI(
model=config.OPENAI_CHAT_MODEL,
temperature=0,
)
case "ollama":
from langchain_community.chat_models import ChatOllama

return ChatOllama(
model=config.OLLAMA_CHAT_MODEL,
base_url=config.OLLAMA_URL,
)
case _:
raise ValueError(f"Unknown chat provider: {config.CHAT_PROVIDER}")


llm = _get_llm()


def get_llm() -> BaseChatModel:
return llm
11 changes: 3 additions & 8 deletions app/dependencies/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,19 @@
from typing import Annotated

from fastapi import Depends
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores.redis import Redis

from app import config


embeddings = OpenAIEmbeddings(
model=config.OPENAI_EMBEDDING_MODEL,
)
from app.dependencies.embeddings import get_embeddings

rds = Redis(
redis_url=os.getenv("REDIS_URL", "redis://localhost:6379/"),
index_name=config.INDEX_NAME,
embedding=embeddings,
embedding=get_embeddings(),
index_schema=config.INDEX_SCHEMA_PATH,
)

rds._create_index_if_not_exist()
rds._create_index_if_not_exist(config.EMBEDDING_DIM)


def get_redis() -> Redis:
Expand Down
6 changes: 1 addition & 5 deletions app/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ text:
weight: 1.0
withsuffixtrie: false
vector:
- algorithm: HNSW
- algorithm: FLAT
datatype: FLOAT32
dims: 1536
distance_metric: COSINE
ef_construction: 200
ef_runtime: 10
epsilon: 0.01
m: 16
name: content_vector
12 changes: 12 additions & 0 deletions docker-compose.redis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
services:
redis:
image: redis/redis-stack:6.2.6-v11
ports:
- "6379:6379"
- "8001:8001"
environment:
REDIS_ARGS: --save 60 1000 --appendonly yes
volumes:
- redis-data:/data
volumes:
redis-data:
17 changes: 4 additions & 13 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
version: '3.8'
include:
- docker-compose.redis.yaml
services:
app:
build: .
Expand All @@ -8,15 +9,5 @@ services:
REDIS_URL: redis://redis:6379/
ports:
- "8000:8000"
redis:
image: redis/redis-stack:6.2.6-v11
ports:
- "6379:6379"
- "8001:8001"
environment:
REDIS_ARGS: --save 60 1000 --appendonly yes
volumes:
- redis-data:/data

volumes:
redis-data:
depends_on:
- redis

0 comments on commit 0acd577

Please sign in to comment.