Skip to content

Commit

Permalink
feat: continued refactoring for sync features
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Nov 12, 2024
1 parent e27f444 commit 4054692
Show file tree
Hide file tree
Showing 6 changed files with 471 additions and 118 deletions.
6 changes: 3 additions & 3 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from pydantic.v1 import BaseModel

from semantic_router.schema import ConfigParameter
from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.route import Route
from semantic_router.utils.logger import logger

Expand Down Expand Up @@ -40,7 +40,7 @@ def add(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def get_utterances(self) -> List[Tuple]:
def get_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
Expand All @@ -50,7 +50,7 @@ def get_utterances(self) -> List[Tuple]:
"""
_, metadata = self._get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata)
return route_tuples
return [Utterance.from_tuple(x) for x in route_tuples]

def get_routes(self) -> List[Route]:
"""Gets a list of route objects currently stored in the index.
Expand Down
8 changes: 5 additions & 3 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from semantic_router.schema import ConfigParameter
from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.index.base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger
Expand Down Expand Up @@ -61,7 +61,7 @@ def _sync_index(
if self.sync is not None:
logger.error("Sync remove is not implemented for LocalIndex.")

def get_utterances(self) -> List[Tuple]:
def get_utterances(self) -> List[Utterance]:
"""
Gets a list of route and utterance objects currently stored in the index.
Expand All @@ -70,7 +70,9 @@ def get_utterances(self) -> List[Tuple]:
"""
if self.routes is None or self.utterances is None:
return []
return list(zip(self.routes, self.utterances))
return [
Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)
]

def describe(self) -> Dict:
return {
Expand Down
242 changes: 131 additions & 111 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from semantic_router.index.local import LocalIndex
from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route
from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice
from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice, Utterance, UtteranceDiff
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger

Expand Down Expand Up @@ -218,31 +218,23 @@ def to_file(self, path: str):
elif ext in [".yaml", ".yml"]:
yaml.safe_dump(self.to_dict(), f)

def _get_diff(self, other: "LayerConfig") -> List[str]:
"""Get the difference between two LayerConfigs.
def to_utterances(self) -> List[Utterance]:
"""Convert the routes to a list of Utterance objects.
:param other: The LayerConfig to compare to.
:type other: LayerConfig
:return: A list of differences between the two LayerConfigs.
:rtype: List[Dict[str, Any]]
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
# TODO: formalize diffs into likely LayerDiff objects that can then
# output different formats as required to enable smarter syncs
self_yaml = yaml.dump(self.to_dict())
other_yaml = yaml.dump(other.to_dict())
differ = Differ()
return list(differ.compare(self_yaml.splitlines(), other_yaml.splitlines()))

def show_diff(self, other: "LayerConfig") -> str:
"""Show the difference between two LayerConfigs.
:param other: The LayerConfig to compare to.
:type other: LayerConfig
:return: A string showing the difference between the two LayerConfigs.
:rtype: str
"""
diff = self._get_diff(other)
return "\n".join(diff)
utterances = []
for route in self.routes:
utterances.extend([
Utterance(
route=route.name,
utterance=x,
function_schemas=route.function_schemas,
metadata=route.metadata
) for x in route.utterances
])
return utterances

def add(self, route: Route):
self.routes.append(route)
Expand Down Expand Up @@ -283,6 +275,7 @@ def __init__(
index: Optional[BaseIndex] = None, # type: ignore
top_k: int = 5,
aggregation: str = "sum",
auto_sync: Optional[str] = None,
):
self.index: BaseIndex = index if index is not None else LocalIndex()
if encoder is None:
Expand Down Expand Up @@ -310,14 +303,17 @@ def __init__(
f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'."
)
self.aggregation_method = self._set_aggregation_method(self.aggregation)
self.auto_sync = auto_sync

# set route score thresholds if not already set
for route in self.routes:
if route.score_threshold is None:
route.score_threshold = self.score_threshold
# if routes list has been passed, we initialize index now
if self.index.sync:
if self.auto_sync:
# initialize index now
dims = self.encoder.dimensions
self.index._init_index(force_create=True, dimensions=dims)
if len(self.routes) > 0:
self._add_and_sync_routes(routes=self.routes)
else:
Expand Down Expand Up @@ -447,6 +443,100 @@ def retrieve_multiple_routes(

return route_choices

def sync(self, sync_mode: str, force: bool = False) -> List[str]:
"""Runs a sync of the local routes with the remote index.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:param force: Whether to force the sync even if the local and remote
hashes already match. Defaults to False.
:type force: bool, optional
:return: A list of diffs describing the addressed differences between
the local and remote route layers.
:rtype: List[str]
"""
if not force and self.is_synced():
logger.warning("Local and remote route layers are already synchronized.")
# create utterance diff to return, but just using local instance
# for speed
local_utterances = self.to_config().to_utterances()
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=local_utterances,
)
return diff.to_utterance_str()
# otherwise we continue with the sync, first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=remote_utterances,
)
# generate sync strategy
sync_strategy = diff.to_sync_strategy()
# and execute
self._execute_sync_strategy(sync_strategy)
return diff.to_utterance_str()

def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
"""Executes the provided sync strategy, either deleting or upserting
routes from the local and remote instances as defined in the strategy.
:param strategy: The sync strategy to execute.
:type strategy: Dict[str, Dict[str, List[Utterance]]]
"""
if strategy["remote"]["delete"]:
data_to_delete = {} # type: ignore
for utt_obj in strategy["remote"]["delete"]:
data_to_delete.setdefault(
utt_obj.route, []
).append(utt_obj.utterance)
self.index._remove_and_sync(data_to_delete)
if strategy["remote"]["upsert"]:
utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
self.index.add(
embeddings=self.encoder(utterances_text),
routes=[utt.route for utt in strategy["remote"]["upsert"]],
utterances=utterances_text,
function_schemas=[utt.function_schemas for utt in strategy["remote"]["upsert"]],
metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
)
if strategy["local"]["delete"]:
self._local_delete(utterances=strategy["local"]["delete"])
if strategy["local"]["upsert"]:
self._local_upsert(utterances=strategy["local"]["upsert"])
# update hash
self._write_hash()

def _local_upsert(self, utterances: List[Utterance]):
"""Adds new routes to the RouteLayer.
:param utterances: The utterances to add to the local RouteLayer.
:type utterances: List[Utterance]
"""
new_routes = {}
for utt_obj in utterances:
if utt_obj.route not in new_routes.keys():
new_routes[utt_obj.route] = Route(
name=utt_obj.route,
utterances=[utt_obj.utterance],
function_schemas=utt_obj.function_schemas,
metadata=utt_obj.metadata
)
else:
new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
self.routes.extend(list(new_routes.values()))

def _local_delete(self, utterances: List[Utterance]):
"""Deletes routes from the local RouteLayer.
:param utterances: The utterances to delete from the local RouteLayer.
:type utterances: List[Utterance]
"""
route_names = set([utt.route for utt in utterances])
self.routes = [route for route in self.routes if route.name not in route_names]


def _retrieve_top_route(
self, vector: List[float], route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]:
Expand Down Expand Up @@ -735,97 +825,27 @@ def get_utterance_diff(self) -> List[str]:
"route2: utterance4", which do not exist locally.
"""
# first we get remote and local utterances
remote_utterances = [f"{x[0]}: {x[1]}" for x in self.index.get_utterances()]
local_routes, local_utterance_arr, _ = self._extract_routes_details(
self.routes, include_metadata=False
)
local_utterances = [
f"{x[0]}: {x[1]}" for x in zip(local_routes, local_utterance_arr)
]
# sort local and remote utterances
local_utterances.sort()
remote_utterances.sort()
# now get diff
differ = Differ()
diff = list(differ.compare(local_utterances, remote_utterances))
return diff

def _add_and_sync_routes(self, routes: List[Route]):
# get current local hash
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
# create embeddings for all routes and sync at startup with remote ones based on sync setting
local_route_names, local_utterances, local_function_schemas, local_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
remote_utterances = self.index.get_utterances()
local_utterances = self.to_config().to_utterances()

routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index(
local_route_names,
local_utterances,
local_function_schemas,
local_metadata,
dimensions=self.index.dimensions or len(self.encoder(["dummy"])[0]),
diff_obj = UtteranceDiff.from_utterances(
local_utterances=local_utterances, remote_utterances=remote_utterances
)
return diff_obj.to_utterance_str()

data_to_delete = {} # type: ignore
for route, utterance in routes_to_delete:
data_to_delete.setdefault(route, []).append(utterance)
self.index._remove_and_sync(data_to_delete)

# Prepare data for addition
if routes_to_add:
(
route_names_to_add,
all_utterances_to_add,
function_schemas_to_add,
metadata_to_add,
) = map(list, zip(*routes_to_add))
else:
(
route_names_to_add,
all_utterances_to_add,
function_schemas_to_add,
metadata_to_add,
) = ([], [], [], [])

embedded_utterances_to_add = (
self.encoder(all_utterances_to_add) if all_utterances_to_add else []
)
def _add_and_sync_routes(self, routes: List[Route]):
self.routes.extend(routes)
# first we get remote and local utterances
remote_utterances = self.index.get_utterances()
local_utterances = self.to_config().to_utterances()

self.index.add(
embeddings=embedded_utterances_to_add,
routes=route_names_to_add,
utterances=all_utterances_to_add,
function_schemas=function_schemas_to_add,
metadata_list=metadata_to_add,
diff_obj = UtteranceDiff.from_utterances(
local_utterances=local_utterances, remote_utterances=remote_utterances
)

# Update local route layer state
self.routes = []
for route, data in layer_routes_dict.items():
function_schemas = data.get("function_schemas", None)
if function_schemas is not None:
function_schemas = [function_schemas]
self.routes.append(
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=function_schemas,
metadata=data.get("metadata", {}),
)
)
# update hash IF index and local hash were aligned
if current_local_hash.value == current_remote_hash.value:
self._write_hash()
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
"not updated. Use `RouteLayer.get_utterance_diff()` to see "
"details."
)
sync_strategy = diff_obj.get_sync_strategy(sync_mode=self.auto_sync)
self._execute_sync_strategy(strategy=sync_strategy)
# update remote hash
self._write_hash()

def _extract_routes_details(
self, routes: List[Route], include_metadata: bool = False
Expand Down
Loading

0 comments on commit 4054692

Please sign in to comment.