Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP litellm integration #320

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ on: [pull_request, workflow_dispatch]
permissions:
contents: read

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-and-test-extension:
services:
Expand Down
35 changes: 35 additions & 0 deletions projects/extension/ai/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import litellm
from typing import Optional, Generator


def embed(
model: str,
input: list[str],
api_key: str,
user: Optional[str] = None,
dimensions: Optional[int] = None,
timeout: Optional[int] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_type: Optional[str] = None,
organization: Optional[str] = None,
**kwargs,
) -> Generator[tuple[int, list[float]], None, None]:
if organization is not None:
litellm.organization = organization
response = litellm.embedding(
model=model,
input=input,
user=user,
dimensions=dimensions,
timeout=timeout,
api_type=api_type,
api_key=api_key,
api_base=api_base,
api_version=api_version,
**kwargs,
)
if not hasattr(response, "data"):
return None
for idx, obj in enumerate(response["data"]):
yield idx, obj["embedding"]
20 changes: 7 additions & 13 deletions projects/extension/ai/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,21 @@ def remove_secret_from_cache(sd_cache: dict[str, str], secret_name: str):

def get_secret(
plpy,
secret: Optional[str],
secret_name: Optional[str],
secret_name_default: str,
sd_cache: Optional[dict[str, str]],
) -> str:
secret: Optional[str] = None,
secret_name: Optional[str] = None,
secret_name_default: Optional[str] = None,
sd_cache: Optional[dict[str, str]] = None,
) -> str | None:
if secret is not None:
return secret

if secret_name is None:
secret_name = secret_name_default

if secret_name is None or secret_name == "":
plpy.error("secret_name is required")

secret = reveal_secret(plpy, secret_name, sd_cache)
if secret is None:
plpy.error(f"missing {secret_name} secret")
# This line should never be reached, but it's here to make the type checker happy.
return ""
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there existing code that depends on this condition throwing an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. I still need to do a full audit of all of the callers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now audited the callers - from what I can tell none of our code depends on this throwing an error.

The motivation for this change is that with litellm there is no sensible "default secret name" that we could configure because every provider has different conventions for secret naming. This motivated the change on line 29 to allow for secret_name_default to be None.


return secret
return reveal_secret(plpy, secret_name, sd_cache)


def check_secret_permissions(plpy, secret_name: str) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion projects/extension/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
openai==1.44.0
openai==1.56.0
tiktoken==0.7.0
ollama==0.4.5
anthropic==0.29.0
cohere==5.5.8
backoff==2.2.1
voyageai==0.3.1
datasets==3.1.0
litellm==1.55.4
google-cloud-aiplatform==1.74.0 # required for vertexAI (don't know why litellm doesn't include this)
26 changes: 26 additions & 0 deletions projects/extension/sql/idempotent/008-embedding.sql
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ $func$ language plpgsql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- embedding_litellm
create or replace function ai.embedding_litellm
( model pg_catalog.text
, dimensions pg_catalog.int4
, api_key_name pg_catalog.text default null
, extra_options pg_catalog.jsonb default null
) returns pg_catalog.jsonb
as $func$
begin
return json_object
( 'implementation': 'litellm'
, 'config_type': 'embedding'
, 'model': model
, 'dimensions': dimensions
, 'api_key_name': api_key_name
, 'extra_options': extra_options
absent on null
);
end
$func$ language plpgsql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- _validate_embedding
create or replace function ai._validate_embedding(config pg_catalog.jsonb) returns void
Expand All @@ -98,6 +122,8 @@ begin
-- ok
when 'voyageai' then
-- ok
when 'litellm' then
-- ok
else
if _implementation is null then
raise exception 'embedding implementation not specified';
Expand Down
58 changes: 58 additions & 0 deletions projects/extension/sql/idempotent/017-litellm.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
-------------------------------------------------------------------------------
-- litellm_embed
-- generate an embedding from a text value
create or replace function ai.litellm_embed
( model text
, input_text text
, api_key text default null
, api_key_name text default null
, extra_options jsonb default null
) returns @extschema:[email protected]
as $python$
#ADD-PYTHON-LIB-DIR
import ai.litellm
import ai.secrets
options = {}
if extra_options is not None:
import json
options = {k: v for k, v in json.loads(extra_options).items()}

api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, None, SD)
for tup in ai.litellm.embed(model, [input_text], api_key=api_key_resolved, **options):
return tup[1]
$python$
language plpython3u immutable parallel safe security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- litellm_embed
-- generate embeddings from an array of text values
create or replace function ai.litellm_embed
( model text
, input_texts text[]
, api_key text default null
, api_key_name text default null
, extra_options jsonb default null
) returns table
( "index" int
, embedding @extschema:[email protected]
)
as $python$
#ADD-PYTHON-LIB-DIR
import ai.litellm
import ai.secrets
options = {}
if extra_options is not None:
import json
options = {k: v for k, v in json.loads(extra_options).items()}

plpy.log("options", options)

