Skip to content

Commit

Permalink
Allow for setting nan for embedding models that do not support embedd…
Browse files Browse the repository at this point in the history
…ing dimensions (#1792)
  • Loading branch information
NolanTrem authored Jan 9, 2025
1 parent 1cff015 commit 5ee250b
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 18 deletions.
2 changes: 1 addition & 1 deletion py/core/base/providers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class EmbeddingConfig(ProviderConfig):
provider: str
base_model: str
base_dimension: int
base_dimension: int | float
rerank_model: Optional[str] = None
rerank_url: Optional[str] = None
batch_size: int = 1
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TextSplitter,
_decorate_vector_type,
_get_str_estimation_output,
_get_vector_column_str,
decrement_version,
deep_update,
format_search_results_for_llm,
Expand Down Expand Up @@ -39,5 +40,6 @@
"validate_uuid",
"deep_update",
"_decorate_vector_type",
"_get_vector_column_str",
"_get_str_estimation_output",
]
47 changes: 38 additions & 9 deletions py/core/database/chunks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json
import logging
import math
import time
import uuid
from typing import Any, Optional, TypedDict
Expand Down Expand Up @@ -122,13 +123,18 @@ async def create_tables(self):
else f"vec_binary bit({self.dimension}),"
)

if self.dimension > 0:
vector_col = f"vec vector({self.dimension})"
else:
vector_col = "vec vector"

query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
id UUID PRIMARY KEY,
document_id UUID,
owner_id UUID,
collection_ids UUID[],
vec vector({self.dimension}),
{vector_col},
{binary_col}
text TEXT,
metadata JSONB,
Expand All @@ -149,11 +155,15 @@ async def upsert(self, entry: VectorEntry) -> None:
"""
# Check the quantization type to determine which columns to use
if self.quantization_type == VectorQuantizationType.INT1:
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

# For quantized vectors, use vec_binary column
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
Expand Down Expand Up @@ -212,11 +222,15 @@ async def upsert_entries(self, entries: list[VectorEntry]) -> None:
Matches the table schema where vec_binary column only exists for INT1 quantization.
"""
if self.quantization_type == VectorQuantizationType.INT1:
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

# For quantized vectors, use vec_binary column
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
Expand Down Expand Up @@ -313,7 +327,10 @@ async def semantic_search(
)

# Use binary column and binary-specific distance measures for first stage
stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
stage1_param = binary_query

