diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 884106c0..fee3c4e2 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -1,3 +1,5 @@ +from datetime import datetime +import time from typing import Any, List, Optional, Tuple, Union, Dict import json @@ -157,26 +159,87 @@ def delete_index(self): logger.warning("This method should be implemented by subclasses.") self.index = None - def _read_hash(self) -> ConfigParameter: - """ - Read the hash of the previously written index. + def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + """Read a config parameter from the index. - This method should be implemented by subclasses. + :param field: The field to read. + :type field: str + :param scope: The scope to read. + :type scope: str | None + :return: The config parameter that was read. + :rtype: ConfigParameter """ logger.warning("This method should be implemented by subclasses.") return ConfigParameter( - field="sr_hash", + field=field, value="", - namespace="", + scope=scope, ) - def _write_config(self, config: ConfigParameter): + def _read_hash(self) -> ConfigParameter: + """Read the hash of the previously written index. + + :return: The config parameter that was read. + :rtype: ConfigParameter """ - Write a config parameter to the index. + return self._read_config(field="sr_hash") - This method should be implemented by subclasses. + def _write_config(self, config: ConfigParameter) -> ConfigParameter: + """Write a config parameter to the index. + + :param config: The config parameter to write. + :type config: ConfigParameter + :return: The config parameter that was written. + :rtype: ConfigParameter """ logger.warning("This method should be implemented by subclasses.") + return config + + def lock( + self, + value: bool, + wait: int = 0, + scope: str | None = None + ) -> ConfigParameter: + """Lock/unlock the index for a given scope (if applicable). If index + already locked/unlocked, raises ValueError. + + :param scope: The scope to lock. + :type scope: str | None + :param wait: The number of seconds to wait for the index to be unlocked, if + set to 0, will raise an error if index is already locked/unlocked. + :type wait: int + :return: The config parameter that was locked. + :rtype: ConfigParameter + """ + start_time = datetime.now() + while True: + if self._is_locked(scope=scope) != value: + # in this case, we can set the lock value + break + if (datetime.now() - start_time).total_seconds() > wait: + # wait for 2.5 seconds before checking again + time.sleep(2.5) + else: + raise ValueError(f"Index is already {'locked' if value else 'unlocked'}.") + lock_param = ConfigParameter( + field="sr_lock", + value=str(value), + scope=scope, + ) + self._write_config(lock_param) + return lock_param + + def _is_locked(self, scope: str | None = None) -> bool: + """Check if the index is locked for a given scope (if applicable). + + :param scope: The scope to check. + :type scope: str | None + :return: True if the index is locked, False otherwise. + :rtype: bool + """ + lock_config = self._read_config(field="sr_lock", scope=scope) + return bool(lock_config.value) def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index b4ba144e..469a4141 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -405,39 +405,43 @@ def query( route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names - def _read_hash(self) -> ConfigParameter: + def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + scope = scope or self.namespace if self.index is None: return ConfigParameter( - field="sr_hash", + field=field, value="", - namespace=self.namespace, + scope=scope, ) - hash_id = f"sr_hash#{self.namespace}" - hash_record = self.index.fetch( - ids=[hash_id], + config_id = f"{field}#{scope}" + config_record = self.index.fetch( + ids=[config_id], namespace="sr_config", ) - if hash_record["vectors"]: + if config_record["vectors"]: return ConfigParameter( - field="sr_hash", - value=hash_record["vectors"][hash_id]["metadata"]["value"], - created_at=hash_record["vectors"][hash_id]["metadata"]["created_at"], - namespace=self.namespace, + field=field, + value=config_record["vectors"][config_id]["metadata"]["value"], + created_at=config_record["vectors"][config_id]["metadata"][ + "created_at" + ], + scope=scope, ) else: - logger.warning("Configuration for hash parameter not found in index.") + logger.warning(f"Configuration for {field} parameter not found in index.") return ConfigParameter( - field="sr_hash", + field=field, value="", - namespace=self.namespace, + scope=scope, ) - def _write_config(self, config: ConfigParameter) -> None: + def _write_config(self, config: ConfigParameter) -> ConfigParameter: """Method to write a config parameter to the remote Pinecone index. :param config: The config parameter to write to the index. :type config: ConfigParameter """ + config.scope = config.scope or self.namespace if self.index is None: raise ValueError("Index has not been initialized.") if self.dimensions is None: @@ -446,6 +450,7 @@ def _write_config(self, config: ConfigParameter) -> None: vectors=[config.to_pinecone(dimensions=self.dimensions)], namespace="sr_config", ) + return config async def aquery( self, diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 3628fda8..347c5db9 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -543,7 +543,7 @@ async def _async_retrieve_top_route( route = self.check_for_matching_routes(top_class) return route, top_class_scores - def sync(self, sync_mode: str, force: bool = False) -> List[str]: + def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: """Runs a sync of the local routes with the remote index. :param sync_mode: The mode to sync the routes with the remote index. @@ -551,6 +551,10 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]: :param force: Whether to force the sync even if the local and remote hashes already match. Defaults to False. :type force: bool, optional + :param wait: The number of seconds to wait for the index to be unlocked + before proceeding with the sync. If set to 0, will raise an error if + index is already locked/unlocked. + :type wait: int :return: A list of diffs describing the addressed differences between the local and remote route layers. :rtype: List[str] @@ -565,7 +569,9 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]: remote_utterances=local_utterances, ) return diff.to_utterance_str() - # otherwise we continue with the sync, first creating a diff + # otherwise we continue with the sync, first locking the index + _ = self.index.lock(value=True, wait=wait) + # first creating a diff local_utterances = self.to_config().to_utterances() remote_utterances = self.index.get_utterances() diff = UtteranceDiff.from_utterances( @@ -576,6 +582,8 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]: sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) # and execute self._execute_sync_strategy(sync_strategy) + # unlock index after sync + _ = self.index.lock(value=False) return diff.to_utterance_str() def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): @@ -781,6 +789,9 @@ def delete(self, route_name: str): :param route_name: the name of the route to be deleted :type str: """ + # ensure index is not locked + if self.index._is_locked(): + raise ValueError("Index is locked. Cannot delete route.") current_local_hash = self._get_hash() current_remote_hash = self.index._read_hash() if current_remote_hash.value == "": @@ -829,9 +840,13 @@ def _get_hash(self) -> ConfigParameter: return config.get_hash() def _write_hash(self) -> ConfigParameter: + # lock index before writing + _ = self.index.lock(value=True) config = self.to_config() hash_config = config.get_hash() self.index._write_config(config=hash_config) + # unlock index after writing + _ = self.index.lock(value=False) return hash_config def is_synced(self) -> bool: diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 2a94b355..d4f89aa8 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -62,12 +62,11 @@ def __str__(self): class ConfigParameter(BaseModel): field: str value: str - namespace: Optional[str] = None - created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + scope: Optional[str] = None + created_at: str = Field(default_factory=lambda: datetime.now(datetime.UTC).isoformat()) def to_pinecone(self, dimensions: int): - if self.namespace is None: - namespace = "" + namespace = self.scope or "" return { "id": f"{self.field}#{namespace}", "values": [0.1] * dimensions, diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 8e73de34..a4e4539c 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -487,3 +487,86 @@ def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls): # clear index route_layer.index.index.delete(namespace="", delete_all=True) + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_sync_lock_prevents_concurrent_sync(self, openai_encoder, routes, index_cls): + """Test that sync lock prevents concurrent synchronization operations""" + index = init_index(index_cls) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Acquire sync lock + route_layer.index.lock(value=True) + + # Attempt to sync while lock is held should raise exception + with pytest.raises(RuntimeError, match="Sync operation already in progress"): + route_layer.sync("local") + + # Release lock + route_layer.index.lock(value=False) + + # Should succeed after lock is released + route_layer.sync("local") + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + assert route_layer.is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): + """Test that sync lock is automatically released after sync operations""" + index = init_index(index_cls) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Initial sync should acquire and release lock + route_layer.sync("local") + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + + # Lock should be released, allowing another sync + route_layer.sync("local") # Should not raise exception + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + assert route_layer.is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_sync_lock_releases_on_error(self, openai_encoder, routes, index_cls): + """Test that sync lock is released even if sync operation fails""" + index = init_index(index_cls) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Force an error during sync by temporarily breaking the index + original_sync = route_layer.index.sync + route_layer.index.sync = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("Forced sync error")) + + # Sync should fail but release the lock + with pytest.raises(Exception, match="Forced sync error"): + route_layer.sync("local") + + # Restore original sync method + route_layer.index.sync = original_sync + + # Should be able to sync again since lock was released + route_layer.sync("local") + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) + assert route_layer.is_synced()