Skip to content

Commit

Permalink
WIP: litellm vectorizer
Browse files Browse the repository at this point in the history
TODOs:
- tests for all integrations
- fine-tune configuration parameters
- fine-tune huggingface (cold-start?)
  • Loading branch information
JamesGuthrie committed Jan 10, 2025
1 parent d1d60a9 commit 118c54b
Show file tree
Hide file tree
Showing 23 changed files with 3,521 additions and 92 deletions.
26 changes: 26 additions & 0 deletions projects/extension/sql/idempotent/008-embedding.sql
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ $func$ language plpgsql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- embedding_litellm
create or replace function ai.embedding_litellm
( model pg_catalog.text
, dimensions pg_catalog.int4
, api_key_name pg_catalog.text default null
, extra_options pg_catalog.jsonb default null
) returns pg_catalog.jsonb
as $func$
begin
return json_object
( 'implementation': 'litellm'
, 'config_type': 'embedding'
, 'model': model
, 'dimensions': dimensions
, 'api_key_name': api_key_name
, 'extra_options': extra_options
absent on null
);
end
$func$ language plpgsql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- _validate_embedding
create or replace function ai._validate_embedding(config pg_catalog.jsonb) returns void
Expand All @@ -98,6 +122,8 @@ begin
-- ok
when 'voyageai' then
-- ok
when 'litellm' then
-- ok
else
if _implementation is null then
raise exception 'embedding implementation not specified';
Expand Down
1 change: 1 addition & 0 deletions projects/extension/tests/contents/output16.expected
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CREATE EXTENSION
function ai.create_vectorizer(regclass,name,jsonb,jsonb,jsonb,jsonb,jsonb,jsonb,name,name,name,name,name,name,name[],boolean)
function ai.disable_vectorizer_schedule(integer)
function ai.drop_vectorizer(integer,boolean)
function ai.embedding_litellm(text,integer,text,jsonb)
function ai.embedding_ollama(text,integer,text,jsonb,text)
function ai.embedding_openai(text,integer,text,text)
function ai.embedding_voyageai(text,integer,text,text)
Expand Down
3 changes: 2 additions & 1 deletion projects/extension/tests/contents/output17.expected
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CREATE EXTENSION
function ai.create_vectorizer(regclass,name,jsonb,jsonb,jsonb,jsonb,jsonb,jsonb,name,name,name,name,name,name,name[],boolean)
function ai.disable_vectorizer_schedule(integer)
function ai.drop_vectorizer(integer,boolean)
function ai.embedding_litellm(text,integer,text,jsonb)
function ai.embedding_ollama(text,integer,text,jsonb,text)
function ai.embedding_openai(text,integer,text,text)
function ai.embedding_voyageai(text,integer,text,text)
Expand Down Expand Up @@ -110,7 +111,7 @@ CREATE EXTENSION
type ai.vectorizer_status[]
view ai.secret_permissions
view ai.vectorizer_status
(106 rows)
(107 rows)

Table "ai._secret_permissions"
Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description
Expand Down
6 changes: 5 additions & 1 deletion projects/extension/tests/privileges/function.expected
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@
f | bob | execute | no | ai | drop_vectorizer(vectorizer_id integer, drop_all boolean)
f | fred | execute | no | ai | drop_vectorizer(vectorizer_id integer, drop_all boolean)
f | jill | execute | YES | ai | drop_vectorizer(vectorizer_id integer, drop_all boolean)
f | alice | execute | YES | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | bob | execute | no | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | fred | execute | no | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | jill | execute | YES | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | alice | execute | YES | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | bob | execute | no | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | fred | execute | no | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
Expand Down Expand Up @@ -336,5 +340,5 @@
f | bob | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text)
f | fred | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text)
f | jill | execute | YES | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text)
(336 rows)
(340 rows)

26 changes: 18 additions & 8 deletions projects/extension/tests/test_litellm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import List
from typing import List, Any

import psycopg
import pytest
import json


