Skip to content

Commit

Permalink
feat: new sparse embedding support and abstractions
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Nov 26, 2024
1 parent 09cc05f commit 87d6ae4
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 272 deletions.
196 changes: 115 additions & 81 deletions docs/examples/pinecone-hybrid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n",
"2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n",
"2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n"
"2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n",
"2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n",
"2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n"
]
}
],
Expand Down Expand Up @@ -205,29 +205,7 @@
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-24 19:41:15 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
"2024-11-24 19:41:17 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
"politics: isn't politics the best thing ever\n",
"politics: why don't you tell me about your political opinions\n",
"politics: don't you just love the president\n",
"politics: don't you just hate the president\n",
"politics: they're going to destroy this country!\n",
"politics: they will save the country!\n",
"2024-11-24 19:41:17 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
"2024-11-24 19:41:18 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
"chitchat: how's the weather today?\n",
"chitchat: how are things going?\n",
"chitchat: lovely weather today\n",
"chitchat: the weather is horrendous\n",
"chitchat: let's go to the chippy\n"
]
}
],
"outputs": [],
"source": [
"from semantic_router.routers import HybridRouter\n",
"\n",
Expand All @@ -248,23 +226,16 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-24 19:42:06 - semantic_router.utils.logger - WARNING - pinecone.py:424 - _read_hash() - Configuration for hash parameter not found in index.\n"
]
},
{
"data": {
"text/plain": [
"False"
"True"
]
},
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -282,26 +253,26 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['- chitchat: how are things going?',\n",
" \"- chitchat: how's the weather today?\",\n",
" \"- chitchat: let's go to the chippy\",\n",
" '- chitchat: lovely weather today',\n",
" '- chitchat: the weather is horrendous',\n",
" \"- politics: don't you just hate the president\",\n",
" \"- politics: don't you just love the president\",\n",
" \"- politics: isn't politics the best thing ever\",\n",
" '- politics: they will save the country!',\n",
" \"- politics: they're going to destroy this country!\",\n",
" \"- politics: why don't you tell me about your political opinions\"]"
"[' chitchat: how are things going?',\n",
" \" chitchat: how's the weather today?\",\n",
" \" chitchat: let's go to the chippy\",\n",
" ' chitchat: lovely weather today',\n",
" ' chitchat: the weather is horrendous',\n",
" \" politics: don't you just hate the president\",\n",
" \" politics: don't you just love the president\",\n",
" \" politics: isn't politics the best thing ever\",\n",
" ' politics: they will save the country!',\n",
" \" politics: they're going to destroy this country!\",\n",
" \" politics: why don't you tell me about your political opinions\"]"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -319,16 +290,26 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
"[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n",
" Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]"
]
},
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -346,31 +327,9 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-24 19:48:29 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
"2024-11-24 19:48:31 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
"politics: isn't politics the best thing ever\n",
"politics: why don't you tell me about your political opinions\n",
"politics: don't you just love the president\n",
"politics: don't you just hate the president\n",
"politics: they're going to destroy this country!\n",
"politics: they will save the country!\n",
"2024-11-24 19:48:31 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
"2024-11-24 19:48:32 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
"chitchat: how's the weather today?\n",
"chitchat: how are things going?\n",
"chitchat: lovely weather today\n",
"chitchat: the weather is horrendous\n",
"chitchat: let's go to the chippy\n"
]
}
],
"outputs": [],
"source": [
"router = HybridRouter(\n",
" encoder=encoder,\n",
Expand All @@ -390,16 +349,16 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
"True"
]
},
"execution_count": 16,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -408,6 +367,36 @@
"router.is_synced()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[' chitchat: how are things going?',\n",
" \" chitchat: how's the weather today?\",\n",
" \" chitchat: let's go to the chippy\",\n",
" ' chitchat: lovely weather today',\n",
" ' chitchat: the weather is horrendous',\n",
" \" politics: don't you just hate the president\",\n",
" \" politics: don't you just love the president\",\n",
" \" politics: isn't politics the best thing ever\",\n",
" ' politics: they will save the country!',\n",
" \" politics: they're going to destroy this country!\",\n",
" \" politics: why don't you tell me about your political opinions\"]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"router.get_utterance_diff()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -417,9 +406,54 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-26 22:35:56 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
]
},
{
"data": {
"text/plain": [
"RouteChoice(name=None, function_call=None, similarity_score=None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"router(\"it's raining cats and dogs today\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-26 22:35:20 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
]
},
{
"data": {
"text/plain": [
"RouteChoice(name=None, function_call=None, similarity_score=None)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"router(\"I'm interested in learning about llama 2\")"
]
Expand Down
17 changes: 7 additions & 10 deletions semantic_router/encoders/aurelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aurelio_sdk import AurelioClient, AsyncAurelioClient, EmbeddingResponse

from semantic_router.encoders.base import BaseEncoder
from semantic_router.schema import SparseEmbedding


class AurelioSparseEncoder(BaseEncoder):
Expand All @@ -28,19 +29,15 @@ def __init__(
self.client = AurelioClient(api_key=api_key)
self.async_client = AsyncAurelioClient(api_key=api_key)

def __call__(self, docs: list[str]) -> list[dict[int, float]]:
def __call__(self, docs: list[str]) -> list[SparseEmbedding]:
res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name)
embeds = [r.embedding.model_dump() for r in res.data]
# convert sparse vector to {index: value} format
sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds]
return sparse_dicts
embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
return embeds

async def acall(self, docs: list[str]) -> list[dict[int, float]]:
async def acall(self, docs: list[str]) -> list[SparseEmbedding]:
res: EmbeddingResponse = await self.async_client.embedding(input=docs, model=self.name)
embeds = [r.embedding.model_dump() for r in res.data]
# convert sparse vector to {index: value} format
sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds]
return sparse_dicts
embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
return embeds

def fit(self, docs: List[str]):
raise NotImplementedError("AurelioSparseEncoder does not support fit.")
27 changes: 14 additions & 13 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic.v1 import BaseModel, Field

from semantic_router.index.base import BaseIndex
from semantic_router.schema import ConfigParameter
from semantic_router.schema import ConfigParameter, SparseEmbedding
from semantic_router.utils.logger import logger


Expand Down Expand Up @@ -243,8 +243,6 @@ def add(
sparse_embeddings: Optional[List[dict[int, float]]] = None,
):
"""Add vectors to Pinecone in batches."""
temp = "\n".join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)])
logger.warning("TEMP | add:\n" + temp)
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
Expand Down Expand Up @@ -272,10 +270,6 @@ def add(
self._batch_upsert(batch)

def _remove_and_sync(self, routes_to_delete: dict):
temp = "\n".join(
[f"{route}: {utterances}" for route, utterances in routes_to_delete.items()]
)
logger.warning("TEMP | _remove_and_sync:\n" + temp)
for route, utterances in routes_to_delete.items():
remote_routes = self._get_routes_with_ids(route_name=route)
ids_to_delete = [
Expand Down Expand Up @@ -364,6 +358,7 @@ def query(
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
sparse_vector: dict[int, float] | SparseEmbedding | None = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]:
"""Search the index for the query vector and return the top_k results.
Expand All @@ -374,10 +369,10 @@ def query(
:type top_k: int, optional
:param route_filter: A list of route names to filter the search results, defaults to None.
:type route_filter: Optional[List[str]], optional
:param sparse_vector: An optional sparse vector to include in the query.
:type sparse_vector: Optional[SparseEmbedding]
:param kwargs: Additional keyword arguments for the query, including sparse_vector.
:type kwargs: Any
:keyword sparse_vector: An optional sparse vector to include in the query.
:type sparse_vector: Optional[dict]
:return: A tuple containing an array of scores and a list of route names.
:rtype: Tuple[np.ndarray, List[str]]
:raises ValueError: If the index is not populated.
Expand All @@ -389,9 +384,13 @@ def query(
filter_query = {"sr_route": {"$in": route_filter}}
else:
filter_query = None
if sparse_vector is not None:
if isinstance(sparse_vector, dict):
sparse_vector = SparseEmbedding.from_dict(sparse_vector)
sparse_vector = sparse_vector.to_pinecone()
results = self.index.query(
vector=[query_vector_list],
sparse_vector=kwargs.get("sparse_vector", None),
sparse_vector=sparse_vector,
top_k=top_k,
filter=filter_query,
include_metadata=True,
Expand Down Expand Up @@ -653,6 +652,8 @@ async def _async_fetch_metadata(self, vector_id: str) -> dict:
)

def __len__(self):
return self.index.describe_index_stats()["namespaces"][self.namespace][
"vector_count"
]
namespace_stats = self.index.describe_index_stats()["namespaces"].get(self.namespace)
if namespace_stats:
return namespace_stats["vector_count"]
else:
return 0
Loading

0 comments on commit 87d6ae4

Please sign in to comment.