Skip to content

Commit

Permalink
feat: add sync lock
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Dec 13, 2024
1 parent a0f6192 commit 4830499
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 30 deletions.
81 changes: 72 additions & 9 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime
import time
from typing import Any, List, Optional, Tuple, Union, Dict
import json

Expand Down Expand Up @@ -157,26 +159,87 @@ def delete_index(self):
logger.warning("This method should be implemented by subclasses.")
self.index = None

def _read_hash(self) -> ConfigParameter:
"""
Read the hash of the previously written index.
def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
"""Read a config parameter from the index.
This method should be implemented by subclasses.
:param field: The field to read.
:type field: str
:param scope: The scope to read.
:type scope: str | None
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
logger.warning("This method should be implemented by subclasses.")
return ConfigParameter(
field="sr_hash",
field=field,
value="",
namespace="",
scope=scope,
)

def _write_config(self, config: ConfigParameter):
def _read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index.
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
Write a config parameter to the index.
return self._read_config(field="sr_hash")

This method should be implemented by subclasses.
def _write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Write a config parameter to the index.
:param config: The config parameter to write.
:type config: ConfigParameter
:return: The config parameter that was written.
:rtype: ConfigParameter
"""
logger.warning("This method should be implemented by subclasses.")
return config

def lock(
self,
value: bool,
wait: int = 0,
scope: str | None = None
) -> ConfigParameter:
"""Lock/unlock the index for a given scope (if applicable). If index
already locked/unlocked, raises ValueError.
:param scope: The scope to lock.
:type scope: str | None
:param wait: The number of seconds to wait for the index to be unlocked, if
set to 0, will raise an error if index is already locked/unlocked.
:type wait: int
:return: The config parameter that was locked.
:rtype: ConfigParameter
"""
start_time = datetime.now()
while True:
if self._is_locked(scope=scope) != value:
# in this case, we can set the lock value
break
if (datetime.now() - start_time).total_seconds() > wait:
# wait for 2.5 seconds before checking again
time.sleep(2.5)
else:
raise ValueError(f"Index is already {'locked' if value else 'unlocked'}.")
lock_param = ConfigParameter(
field="sr_lock",
value=str(value),
scope=scope,
)
self._write_config(lock_param)
return lock_param

def _is_locked(self, scope: str | None = None) -> bool:
"""Check if the index is locked for a given scope (if applicable).
:param scope: The scope to check.
:type scope: str | None
:return: True if the index is locked, False otherwise.
:rtype: bool
"""
lock_config = self._read_config(field="sr_lock", scope=scope)
return bool(lock_config.value)

def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
"""
Expand Down
35 changes: 20 additions & 15 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,39 +405,43 @@ def query(
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names

def _read_hash(self) -> ConfigParameter:
def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
scope = scope or self.namespace
if self.index is None:
return ConfigParameter(
field="sr_hash",
field=field,
value="",
namespace=self.namespace,
scope=scope,
)
hash_id = f"sr_hash#{self.namespace}"
hash_record = self.index.fetch(
ids=[hash_id],
config_id = f"{field}#{scope}"
config_record = self.index.fetch(
ids=[config_id],
namespace="sr_config",
)
if hash_record["vectors"]:
if config_record["vectors"]:
return ConfigParameter(
field="sr_hash",
value=hash_record["vectors"][hash_id]["metadata"]["value"],
created_at=hash_record["vectors"][hash_id]["metadata"]["created_at"],
namespace=self.namespace,
field=field,
value=config_record["vectors"][config_id]["metadata"]["value"],
created_at=config_record["vectors"][config_id]["metadata"][
"created_at"
],
scope=scope,
)
else:
logger.warning("Configuration for hash parameter not found in index.")
logger.warning(f"Configuration for {field} parameter not found in index.")
return ConfigParameter(
field="sr_hash",
field=field,
value="",
namespace=self.namespace,
scope=scope,
)

def _write_config(self, config: ConfigParameter) -> None:
def _write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Method to write a config parameter to the remote Pinecone index.
:param config: The config parameter to write to the index.
:type config: ConfigParameter
"""
config.scope = config.scope or self.namespace
if self.index is None:
raise ValueError("Index has not been initialized.")
if self.dimensions is None:
Expand All @@ -446,6 +450,7 @@ def _write_config(self, config: ConfigParameter) -> None:
vectors=[config.to_pinecone(dimensions=self.dimensions)],
namespace="sr_config",
)
return config

async def aquery(
self,
Expand Down
19 changes: 17 additions & 2 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,18 @@ async def _async_retrieve_top_route(
route = self.check_for_matching_routes(top_class)
return route, top_class_scores

def sync(self, sync_mode: str, force: bool = False) -> List[str]:
def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
"""Runs a sync of the local routes with the remote index.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:param force: Whether to force the sync even if the local and remote
hashes already match. Defaults to False.
:type force: bool, optional
:param wait: The number of seconds to wait for the index to be unlocked
before proceeding with the sync. If set to 0, will raise an error if
index is already locked/unlocked.
:type wait: int
:return: A list of diffs describing the addressed differences between
the local and remote route layers.
:rtype: List[str]
Expand All @@ -565,7 +569,9 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]:
remote_utterances=local_utterances,
)
return diff.to_utterance_str()
# otherwise we continue with the sync, first creating a diff
# otherwise we continue with the sync, first locking the index
_ = self.index.lock(value=True, wait=wait)
# first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
diff = UtteranceDiff.from_utterances(
Expand All @@ -576,6 +582,8 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]:
sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode)
# and execute
self._execute_sync_strategy(sync_strategy)
# unlock index after sync
_ = self.index.lock(value=False)
return diff.to_utterance_str()

