Skip to content

Commit

Permalink
feat: optimized sparse embedding interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Nov 27, 2024
1 parent d5f4703 commit 805807f
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 87 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ lint: PYTHON_FILES=.
lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$')

lint lint_diff:
poetry run black --target-version py39 -l 88 $(PYTHON_FILES) --check
poetry run ruff .
poetry run black --target-version py311 -l 88 $(PYTHON_FILES) --check
poetry run ruff check .
poetry run mypy $(PYTHON_FILES)

test:
Expand Down
4 changes: 1 addition & 3 deletions docs/encoders/aurelio-bm25.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@
" \"Enter OpenAI API Key: \"\n",
")\n",
"\n",
"encoder = OpenAIEncoder(\n",
" name=\"text-embedding-3-small\", score_threshold=0.3\n",
")"
"encoder = OpenAIEncoder(name=\"text-embedding-3-small\", score_threshold=0.3)"
]
},
{
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/hybrid-router.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@
"from semantic_router.routers import HybridRouter\n",
"\n",
"router = HybridRouter(\n",
" encoder=dense_encoder,\n",
" sparse_encoder=sparse_encoder,\n",
" routes=routes\n",
" encoder=dense_encoder, sparse_encoder=sparse_encoder, routes=routes\n",
")"
]
},
Expand Down
54 changes: 26 additions & 28 deletions docs/examples/pinecone-hybrid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -90,7 +90,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -153,9 +153,7 @@
" \"Enter OpenAI API Key: \"\n",
")\n",
"\n",
"encoder = OpenAIEncoder(\n",
" name=\"text-embedding-3-small\", score_threshold=0.3\n",
")"
"encoder = OpenAIEncoder(name=\"text-embedding-3-small\", score_threshold=0.3)"
]
},
{
Expand All @@ -167,16 +165,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n",
"2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n",
"2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n"
"2024-11-27 15:41:32 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n",
"2024-11-27 15:41:32 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n",
"2024-11-27 15:41:32 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n"
]
}
],
Expand All @@ -203,7 +201,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -226,16 +224,16 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
"False"
]
},
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -253,7 +251,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -272,7 +270,7 @@
" \" politics: why don't you tell me about your political opinions\"]"
]
},
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -290,7 +288,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -309,7 +307,7 @@
" Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -327,7 +325,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -349,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -358,7 +356,7 @@
"True"
]
},
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -369,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -388,7 +386,7 @@
" \" politics: why don't you tell me about your political opinions\"]"
]
},
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -406,14 +404,14 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-26 22:35:56 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
"2024-11-27 15:42:03 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
]
},
{
Expand All @@ -422,7 +420,7 @@
"RouteChoice(name=None, function_call=None, similarity_score=None)"
]
},
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -440,7 +438,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-26 22:35:20 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
"2024-11-27 15:42:06 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
]
},
{
Expand Down
13 changes: 8 additions & 5 deletions semantic_router/encoders/aurelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ class AurelioSparseEncoder(SparseEncoder):
model: Optional[Any] = None
idx_mapping: Optional[Dict[int, int]] = None
client: AurelioClient = Field(default_factory=AurelioClient, exclude=True)
async_client: AsyncAurelioClient = Field(default_factory=AsyncAurelioClient, exclude=True)
async_client: AsyncAurelioClient = Field(
default_factory=AsyncAurelioClient, exclude=True
)
type: str = "sparse"

def __init__(
self,
name: str = "bm25",
score_threshold: float = 1.0,
api_key: Optional[str] = None,
):
super().__init__(name=name, score_threshold=score_threshold)
super().__init__(name=name)
if api_key is None:
api_key = os.getenv("AURELIO_API_KEY")
if api_key is None:
Expand All @@ -33,9 +34,11 @@ def __call__(self, docs: list[str]) -> list[SparseEmbedding]:
res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name)
embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
return embeds

async def acall(self, docs: list[str]) -> list[SparseEmbedding]:
res: EmbeddingResponse = await self.async_client.embedding(input=docs, model=self.name)
res: EmbeddingResponse = await self.async_client.embedding(
input=docs, model=self.name
)
embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
return embeds

Expand Down
2 changes: 1 addition & 1 deletion semantic_router/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ def __call__(self, docs: List[str]) -> List[SparseEmbedding]:
raise NotImplementedError("Subclasses must implement this method")

