Skip to content

Commit

Permalink
refactor: update PineconeIndex to use aiohttp for async requests and …
Browse files Browse the repository at this point in the history
…improve error handling

- Replaced async_client with aiohttp.ClientSession for making asynchronous HTTP requests.
- Added headers for API requests to enhance security and compatibility.
- Simplified error handling by removing checks for async_client initialization.
- Improved code readability and maintainability by consolidating request logic.
  • Loading branch information
italianconcerto committed Jan 23, 2025
1 parent fbf5a64 commit c8e7a62
Showing 1 changed file with 81 additions and 57 deletions.
138 changes: 81 additions & 57 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,14 @@ def __init__(
raise ValueError("Pinecone API key is required.")

self.client = self._initialize_client(api_key=self.api_key)
if init_async_index:
self.async_client = self._initialize_async_client(api_key=self.api_key)
else:
self.async_client = None

self.api_key = api_key
self.headers = {
"Api-Key": self.api_key,
"Content-Type": "application/json",
"X-Pinecone-API-Version": "2024-07",
"User-Agent": "source_tag=semanticrouter",
}
# try initializing index
self.index = self._init_index()

Expand Down Expand Up @@ -659,8 +663,8 @@ async def aquery(
:rtype: Tuple[np.ndarray, List[str]]
:raises ValueError: If the index is not populated.
"""
if self.async_client is None or self.host == "":
raise ValueError("Async client or host are not initialized.")
if self.host == "":
raise ValueError("Host is not initialized.")
query_vector_list = vector.tolist()
if route_filter is not None:
filter_query = {"sr_route": {"$in": route_filter}}
Expand Down Expand Up @@ -693,8 +697,8 @@ async def aget_routes(self) -> list[tuple]:
:return: A list of (route_name, utterance) objects.
:rtype: List[Tuple]
"""
if self.async_client is None or self.host == "":
raise ValueError("Async client or host are not initialized.")
if self.host == "":
raise ValueError("Host is not initialized.")

return await self._async_get_routes()

Expand Down Expand Up @@ -722,15 +726,21 @@ async def _async_query(
}
if self.host == "":
raise ValueError("self.host is not initialized.")
async with self.async_client.post(
f"https://{self.host}/query",
json=params,
) as response:
return await response.json(content_type=None)
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://{self.host}/query",
json=params,
headers=self.headers,
) as response:
return await response.json(content_type=None)

async def _async_list_indexes(self):
async with self.async_client.get(f"{self.base_url}/indexes") as response:
return await response.json(content_type=None)
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/indexes",
headers=self.headers,
) as response:
return await response.json(content_type=None)

async def _async_upsert(
self,
Expand All @@ -741,12 +751,14 @@ async def _async_upsert(
"vectors": vectors,
"namespace": namespace,
}
async with self.async_client.post(
f"https://{self.host}/vectors/upsert",
json=params,
) as response:
res = await response.json(content_type=None)
return res
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://{self.host}/vectors/upsert",
json=params,
headers=self.headers,
) as response:
res = await response.json(content_type=None)
return res

async def _async_create_index(
self,
Expand All @@ -762,26 +774,34 @@ async def _async_create_index(
"metric": metric,
"spec": {"serverless": {"cloud": cloud, "region": region}},
}
async with self.async_client.post(
f"{self.base_url}/indexes",
json=params,
) as response:
return await response.json(content_type=None)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/indexes",
json=params,
headers=self.headers,
) as response:
return await response.json(content_type=None)

async def _async_delete(self, ids: list[str], namespace: str = ""):
params = {
"ids": ids,
"namespace": namespace,
}
async with self.async_client.post(
f"https://{self.host}/vectors/delete",
json=params,
) as response:
return await response.json(content_type=None)
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://{self.host}/vectors/delete",
json=params,
headers=self.headers,
) as response:
return await response.json(content_type=None)

async def _async_describe_index(self, name: str):
async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
return await response.json(content_type=None)
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/indexes/{name}",
headers=self.headers,
) as response:
return await response.json(content_type=None)

async def _async_get_all(
self, prefix: Optional[str] = None, include_metadata: bool = False
Expand Down Expand Up @@ -819,13 +839,16 @@ async def _async_get_all(
if next_page_token:
params["paginationToken"] = next_page_token

async with self.async_client.get(
list_url, params=params, headers={"Api-Key": self.api_key}
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching vectors: {error_text}")
break
async with aiohttp.ClientSession() as session:
async with session.get(
list_url,
params=params,
headers=self.headers,
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching vectors: {error_text}")
break

response_data = await response.json(content_type=None)

Expand Down Expand Up @@ -877,23 +900,24 @@ async def _async_fetch_metadata(
"Api-Key": self.api_key,
}

async with self.async_client.get(
url, params=params, headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching metadata: {error_text}")
return {}

try:
response_data = await response.json(content_type=None)
except Exception as e:
logger.warning(f"No metadata found for vector {vector_id}: {e}")
return {}

return (
response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
)
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching metadata: {error_text}")
return {}

try:
response_data = await response.json(content_type=None)
except Exception as e:
logger.warning(f"No metadata found for vector {vector_id}: {e}")
return {}

return (
response_data.get("vectors", {})
.get(vector_id, {})
.get("metadata", {})
)

def __len__(self):
namespace_stats = self.index.describe_index_stats()["namespaces"].get(
Expand Down

0 comments on commit c8e7a62

Please sign in to comment.