def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
Expand Down Expand Up @@ -781,6 +789,9 @@ def delete(self, route_name: str):
:param route_name: the name of the route to be deleted
:type str:
"""
# ensure index is not locked
if self.index._is_locked():
raise ValueError("Index is locked. Cannot delete route.")
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
Expand Down Expand Up @@ -829,9 +840,13 @@ def _get_hash(self) -> ConfigParameter:
return config.get_hash()

def _write_hash(self) -> ConfigParameter:
# lock index before writing
_ = self.index.lock(value=True)
config = self.to_config()
hash_config = config.get_hash()
self.index._write_config(config=hash_config)
# unlock index after writing
_ = self.index.lock(value=False)
return hash_config

def is_synced(self) -> bool:
Expand Down
7 changes: 3 additions & 4 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ def __str__(self):
class ConfigParameter(BaseModel):
field: str
value: str
namespace: Optional[str] = None
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
scope: Optional[str] = None
created_at: str = Field(default_factory=lambda: datetime.now(datetime.UTC).isoformat())

def to_pinecone(self, dimensions: int):
if self.namespace is None:
namespace = ""
namespace = self.scope or ""
return {
"id": f"{self.field}#{namespace}",
"values": [0.1] * dimensions,
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,86 @@ def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls):

# clear index
route_layer.index.index.delete(namespace="", delete_all=True)

@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_sync_lock_prevents_concurrent_sync(self, openai_encoder, routes, index_cls):
"""Test that sync lock prevents concurrent synchronization operations"""
index = init_index(index_cls)
route_layer = SemanticRouter(
encoder=openai_encoder,
routes=routes,
index=index,
auto_sync=None,
)

# Acquire sync lock
route_layer.index.lock(value=True)

# Attempt to sync while lock is held should raise exception
with pytest.raises(RuntimeError, match="Sync operation already in progress"):
route_layer.sync("local")

# Release lock
route_layer.index.lock(value=False)

# Should succeed after lock is released
route_layer.sync("local")
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP)
assert route_layer.is_synced()

@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls):
"""Test that sync lock is automatically released after sync operations"""
index = init_index(index_cls)
route_layer = SemanticRouter(
encoder=openai_encoder,
routes=routes,
index=index,
auto_sync=None,
)

# Initial sync should acquire and release lock
route_layer.sync("local")
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP)

# Lock should be released, allowing another sync
route_layer.sync("local") # Should not raise exception
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP)
assert route_layer.is_synced()

@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_sync_lock_releases_on_error(self, openai_encoder, routes, index_cls):
"""Test that sync lock is released even if sync operation fails"""
index = init_index(index_cls)
route_layer = SemanticRouter(
encoder=openai_encoder,
routes=routes,
index=index,
auto_sync=None,
)

# Force an error during sync by temporarily breaking the index
original_sync = route_layer.index.sync
route_layer.index.sync = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("Forced sync error"))

# Sync should fail but release the lock
with pytest.raises(Exception, match="Forced sync error"):
route_layer.sync("local")

# Restore original sync method
route_layer.index.sync = original_sync

# Should be able to sync again since lock was released
route_layer.sync("local")
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP)
assert route_layer.is_synced()

0 comments on commit 4830499

Please sign in to comment.