api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, None, SD)
for tup in ai.litellm.embed(model, input_texts, api_key=api_key_resolved, **options):
yield tup
$python$
language plpython3u immutable parallel safe security invoker
set search_path to pg_catalog, pg_temp
;
5 changes: 4 additions & 1 deletion projects/extension/tests/contents/output16.expected
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CREATE EXTENSION
function ai.create_vectorizer(regclass,name,jsonb,jsonb,jsonb,jsonb,jsonb,jsonb,name,name,name,name,name,name,name[],boolean)
function ai.disable_vectorizer_schedule(integer)
function ai.drop_vectorizer(integer,boolean)
function ai.embedding_litellm(text,integer,text,jsonb)
function ai.embedding_ollama(text,integer,text,jsonb,text)
function ai.embedding_openai(text,integer,text,text)
function ai.embedding_voyageai(text,integer,text,text)
Expand All @@ -34,6 +35,8 @@ CREATE EXTENSION
function ai.indexing_diskann(integer,text,integer,integer,double precision,integer,integer,boolean)
function ai.indexing_hnsw(integer,text,integer,integer,boolean)
function ai.indexing_none()
function ai.litellm_embed(text,text,text,text,jsonb)
function ai.litellm_embed(text,text[],text,text,jsonb)
function ai.load_dataset_multi_txn(text,text,text,name,name,text,jsonb,integer,integer,integer,jsonb)
function ai.load_dataset(text,text,text,name,name,text,jsonb,integer,integer,jsonb)
function ai.ollama_chat_complete(text,jsonb,text,text,jsonb)
Expand Down Expand Up @@ -94,7 +97,7 @@ CREATE EXTENSION
table ai.vectorizer_errors
view ai.secret_permissions
view ai.vectorizer_status
(90 rows)
(93 rows)

Table "ai._secret_permissions"
Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description
Expand Down
5 changes: 4 additions & 1 deletion projects/extension/tests/contents/output17.expected
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CREATE EXTENSION
function ai.create_vectorizer(regclass,name,jsonb,jsonb,jsonb,jsonb,jsonb,jsonb,name,name,name,name,name,name,name[],boolean)
function ai.disable_vectorizer_schedule(integer)
function ai.drop_vectorizer(integer,boolean)
function ai.embedding_litellm(text,integer,text,jsonb)
function ai.embedding_ollama(text,integer,text,jsonb,text)
function ai.embedding_openai(text,integer,text,text)
function ai.embedding_voyageai(text,integer,text,text)
Expand All @@ -34,6 +35,8 @@ CREATE EXTENSION
function ai.indexing_diskann(integer,text,integer,integer,double precision,integer,integer,boolean)
function ai.indexing_hnsw(integer,text,integer,integer,boolean)
function ai.indexing_none()
function ai.litellm_embed(text,text,text,text,jsonb)
function ai.litellm_embed(text,text[],text,text,jsonb)
function ai.load_dataset_multi_txn(text,text,text,name,name,text,jsonb,integer,integer,integer,jsonb)
function ai.load_dataset(text,text,text,name,name,text,jsonb,integer,integer,jsonb)
function ai.ollama_chat_complete(text,jsonb,text,text,jsonb)
Expand Down Expand Up @@ -108,7 +111,7 @@ CREATE EXTENSION
type ai.vectorizer_status[]
view ai.secret_permissions
view ai.vectorizer_status
(104 rows)
(107 rows)

Table "ai._secret_permissions"
Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description
Expand Down
14 changes: 13 additions & 1 deletion projects/extension/tests/privileges/function.expected
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@
f | bob | execute | no | ai | drop_vectorizer(vectorizer_id integer, drop_all boolean)
f | fred | execute | no | ai | drop_vectorizer(vectorizer_id integer, drop_all boolean)
f | jill | execute | YES | ai | drop_vectorizer(vectorizer_id integer, drop_all boolean)
f | alice | execute | YES | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | bob | execute | no | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | fred | execute | no | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | jill | execute | YES | ai | embedding_litellm(model text, dimensions integer, api_key_name text, extra_options jsonb)
f | alice | execute | YES | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | bob | execute | no | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
f | fred | execute | no | ai | embedding_ollama(model text, dimensions integer, base_url text, options jsonb, keep_alive text)
Expand Down Expand Up @@ -220,6 +224,14 @@
f | bob | execute | no | ai | indexing_none()
f | fred | execute | no | ai | indexing_none()
f | jill | execute | YES | ai | indexing_none()
f | alice | execute | YES | ai | litellm_embed(model text, input_text text, api_key text, api_key_name text, extra_options jsonb)
f | bob | execute | no | ai | litellm_embed(model text, input_text text, api_key text, api_key_name text, extra_options jsonb)
f | fred | execute | no | ai | litellm_embed(model text, input_text text, api_key text, api_key_name text, extra_options jsonb)
f | jill | execute | YES | ai | litellm_embed(model text, input_text text, api_key text, api_key_name text, extra_options jsonb)
f | alice | execute | YES | ai | litellm_embed(model text, input_texts text[], api_key text, api_key_name text, extra_options jsonb)
f | bob | execute | no | ai | litellm_embed(model text, input_texts text[], api_key text, api_key_name text, extra_options jsonb)
f | fred | execute | no | ai | litellm_embed(model text, input_texts text[], api_key text, api_key_name text, extra_options jsonb)
f | jill | execute | YES | ai | litellm_embed(model text, input_texts text[], api_key text, api_key_name text, extra_options jsonb)
f | alice | execute | YES | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb)
f | bob | execute | no | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb)
f | fred | execute | no | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb)
Expand Down Expand Up @@ -328,5 +340,5 @@
f | bob | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text)
f | fred | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text)
f | jill | execute | YES | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text)
(328 rows)
(340 rows)

Loading
Loading