@pytest.fixture()
Expand All @@ -19,14 +20,16 @@ def cur() -> psycopg.Cursor:
"dimensions": 1536,
"api_key_name": "OPENAI_API_KEY",
"exception": "OpenAIException - The api_key client option must be set",
"extra_options": {},
"input_types": [],
},
{
"id": "voyage",
"name": "voyage/voyage-3-lite",
"dimensions": 512,
"api_key_name": "VOYAGE_API_KEY",
"exception": "VoyageException - The api_key client option must be set",
"exception": """VoyageException - {"detail":"Provided API key is invalid."}""",
"extra_options": {},
"input_types": ["query", "document"],
},
{
Expand All @@ -35,6 +38,7 @@ def cur() -> psycopg.Cursor:
"dimensions": 1024,
"api_key_name": "COHERE_API_KEY",
"exception": """CohereException - {"message":"no api key supplied"}""",
"extra_options": {},
"input_types": [
"search_query",
"search_document",
Expand All @@ -48,6 +52,7 @@ def cur() -> psycopg.Cursor:
"dimensions": 1024,
"api_key_name": "MISTRAL_API_KEY",
"exception": "MistralException - The api_key client option must be set",
"extra_options": {},
"input_types": [],
},
{
Expand All @@ -56,6 +61,7 @@ def cur() -> psycopg.Cursor:
"dimensions": 768,
"api_key_name": "HUGGINGFACE_API_KEY",
"exception": """HuggingfaceException - {"error":"Please log in or use a HF access token"}""",
"extra_options": {"wait_for_model": True},
"input_types": [],
},
]
Expand Down Expand Up @@ -87,12 +93,12 @@ def test_litellm_embed_fails_without_secret(


@pytest.mark.parametrize(
"name,dimensions,api_key_name",
model_keys("name", "dimensions", "api_key_name"),
"name,dimensions,api_key_name,extra_options",
model_keys("name", "dimensions", "api_key_name", "extra_options"),
ids=ids,
)
def test_litellm_embed_with_api_key_via_guc(
cur: psycopg.Cursor, name: str, dimensions: int, api_key_name: str
cur: psycopg.Cursor, name: str, dimensions: int, api_key_name: str, extra_options: dict[Any, Any]
):
api_key_value = os.getenv(api_key_name)
if api_key_value is None:
Expand All @@ -109,25 +115,27 @@ def test_litellm_embed_with_api_key_via_guc(
( %s
, 'hello world'
, api_key_name => %s
, extra_options => %s::jsonb
)
)
""",
(
name,
api_key_name,
json.dumps(extra_options)
),
)
actual = cur.fetchone()[0]
assert actual == dimensions


@pytest.mark.parametrize(
"name,dimensions,api_key_name",
model_keys("name", "dimensions", "api_key_name"),
"name,dimensions,api_key_name,extra_options",
model_keys("name", "dimensions", "api_key_name", "extra_options"),
ids=ids,
)
def test_litellm_embed(
cur: psycopg.Cursor, name: str, dimensions: int, api_key_name: str
cur: psycopg.Cursor, name: str, dimensions: int, api_key_name: str, extra_options: dict[Any, Any]
):
api_key_value = os.getenv(api_key_name)
if api_key_value is None:
Expand All @@ -140,12 +148,14 @@ def test_litellm_embed(
( %s
, 'hello world'
, api_key=>%s
, extra_options => %s::jsonb
)
)
""",
(
name,
api_key_value,
json.dumps(extra_options)
),
)
actual = cur.fetchone()[0]
Expand Down
6 changes: 5 additions & 1 deletion projects/pgai/pgai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ def get_vectorizer(db_url: str, vectorizer_id: int) -> Vectorizer:
return vectorizer


def run_vectorizer(db_url: str, vectorizer: Vectorizer, concurrency: int) -> None:
def run_vectorizer(
db_url: str,
vectorizer: Vectorizer,
concurrency: int,
) -> None:
async def run_workers(
db_url: str, vectorizer: Vectorizer, concurrency: int
) -> list[int]:
Expand Down
1 change: 1 addition & 0 deletions projects/pgai/pgai/vectorizer/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .litellm import LiteLLM as LiteLLM
from .ollama import Ollama as Ollama
from .openai import OpenAI as OpenAI
from .voyageai import VoyageAI as VoyageAI
131 changes: 131 additions & 0 deletions projects/pgai/pgai/vectorizer/embedders/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from collections.abc import Sequence
from functools import cached_property
from typing import Any, Literal

import litellm
from litellm import EmbeddingResponse as LiteLLMEmbeddingResponse, InMemoryCache as LiteLLMInMemoryCache # type: ignore
from pydantic import BaseModel
from typing_extensions import override

from ..embeddings import (
BatchApiCaller,
Embedder,
EmbeddingResponse,
EmbeddingVector,
StringDocument,
Usage,
logger,
)


