diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 46d5abb3..1f4ccaef 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -4,7 +4,7 @@ import numpy as np from pydantic.v1 import BaseModel -from semantic_router.schema import ConfigParameter +from semantic_router.schema import ConfigParameter, Utterance from semantic_router.route import Route from semantic_router.utils.logger import logger @@ -40,7 +40,7 @@ def add( """ raise NotImplementedError("This method should be implemented by subclasses.") - def get_utterances(self) -> List[Tuple]: + def get_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -50,7 +50,7 @@ def get_utterances(self) -> List[Tuple]: """ _, metadata = self._get_all(include_metadata=True) route_tuples = parse_route_info(metadata=metadata) - return route_tuples + return [Utterance.from_tuple(x) for x in route_tuples] def get_routes(self) -> List[Route]: """Gets a list of route objects currently stored in the index. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 00210613..faf24084 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -2,7 +2,7 @@ import numpy as np -from semantic_router.schema import ConfigParameter +from semantic_router.schema import ConfigParameter, Utterance from semantic_router.index.base import BaseIndex from semantic_router.linear import similarity_matrix, top_scores from semantic_router.utils.logger import logger @@ -61,7 +61,7 @@ def _sync_index( if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") - def get_utterances(self) -> List[Tuple]: + def get_utterances(self) -> List[Utterance]: """ Gets a list of route and utterance objects currently stored in the index. @@ -70,7 +70,9 @@ def get_utterances(self) -> List[Tuple]: """ if self.routes is None or self.utterances is None: return [] - return list(zip(self.routes, self.utterances)) + return [ + Utterance.from_tuple(x) for x in zip(self.routes, self.utterances) + ] def describe(self) -> Dict: return { diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 0e2339e4..74b8b50c 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -15,7 +15,7 @@ from semantic_router.index.local import LocalIndex from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route -from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice +from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice, Utterance, UtteranceDiff from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger @@ -218,31 +218,23 @@ def to_file(self, path: str): elif ext in [".yaml", ".yml"]: yaml.safe_dump(self.to_dict(), f) - def _get_diff(self, other: "LayerConfig") -> List[str]: - """Get the difference between two LayerConfigs. + def to_utterances(self) -> List[Utterance]: + """Convert the routes to a list of Utterance objects. - :param other: The LayerConfig to compare to. - :type other: LayerConfig - :return: A list of differences between the two LayerConfigs. - :rtype: List[Dict[str, Any]] + :return: A list of Utterance objects. + :rtype: List[Utterance] """ - # TODO: formalize diffs into likely LayerDiff objects that can then - # output different formats as required to enable smarter syncs - self_yaml = yaml.dump(self.to_dict()) - other_yaml = yaml.dump(other.to_dict()) - differ = Differ() - return list(differ.compare(self_yaml.splitlines(), other_yaml.splitlines())) - - def show_diff(self, other: "LayerConfig") -> str: - """Show the difference between two LayerConfigs. - - :param other: The LayerConfig to compare to. - :type other: LayerConfig - :return: A string showing the difference between the two LayerConfigs. - :rtype: str - """ - diff = self._get_diff(other) - return "\n".join(diff) + utterances = [] + for route in self.routes: + utterances.extend([ + Utterance( + route=route.name, + utterance=x, + function_schemas=route.function_schemas, + metadata=route.metadata + ) for x in route.utterances + ]) + return utterances def add(self, route: Route): self.routes.append(route) @@ -283,6 +275,7 @@ def __init__( index: Optional[BaseIndex] = None, # type: ignore top_k: int = 5, aggregation: str = "sum", + auto_sync: Optional[str] = None, ): self.index: BaseIndex = index if index is not None else LocalIndex() if encoder is None: @@ -310,14 +303,17 @@ def __init__( f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." ) self.aggregation_method = self._set_aggregation_method(self.aggregation) + self.auto_sync = auto_sync # set route score thresholds if not already set for route in self.routes: if route.score_threshold is None: route.score_threshold = self.score_threshold # if routes list has been passed, we initialize index now - if self.index.sync: + if self.auto_sync: # initialize index now + dims = self.encoder.dimensions + self.index._init_index(force_create=True, dimensions=dims) if len(self.routes) > 0: self._add_and_sync_routes(routes=self.routes) else: @@ -447,6 +443,100 @@ def retrieve_multiple_routes( return route_choices + def sync(self, sync_mode: str, force: bool = False) -> List[str]: + """Runs a sync of the local routes with the remote index. + + :param sync_mode: The mode to sync the routes with the remote index. + :type sync_mode: str + :param force: Whether to force the sync even if the local and remote + hashes already match. Defaults to False. + :type force: bool, optional + :return: A list of diffs describing the addressed differences between + the local and remote route layers. + :rtype: List[str] + """ + if not force and self.is_synced(): + logger.warning("Local and remote route layers are already synchronized.") + # create utterance diff to return, but just using local instance + # for speed + local_utterances = self.to_config().to_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=local_utterances, + ) + return diff.to_utterance_str() + # otherwise we continue with the sync, first creating a diff + local_utterances = self.to_config().to_utterances() + remote_utterances = self.index.get_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances, + ) + # generate sync strategy + sync_strategy = diff.to_sync_strategy() + # and execute + self._execute_sync_strategy(sync_strategy) + return diff.to_utterance_str() + + def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): + """Executes the provided sync strategy, either deleting or upserting + routes from the local and remote instances as defined in the strategy. + + :param strategy: The sync strategy to execute. + :type strategy: Dict[str, Dict[str, List[Utterance]]] + """ + if strategy["remote"]["delete"]: + data_to_delete = {} # type: ignore + for utt_obj in strategy["remote"]["delete"]: + data_to_delete.setdefault( + utt_obj.route, [] + ).append(utt_obj.utterance) + self.index._remove_and_sync(data_to_delete) + if strategy["remote"]["upsert"]: + utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]] + self.index.add( + embeddings=self.encoder(utterances_text), + routes=[utt.route for utt in strategy["remote"]["upsert"]], + utterances=utterances_text, + function_schemas=[utt.function_schemas for utt in strategy["remote"]["upsert"]], + metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]], + ) + if strategy["local"]["delete"]: + self._local_delete(utterances=strategy["local"]["delete"]) + if strategy["local"]["upsert"]: + self._local_upsert(utterances=strategy["local"]["upsert"]) + # update hash + self._write_hash() + + def _local_upsert(self, utterances: List[Utterance]): + """Adds new routes to the RouteLayer. + + :param utterances: The utterances to add to the local RouteLayer. + :type utterances: List[Utterance] + """ + new_routes = {} + for utt_obj in utterances: + if utt_obj.route not in new_routes.keys(): + new_routes[utt_obj.route] = Route( + name=utt_obj.route, + utterances=[utt_obj.utterance], + function_schemas=utt_obj.function_schemas, + metadata=utt_obj.metadata + ) + else: + new_routes[utt_obj.route].utterances.append(utt_obj.utterance) + self.routes.extend(list(new_routes.values())) + + def _local_delete(self, utterances: List[Utterance]): + """Deletes routes from the local RouteLayer. + + :param utterances: The utterances to delete from the local RouteLayer. + :type utterances: List[Utterance] + """ + route_names = set([utt.route for utt in utterances]) + self.routes = [route for route in self.routes if route.name not in route_names] + + def _retrieve_top_route( self, vector: List[float], route_filter: Optional[List[str]] = None ) -> Tuple[Optional[Route], List[float]]: @@ -735,97 +825,27 @@ def get_utterance_diff(self) -> List[str]: "route2: utterance4", which do not exist locally. """ # first we get remote and local utterances - remote_utterances = [f"{x[0]}: {x[1]}" for x in self.index.get_utterances()] - local_routes, local_utterance_arr, _ = self._extract_routes_details( - self.routes, include_metadata=False - ) - local_utterances = [ - f"{x[0]}: {x[1]}" for x in zip(local_routes, local_utterance_arr) - ] - # sort local and remote utterances - local_utterances.sort() - remote_utterances.sort() - # now get diff - differ = Differ() - diff = list(differ.compare(local_utterances, remote_utterances)) - return diff - - def _add_and_sync_routes(self, routes: List[Route]): - # get current local hash - current_local_hash = self._get_hash() - current_remote_hash = self.index._read_hash() - if current_remote_hash.value == "": - # if remote hash is empty, the index is to be initialized - current_remote_hash = current_local_hash - # create embeddings for all routes and sync at startup with remote ones based on sync setting - local_route_names, local_utterances, local_function_schemas, local_metadata = ( - self._extract_routes_details(routes, include_metadata=True) - ) + remote_utterances = self.index.get_utterances() + local_utterances = self.to_config().to_utterances() - routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( - local_route_names, - local_utterances, - local_function_schemas, - local_metadata, - dimensions=self.index.dimensions or len(self.encoder(["dummy"])[0]), + diff_obj = UtteranceDiff.from_utterances( + local_utterances=local_utterances, remote_utterances=remote_utterances ) + return diff_obj.to_utterance_str() - data_to_delete = {} # type: ignore - for route, utterance in routes_to_delete: - data_to_delete.setdefault(route, []).append(utterance) - self.index._remove_and_sync(data_to_delete) - - # Prepare data for addition - if routes_to_add: - ( - route_names_to_add, - all_utterances_to_add, - function_schemas_to_add, - metadata_to_add, - ) = map(list, zip(*routes_to_add)) - else: - ( - route_names_to_add, - all_utterances_to_add, - function_schemas_to_add, - metadata_to_add, - ) = ([], [], [], []) - - embedded_utterances_to_add = ( - self.encoder(all_utterances_to_add) if all_utterances_to_add else [] - ) + def _add_and_sync_routes(self, routes: List[Route]): + self.routes.extend(routes) + # first we get remote and local utterances + remote_utterances = self.index.get_utterances() + local_utterances = self.to_config().to_utterances() - self.index.add( - embeddings=embedded_utterances_to_add, - routes=route_names_to_add, - utterances=all_utterances_to_add, - function_schemas=function_schemas_to_add, - metadata_list=metadata_to_add, + diff_obj = UtteranceDiff.from_utterances( + local_utterances=local_utterances, remote_utterances=remote_utterances ) - - # Update local route layer state - self.routes = [] - for route, data in layer_routes_dict.items(): - function_schemas = data.get("function_schemas", None) - if function_schemas is not None: - function_schemas = [function_schemas] - self.routes.append( - Route( - name=route, - utterances=data.get("utterances", []), - function_schemas=function_schemas, - metadata=data.get("metadata", {}), - ) - ) - # update hash IF index and local hash were aligned - if current_local_hash.value == current_remote_hash.value: - self._write_hash() - else: - logger.warning( - "Local and remote route layers were not aligned. Remote hash " - "not updated. Use `RouteLayer.get_utterance_diff()` to see " - "details." - ) + sync_strategy = diff_obj.get_sync_strategy(sync_mode=self.auto_sync) + self._execute_sync_strategy(strategy=sync_strategy) + # update remote hash + self._write_hash() def _extract_routes_details( self, routes: List[Route], include_metadata: bool = False diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 8d87f017..634cd25b 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,6 +1,7 @@ from datetime import datetime +from difflib import Differ from enum import Enum -from typing import List, Optional, Union, Any, Dict +from typing import List, Optional, Union, Any, Dict, Tuple from pydantic.v1 import BaseModel, Field @@ -18,6 +19,9 @@ class EncoderType(Enum): GOOGLE = "google" BEDROCK = "bedrock" + def to_list(): + return [encoder.value for encoder in EncoderType] + class EncoderInfo(BaseModel): name: str @@ -86,6 +90,234 @@ def to_pinecone(self, dimensions: int): } +class Utterance(BaseModel): + route: str + utterance: str + function_schemas: Optional[List[Dict]] = None + metadata: Optional[Dict] = None + diff_tag: str = " " + + @classmethod + def from_tuple(cls, tuple_obj: Tuple): + """Create an Utterance object from a tuple. The tuple must contain + route and utterance as the first two elements. Then optionally + function schemas and metadata as the third and fourth elements + respectively. If this order is not followed an invalid Utterance + object will be returned. + + :param tuple_obj: A tuple containing route, utterance, function schemas and metadata. + :type tuple_obj: Tuple + :return: An Utterance object. + :rtype: Utterance + """ + route, utterance = tuple_obj[0], tuple_obj[1] + function_schemas = tuple_obj[2] if len(tuple_obj) > 2 else None + metadata = tuple_obj[3] if len(tuple_obj) > 3 else None + return cls( + route=route, + utterance=utterance, + function_schemas=function_schemas, + metadata=metadata + ) + + def to_tuple(self): + return ( + self.route, + self.utterance, + self.function_schemas, + self.metadata, + ) + + def to_str(self, include_metadata: bool = False): + if include_metadata: + return f"{self.route}: {self.utterance} | {self.function_schemas} | {self.metadata}" + return f"{self.route}: {self.utterance}" + + def to_diff_str(self): + return f"{self.diff_tag} {self.to_str()}" + + +class SyncMode(Enum): + """Synchronization modes for local (route layer) and remote (index) + instances. + """ + ERROR = "error" + REMOTE = "remote" + LOCAL = "local" + MERGE_FORCE_REMOTE = "merge-force-remote" + MERGE_FORCE_LOCAL = "merge-force-local" + MERGE = "merge" + + def to_list() -> List[str]: + return [mode.value for mode in SyncMode] + +class UtteranceDiff(BaseModel): + diff: List[Utterance] + + @classmethod + def from_utterances( + cls, + local_utterances: List[Utterance], + remote_utterances: List[Utterance] + ): + local_utterances_map = {x.to_str(): x for x in local_utterances} + remote_utterances_map = {x.to_str(): x for x in remote_utterances} + # sort local and remote utterances + local_utterances_str = list(local_utterances_map.keys()) + local_utterances_str.sort() + remote_utterances_str = list(remote_utterances_map.keys()) + remote_utterances_str.sort() + # get diff + differ = Differ() + diff_obj = list(differ.compare(local_utterances_str, remote_utterances_str)) + # create UtteranceDiff list + utterance_diffs = [] + for line in diff_obj: + utterance_str = line[2:] + utterance_diff_tag = line[0] + utterance = remote_utterances_map[utterance_str] if utterance_diff_tag == "+" else local_utterances_map[utterance_str] + utterance.diff_tag = utterance_diff_tag + utterance_diffs.append(utterance) + return UtteranceDiff(diff=utterance_diffs) + + def to_utterance_str(self) -> List[str]: + """Outputs the utterance diff as a list of diff strings. Returns a list + of strings showing what is different in the remote when compared to the + local. For example: + + [" route1: utterance1", + " route1: utterance2", + "- route2: utterance3", + "- route2: utterance4"] + + Tells us that the remote is missing "route2: utterance3" and "route2: + utterance4", which do exist locally. If we see: + + [" route1: utterance1", + " route1: utterance2", + "+ route2: utterance3", + "+ route2: utterance4"] + + This diff tells us that the remote has "route2: utterance3" and + "route2: utterance4", which do not exist locally. + """ + return [x.to_diff_str() for x in self.diff] + + def get_tag(self, diff_tag: str) -> List[Utterance]: + """Get all utterances with a given diff tag. + + :param diff_tag: The diff tag to filter by. Must be one of "+", "-", or + " ". + :type diff_tag: str + :return: A list of Utterance objects. + :rtype: List[Utterance] + """ + if diff_tag not in ["+", "-", " "]: + raise ValueError("diff_tag must be one of '+', '-', or ' '") + return [x for x in self.diff if x.diff_tag == diff_tag] + + def get_sync_strategy(self, sync_mode: str) -> dict: + """Generates the optimal synchronization plan for local and remote + instances. + + :param sync_mode: The mode to sync the routes with the remote index. + :type sync_mode: str + :return: A dictionary describing the synchronization strategy. + :rtype: dict + """ + if sync_mode not in SyncMode.to_list(): + raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}") + local_only = self.get_tag("-") + remote_only = self.get_tag("+") + local_and_remote = self.get_tag(" ") + if sync_mode == "error": + if len(local_only) > 0 or len(remote_only) > 0: + raise ValueError( + "There are utterances that exist in the local or remote " + "instance that do not exist in the other instance. Please " + "sync the routes before running this command." + ) + else: + return { + "remote": { + "upsert": [], + "delete": [] + }, + "local": { + "upsert": [], + "delete": [] + } + } + elif sync_mode == "local": + return { + "remote": { + "upsert": local_only, + "delete": remote_only + }, + "local": { + "upsert": [], + "delete": [] + } + } + elif sync_mode == "remote": + return { + "remote": { + "upsert": [], + "delete": [] + }, + "local": { + "upsert": remote_only, + "delete": local_only + } + } + elif sync_mode == "merge-force-remote": + # get set of route names that exist in both local and remote + routes_in_both = set([utt.route for utt in local_and_remote]) + # get remote utterances that belong to routes_in_both + remote_to_keep = [utt for utt in remote_only if utt.route in routes_in_both] + # get remote utterances that do NOT belong to routes_in_both + remote_to_delete = [utt for utt in remote_only if utt.route not in routes_in_both] + return { + "remote": { + "upsert": local_only, + "delete": remote_to_delete + }, + "local": { + "upsert": remote_to_keep, + "delete": [] + } + } + elif sync_mode == "merge-force-local": + # get set of route names that exist in both local and remote + routes_in_both = set([utt.route for utt in local_and_remote]) + # get local utterances that belong to routes_in_both + local_to_keep = [utt for utt in local_only if utt.route in routes_in_both] + # get local utterances that do NOT belong to routes_in_both + local_to_delete = [utt for utt in local_only if utt.route not in routes_in_both] + return { + "remote": { + "upsert": local_to_keep, + "delete": [] + }, + "local": { + "upsert": remote_only, + "delete": local_to_delete + } + } + elif sync_mode == "merge": + return { + "remote": { + "upsert": local_only, + "delete": [] + }, + "local": { + "upsert": remote_only, + "delete": [] + } + } + + + class Metric(Enum): COSINE = "cosine" DOTPRODUCT = "dotproduct" diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 0f81476b..8d337f2f 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -238,6 +238,7 @@ def test_initialization_dynamic_route( assert route_layer_openai.score_threshold == 0.3 def test_delete_index(self, openai_encoder, routes, index_cls): + # TODO merge .delete_index() and .delete_all() and get working index = init_index(index_cls) route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) if index_cls is PineconeIndex: diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 7c89f143..c4877151 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -237,3 +237,101 @@ def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls): assert "+ Route 2: Bye" in diff assert "+ Route 2: Goodbye" in diff assert " Route 2: Hi" in diff + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_auto_sync_local(self, openai_encoder, routes, routes_2, routes_4, index_cls): + if index_cls is PineconeIndex: + # TEST LOCAL + pinecone_index = init_index(index_cls) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes_2, index=pinecone_index, + auto_sync="local" + ) + time.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.index.get_utterances() == [ + ("Route 1", "Hello", None, {}), + ("Route 2", "Hi", None, {}), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_auto_sync_remote(self, openai_encoder, routes, index_cls): + if index_cls is PineconeIndex: + + # TEST REMOTE + pinecone_index = init_index(index_cls) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=pinecone_index, + auto_sync="remote" + ) + + time.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.index.get_utterances() == [ + ("Route 1", "Hello", None, {}), + ("Route 2", "Hi", None, {}), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_auto_sync_merge_force_remote(self, openai_encoder, routes, index_cls): + if index_cls is PineconeIndex: + # TEST MERGE FORCE REMOTE + pinecone_index = init_index(index_cls) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=pinecone_index, + auto_sync="merge-force-remote" + ) + + time.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.index.get_utterances() == [ + ("Route 1", "Hello", None, {}), + ("Route 2", "Hi", None, {}), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_auto_sync_merge_force_local(self, openai_encoder, routes, index_cls): + if index_cls is PineconeIndex: + # TEST MERGE FORCE LOCAL + pinecone_index = init_index(index_cls) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=pinecone_index, + auto_sync="merge-force-local" + ) + + time.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.index.get_utterances() == [ + ("Route 1", "Hello", None, {"type": "default"}), + ("Route 1", "Hi", None, {"type": "default"}), + ("Route 2", "Bye", None, {}), + ("Route 2", "Au revoir", None, {}), + ("Route 2", "Goodbye", None, {}), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + def test_auto_sync_merge(self, openai_encoder, routes_4, index_cls): + if index_cls is PineconeIndex: + # TEST MERGE + pinecone_index = init_index(index_cls) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes_4, index=pinecone_index, + auto_sync="merge" + ) + + time.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.index.get_utterances() == [ + ("Route 1", "Hello", None, {"type": "default"}), + ("Route 1", "Hi", None, {"type": "default"}), + ("Route 1", "Goodbye", None, {"type": "default"}), + ("Route 2", "Bye", None, {}), + ("Route 2", "Asparagus", None, {}), + ("Route 2", "Au revoir", None, {}), + ("Route 2", "Goodbye", None, {}), + ], "The routes in the index should match the local routes"