cols.append(
Expand All @@ -331,6 +348,10 @@ async def semantic_search(
search_settings.filters, params, mode="where_clause"
)

vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

# First stage: Get candidates using binary search
query = f"""
WITH candidates AS (
Expand All @@ -350,7 +371,7 @@ async def semantic_search(
collection_ids,
text,
{"metadata," if search_settings.include_metadatas else ""}
(vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
(vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
FROM candidates
ORDER BY distance
LIMIT ${len(params) + 3}
Expand All @@ -367,7 +388,10 @@ async def semantic_search(

else:
# Standard float vector handling
distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
query_param = str(query_vector)

if search_settings.include_scores:
Expand Down Expand Up @@ -1048,19 +1072,24 @@ async def get_semantic_neighbors(
similarity_threshold: float = 0.5,
) -> list[dict[str, Any]]:
table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

query = f"""
WITH target_vector AS (
SELECT vec FROM {table_name}
SELECT vec::vector{vector_dim} FROM {table_name}
WHERE document_id = $1 AND id = $2
)
SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
SELECT t.id, t.text, t.metadata, t.document_id, (t.vec::vector{vector_dim} <=> tv.vec) AS similarity
FROM {table_name} t, target_vector tv
WHERE (t.vec <=> tv.vec) >= $3
WHERE (t.vec::vector{vector_dim} <=> tv.vec) >= $3
AND t.document_id = $1
AND t.id != $2
ORDER BY similarity ASC
LIMIT $4
"""

results = await self.connection_manager.fetch_query(
query,
(str(document_id), str(id), similarity_threshold, limit),
Expand Down
15 changes: 13 additions & 2 deletions py/core/database/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import csv
import json
import logging
import math
import tempfile
from typing import IO, Any, Optional
from uuid import UUID
Expand Down Expand Up @@ -43,6 +44,12 @@ async def create_tables(self):
logger.info(
f"Creating table, if not exists: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
)

vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
vector_type = f"vector{vector_dim}"

try:
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
Expand All @@ -53,7 +60,7 @@ async def create_tables(self):
metadata JSONB,
title TEXT,
summary TEXT NULL,
summary_embedding vector({self.dimension}) NULL,
summary_embedding {vector_type} NULL,
version TEXT,
size_in_bytes INT,
ingestion_status TEXT DEFAULT 'pending',
Expand Down Expand Up @@ -511,6 +518,10 @@ async def semantic_document_search(
where_clauses = ["summary_embedding IS NOT NULL"]
params: list[str | int | bytes] = [str(query_embedding)]

vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

if search_settings.filters:
filter_condition, params = apply_filters(
search_settings.filters, params, mode="condition_only"
Expand All @@ -537,7 +548,7 @@ async def semantic_document_search(
updated_at,
summary,
summary_embedding,
(summary_embedding <=> $1::vector({self.dimension})) as semantic_distance
(summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
WHERE {where_clause}
ORDER BY semantic_distance ASC
Expand Down
15 changes: 9 additions & 6 deletions py/core/database/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import json
import logging
import math
import os
import tempfile
import time
Expand Down Expand Up @@ -32,6 +33,7 @@
from core.base.utils import (
_decorate_vector_type,
_get_str_estimation_output,
_get_vector_column_str,
llm_cost_per_million_tokens,
)

Expand Down Expand Up @@ -75,8 +77,8 @@ def _get_parent_constraint(self, store_type: StoreType) -> str:

async def create_tables(self) -> None:
"""Create separate tables for graph and document entities."""
vector_column_str = _decorate_vector_type(
f"({self.dimension})", self.quantization_type
vector_column_str = _get_vector_column_str(
self.dimension, self.quantization_type
)

for store_type in StoreType:
Expand Down Expand Up @@ -527,9 +529,10 @@ async def create_tables(self) -> None:
for store_type in StoreType:
table_name = self._get_relationship_table_for_store(store_type)
parent_constraint = self._get_parent_constraint(store_type)
vector_column_str = _decorate_vector_type(
f"({self.dimension})", self.quantization_type
vector_column_str = _get_vector_column_str(
self.dimension, self.quantization_type
)

QUERY = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
Expand Down Expand Up @@ -1011,8 +1014,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore

async def create_tables(self) -> None:
vector_column_str = _decorate_vector_type(
f"({self.dimension})", self.quantization_type
vector_column_str = _get_vector_column_str(
self.dimension, self.quantization_type
)

query = f"""
Expand Down
5 changes: 5 additions & 0 deletions py/core/providers/embeddings/litellm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import os
from copy import copy
from typing import Any
Expand Down Expand Up @@ -73,6 +74,10 @@ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
texts = task["texts"]
kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))

if "dimensions" in kwargs and math.isnan(kwargs["dimensions"]):
kwargs.pop("dimensions")
logger.warning("Dropping nan dimensions from kwargs")

try:
response = await self.litellm_aembedding(
input=texts,
Expand Down
2 changes: 2 additions & 0 deletions py/shared/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_utils import (
_decorate_vector_type,
_get_str_estimation_output,
_get_vector_column_str,
decrement_version,
deep_update,
format_search_results_for_llm,
Expand Down Expand Up @@ -42,5 +43,6 @@
"TextSplitter",
# Vector utils
"_decorate_vector_type",
"_get_vector_column_str",
"_get_str_estimation_output",
]
18 changes: 18 additions & 0 deletions py/shared/utils/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import math
from copy import deepcopy
from datetime import datetime
from typing import (
Expand Down Expand Up @@ -300,6 +301,23 @@ def _decorate_vector_type(
return f"{quantization_type.db_type}{input_str}"


def _get_vector_column_str(
dimension: int | float, quantization_type: VectorQuantizationType
) -> str:
"""
Returns a string representation of a vector column type.
Explicitly handles the case where the dimension is not a valid number
meant to support embedding models that do not allow for specifying
the dimension.
"""
if math.isnan(dimension) or dimension <= 0:
vector_dim = "" # Allows for Postgres to handle any dimension
else:
vector_dim = f"({dimension})"
return _decorate_vector_type(vector_dim, quantization_type)


def _get_str_estimation_output(x: tuple[Any, Any]) -> str:
if isinstance(x[0], int) and isinstance(x[1], int):
return " - ".join(map(str, x))
Expand Down

0 comments on commit 5ee250b

Please sign in to comment.