def acall(self, docs: List[str]) -> Coroutine[Any, Any, List[SparseEmbedding]]:
raise NotImplementedError("Subclasses must implement this method")
raise NotImplementedError("Subclasses must implement this method")
7 changes: 3 additions & 4 deletions semantic_router/encoders/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ class TfidfEncoder(SparseEncoder):
idf: ndarray = np.array([])
word_index: Dict = {}

def __init__(self, name: str = "tfidf", score_threshold: float = 0.82):
# TODO default score_threshold not thoroughly tested, should optimize
super().__init__(name=name, score_threshold=score_threshold)
def __init__(self, name: str = "tfidf"):
super().__init__(name=name)
self.word_index = {}
self.idf = np.array([])

Expand All @@ -29,7 +28,7 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
docs = [self._preprocess(doc) for doc in docs]
tf = self._compute_tf(docs)
tfidf = tf * self.idf
return tfidf.tolist()
return tfidf

def fit(self, routes: List[Route]):
docs = []
Expand Down
12 changes: 8 additions & 4 deletions semantic_router/index/hybrid_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,19 @@ def describe(self) -> Dict:
"dimensions": self.index.shape[1] if self.index is not None else 0,
"vectors": self.index.shape[0] if self.index is not None else 0,
}

def _sparse_dot_product(self, vec_a: dict[int, float], vec_b: dict[int, float]) -> float:

def _sparse_dot_product(
self, vec_a: dict[int, float], vec_b: dict[int, float]
) -> float:
# switch vecs to ensure first is smallest for more efficiency
if len(vec_a) > len(vec_b):
vec_a, vec_b = vec_b, vec_a
return sum(vec_a[i] * vec_b.get(i, 0) for i in vec_a)

def _sparse_index_dot_product(self, vec_a: dict[int, float]) -> list[float]:
dot_products = [self._sparse_dot_product(vec_a, vec_b) for vec_b in self.sparse_index]
dot_products = [
self._sparse_dot_product(vec_a, vec_b) for vec_b in self.sparse_index
]
return dot_products

def query(
Expand Down
4 changes: 3 additions & 1 deletion semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,9 @@ async def _async_fetch_metadata(self, vector_id: str) -> dict:
)

def __len__(self):
namespace_stats = self.index.describe_index_stats()["namespaces"].get(self.namespace)
namespace_stats = self.index.describe_index_stats()["namespaces"].get(
self.namespace
)
if namespace_stats:
return namespace_stats["vector_count"]
else:
Expand Down
5 changes: 2 additions & 3 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import hashlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic.v1 import BaseModel, Field, validator
from pydantic.v1 import BaseModel, Field

import numpy as np
import yaml # type: ignore
Expand Down Expand Up @@ -380,8 +380,7 @@ def _set_index(self, index: Optional[BaseIndex]):
self.index = index

def _init_index_state(self):
"""Initializes an index (where required) and runs auto_sync if active.
"""
"""Initializes an index (where required) and runs auto_sync if active."""
# initialize index now, check if we need dimensions
if self.index.dimensions is None:
dims = len(self.encoder(["test"])[0])
Expand Down
8 changes: 5 additions & 3 deletions semantic_router/routers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()

def _set_sparse_encoder(self, sparse_encoder: Optional[DenseEncoder]):
if sparse_encoder is None:
logger.warning("No sparse_encoder provided. Using default BM25Encoder.")
Expand Down Expand Up @@ -126,7 +126,7 @@ def __call__(
vector=np.array(vector) if isinstance(vector, list) else vector,
top_k=self.top_k,
route_filter=route_filter,
sparse_vector=sparse_vector[0]
sparse_vector=sparse_vector[0],
)
top_class, top_class_scores = self._semantic_classify(
list(zip(scores, route_names))
Expand All @@ -142,7 +142,9 @@ def _convex_scaling(self, dense: np.ndarray, sparse: list[dict[int, float]]):
scaled_dense = np.array(dense) * self.alpha
scaled_sparse = []
for sparse_dict in sparse:
scaled_sparse.append({k: v * (1 - self.alpha) for k, v in sparse_dict.items()})
scaled_sparse.append(
{k: v * (1 - self.alpha) for k, v in sparse_dict.items()}
)
return scaled_dense, scaled_sparse

def _set_aggregation_method(self, aggregation: str = "sum"):
Expand Down
Loading

0 comments on commit 805807f

Please sign in to comment.