Skip to content

Commit

Permalink
WIP: litellm vectorizer
Browse files Browse the repository at this point in the history
TODOs:
- tests for all integrations
- fine-tune configuration parameters
- fine-tune huggingface (cold-start?)
  • Loading branch information
JamesGuthrie committed Dec 20, 2024
1 parent 18f48b0 commit 366ae89
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 8 deletions.
26 changes: 26 additions & 0 deletions projects/extension/sql/idempotent/008-embedding.sql
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ $func$ language plpgsql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- embedding_litellm
create or replace function ai.embedding_litellm
( model pg_catalog.text
, dimensions pg_catalog.int4
, api_key_name pg_catalog.text default null
, extra_options pg_catalog.jsonb default null
) returns pg_catalog.jsonb
as $func$
begin
return json_object
( 'implementation': 'litellm'
, 'config_type': 'embedding'
, 'model': model
, 'dimensions': dimensions
, 'api_key_name': api_key_name
, 'extra_options': extra_options
absent on null
);
end
$func$ language plpgsql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- _validate_embedding
create or replace function ai._validate_embedding(config pg_catalog.jsonb) returns void
Expand All @@ -98,6 +122,8 @@ begin
-- ok
when 'voyageai' then
-- ok
when 'litellm' then
-- ok
else
if _implementation is null then
raise exception 'embedding implementation not specified';
Expand Down
21 changes: 18 additions & 3 deletions projects/pgai/pgai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import signal
import sys
import time
from asyncio import AbstractEventLoop
from collections.abc import Sequence
from typing import Any

Expand Down Expand Up @@ -130,7 +131,12 @@ def get_vectorizer(db_url: str, vectorizer_id: int) -> Vectorizer:
return vectorizer


def run_vectorizer(db_url: str, vectorizer: Vectorizer, concurrency: int) -> None:
def run_vectorizer(
db_url: str,
vectorizer: Vectorizer,
concurrency: int,
event_loop: AbstractEventLoop | None = None,
) -> None:
async def run_workers(
db_url: str, vectorizer: Vectorizer, concurrency: int
) -> list[int]:
Expand All @@ -140,7 +146,12 @@ async def run_workers(
]
return await asyncio.gather(*tasks)

results = asyncio.run(run_workers(db_url, vectorizer, concurrency))
if event_loop is None:
results = asyncio.run(run_workers(db_url, vectorizer, concurrency))
else:
results = event_loop.run_until_complete(
run_workers(db_url, vectorizer, concurrency)
)
items = sum(results)
log.info("finished processing vectorizer", items=items, vectorizer_id=vectorizer.id)

Expand Down Expand Up @@ -268,6 +279,8 @@ def vectorizer_worker(
# --once implies --exit-on-error
exit_on_error = True

event_loop = asyncio.new_event_loop()

while True:
try:
if not can_connect or pgai_version is None:
Expand Down Expand Up @@ -302,7 +315,9 @@ def vectorizer_worker(
try:
vectorizer = get_vectorizer(db_url, vectorizer_id)
log.info("running vectorizer", vectorizer_id=vectorizer_id)
run_vectorizer(db_url, vectorizer, concurrency)
run_vectorizer(
db_url, vectorizer, concurrency, event_loop=event_loop
)
except (VectorizerNotFoundError, ApiKeyNotFoundError) as e:
log.error(
f"error getting vectorizer: {type(e).__name__}: {str(e)} "
Expand Down
1 change: 1 addition & 0 deletions projects/pgai/pgai/vectorizer/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .litellm import LiteLLM as LiteLLM
from .ollama import Ollama as Ollama
from .openai import OpenAI as OpenAI
from .voyageai import VoyageAI as VoyageAI
103 changes: 103 additions & 0 deletions projects/pgai/pgai/vectorizer/embedders/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from collections.abc import Sequence
from functools import cached_property
from typing import Any, Literal

import litellm
from litellm import EmbeddingResponse as LiteLLMEmbeddingResponse # type: ignore
from pydantic import BaseModel
from typing_extensions import override

from ..embeddings import (
BatchApiCaller,
Embedder,
EmbeddingResponse,
EmbeddingVector,
StringDocument,
Usage,
logger,
)


class UnknownProviderError(Exception):
pass


class LiteLLM(BaseModel, Embedder):
"""
Embedder that uses LiteLLM to embed documents into vector representations.
Attributes:
implementation (Literal["litellm"]): The literal identifier for this
implementation.
model (str): The name of the embedding model.
api_key_name (str): The API key name.
extra_options (dict): Additional litellm-specific options
"""

implementation: Literal["litellm"]
model: str
api_key_name: str | None = None
extra_options: dict[str, Any] = {}

@override
async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]:
"""
Embeds a list of documents into vectors using LiteLLM.
Args:
documents (list[str]): A list of documents to be embedded.
Returns:
Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or
errors for each document.
"""
await logger.adebug(f"Chunks produced: {len(documents)}")
return await self._batcher.batch_chunks_and_embed(documents)

@cached_property
def _batcher(self) -> BatchApiCaller[StringDocument]:
return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api)

@override
def _max_chunks_per_batch(self) -> int:
print(f"model: {self.model}")
_, custom_llm_provider, _, _ = litellm.get_llm_provider(self.model) # type: ignore
match custom_llm_provider:
case "cohere":
return 96 # see https://docs.cohere.com/v1/reference/embed#request.body.texts
case "openai":
return 2048 # see https://platform.openai.com/docs/api-reference/embeddings/create
case "azure":
return 1024 # TODO: unknown
case "bedrock":
return 2048 # TODO: unknown
case "huggingface":
return 1024 # TODO: unknown
case "mistral":
return 1024 # TODO: unknown
case "vertex":
return 1024 # TODO: unknown
case "voyage":
return 128 # see https://docs.voyageai.com/reference/embeddings-api
case _:
raise UnknownProviderError(custom_llm_provider)

async def call_embed_api(self, documents: str | list[str]) -> EmbeddingResponse:
# Without `suppress_debug_info`, LiteLLM writes the following into stdout:
# Provider List: https://docs.litellm.ai/docs/providers
# This is useless, and confusing, so we suppress it.
litellm.suppress_debug_info = True
response: LiteLLMEmbeddingResponse = await litellm.aembedding(
model=self.model, input=documents, **self.extra_options
) # type: ignore
usage = (
Usage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
if response.usage is not None
else Usage(prompt_tokens=0, total_tokens=0)
)
return EmbeddingResponse(
embeddings=[d["embedding"] for d in response["data"]], usage=usage
)
4 changes: 2 additions & 2 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LangChainCharacterTextSplitter,
LangChainRecursiveCharacterTextSplitter,
)
from .embedders import Ollama, OpenAI, VoyageAI
from .embedders import LiteLLM, Ollama, OpenAI, VoyageAI
from .embeddings import ChunkEmbeddingError
from .formatting import ChunkValue, PythonTemplate
from .processing import ProcessingDefault
Expand Down Expand Up @@ -75,7 +75,7 @@ class Config:
"""

version: str
embedding: OpenAI | Ollama | VoyageAI
embedding: OpenAI | Ollama | VoyageAI | LiteLLM
processing: ProcessingDefault
chunking: (
LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter
Expand Down
1 change: 1 addition & 0 deletions projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"datadog_lambda>=6.9,<7.0",
"pytimeparse>=1.1,<2.0",
"voyageai>=0.3.1,<0.4.0",
"litellm>=1.55.4,<1.56.0",
]
classifiers = [
"License :: OSI Approved :: PostgreSQL License",
Expand Down
Loading

0 comments on commit 366ae89

Please sign in to comment.