# TODO: remove this when this issue is fixed upstream: https://github.com/BerriAI/litellm/issues/7667
# Note: we did consider building an event-loop aware in-memory cache, but the
# additional complexity doesn't seem to be worth it.
class NoopCache(LiteLLMInMemoryCache):
"""
A no-op cache
This class exists because litellm's internals cause http clients to be
re-used across different event loops. The httpx client does not like this,
which causes exceptions to be thrown. Note: Not all http clients throw
exceptions, so we are being overly cautious with this approach.
"""

def __init__(self):
super().__init__()

@override
def get_cache(self, key: Any, **kwargs: Any):
return None

@override
def set_cache(self, key: Any, value: Any, **kwargs: Any):
pass


litellm.in_memory_llm_clients_cache = NoopCache()


class UnknownProviderError(Exception):
pass


class LiteLLM(BaseModel, Embedder):
"""
Embedder that uses LiteLLM to embed documents into vector representations.
Attributes:
implementation (Literal["litellm"]): The literal identifier for this
implementation.
model (str): The name of the embedding model.
api_key_name (str): The API key name.
extra_options (dict): Additional litellm-specific options
"""

implementation: Literal["litellm"]
model: str
api_key_name: str | None = None
extra_options: dict[str, Any] = {}

@override
async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]:
"""
Embeds a list of documents into vectors using LiteLLM.
Args:
documents (list[str]): A list of documents to be embedded.
Returns:
Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or
errors for each document.
"""
await logger.adebug(f"Chunks produced: {len(documents)}")
return await self._batcher.batch_chunks_and_embed(documents)

@cached_property
def _batcher(self) -> BatchApiCaller[StringDocument]:
return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api)

@override
def _max_chunks_per_batch(self) -> int:
print(f"model: {self.model}")
_, custom_llm_provider, _, _ = litellm.get_llm_provider(self.model) # type: ignore
match custom_llm_provider:
case "cohere":
return 96 # see https://docs.cohere.com/v1/reference/embed#request.body.texts
case "openai":
return 2048 # see https://platform.openai.com/docs/api-reference/embeddings/create
case "azure":
return 1024 # TODO: unknown
case "bedrock":
return 2048 # TODO: unknown
case "huggingface":
return 1024 # TODO: unknown
case "mistral":
return 1024 # TODO: unknown
case "vertex":
return 1024 # TODO: unknown
case "voyage":
return 128 # see https://docs.voyageai.com/reference/embeddings-api
case _:
raise UnknownProviderError(custom_llm_provider)

async def call_embed_api(self, documents: str | list[str]) -> EmbeddingResponse:
# Without `suppress_debug_info`, LiteLLM writes the following into stdout:
# Provider List: https://docs.litellm.ai/docs/providers
# This is useless, and confusing, so we suppress it.
litellm.suppress_debug_info = True
response: LiteLLMEmbeddingResponse = await litellm.aembedding( # type: ignore
model=self.model, input=documents, **self.extra_options
)
usage = (
Usage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
if response.usage is not None
else Usage(prompt_tokens=0, total_tokens=0)
)
return EmbeddingResponse(
embeddings=[d["embedding"] for d in response["data"]], usage=usage
)
4 changes: 2 additions & 2 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LangChainCharacterTextSplitter,
LangChainRecursiveCharacterTextSplitter,
)
from .embedders import Ollama, OpenAI, VoyageAI
from .embedders import LiteLLM, Ollama, OpenAI, VoyageAI
from .embeddings import ChunkEmbeddingError
from .formatting import ChunkValue, PythonTemplate
from .processing import ProcessingDefault
Expand Down Expand Up @@ -75,7 +75,7 @@ class Config:
"""

version: str
embedding: OpenAI | Ollama | VoyageAI
embedding: OpenAI | Ollama | VoyageAI | LiteLLM
processing: ProcessingDefault
chunking: (
LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter
Expand Down
3 changes: 2 additions & 1 deletion projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"datadog_lambda>=6.9,<7.0",
"pytimeparse>=1.1,<2.0",
"voyageai>=0.3.1,<0.4.0",
"litellm>=1.57.5,<1.58.0",
]
classifiers = [
"License :: OSI Approved :: PostgreSQL License",
Expand Down Expand Up @@ -112,7 +113,7 @@ dev-dependencies = [
"ruff==0.6.9",
"pytest==8.3.2",
"python-dotenv==1.0.1",
"vcrpy==6.0.1",
"vcrpy==7.0.0",
"pyright==1.1.385",
"psycopg[binary]==3.2.1",
"testcontainers==4.8.1",
Expand Down
Loading

0 comments on commit 118c54b

Please sign in to comment.