Skip to content

Commit

Permalink
fix: local index tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Nov 16, 2024
1 parent b522b8f commit b3cc85d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
2 changes: 2 additions & 0 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def delete_index(self):
Deletes the index, effectively clearing it and setting it to None.
"""
self.index = None
self.routes = None
self.utterances = None

def _get_indices_for_route(self, route_name: str):
"""Gets an array of indices for a specific route."""
Expand Down
39 changes: 28 additions & 11 deletions tests/unit/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from platform import python_version


PINECONE_SLEEP = 3
PINECONE_SLEEP = 6


def mock_encoder_call(utterances):
Expand Down Expand Up @@ -203,16 +203,15 @@ def test_initialization(self, routes, openai_encoder, index_cls, encoder_cls):
index = init_index(index_cls)
route_layer = RouteLayer(
encoder=encoder_cls(), routes=routes, index=index,
auto_sync="local" if index_cls is PineconeIndex else None,
top_k=10,
auto_sync="local", top_k=10,
)
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP) # allow for index to be populated

assert openai_encoder.score_threshold == 0.3
assert route_layer.score_threshold == 0.3
assert route_layer.top_k == 10
assert len(route_layer.index) if route_layer.index is not None else 0 == 5
assert len(route_layer.index) == 5
assert (
len(set(route_layer._get_route_names()))
if route_layer._get_route_names() is not None
Expand Down Expand Up @@ -292,7 +291,7 @@ def test_list_route_names(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index,
auto_sync="local" if index_cls is PineconeIndex else None,
auto_sync="local",
)
if index_cls is PineconeIndex:
time.sleep(PINECONE_SLEEP) # allow for index to be populated
Expand Down Expand Up @@ -735,7 +734,10 @@ def test_with_unrecognized_route(self, openai_encoder, routes, index_cls):

def test_retrieve_with_text(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index,
auto_sync="local",
)
text = "Hello"
results = route_layer.retrieve_multiple_routes(text=text)
assert len(results) >= 1, "Expected at least one result"
Expand All @@ -745,7 +747,10 @@ def test_retrieve_with_text(self, openai_encoder, routes, index_cls):

def test_retrieve_with_vector(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index,
auto_sync="local",
)
vector = [0.1, 0.2, 0.3]
results = route_layer.retrieve_multiple_routes(vector=vector)
assert len(results) >= 1, "Expected at least one result"
Expand All @@ -755,13 +760,19 @@ def test_retrieve_with_vector(self, openai_encoder, routes, index_cls):

def test_retrieve_without_text_or_vector(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index,
auto_sync="local",
)
with pytest.raises(ValueError, match="Either text or vector must be provided"):
route_layer.retrieve_multiple_routes()

def test_retrieve_no_matches(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index,
auto_sync="local",
)
text = "Asparagus"
results = route_layer.retrieve_multiple_routes(text=text)
assert len(results) == 0, f"Expected no results, but got {len(results)}"
Expand Down Expand Up @@ -859,14 +870,20 @@ def test_update_utterances_not_implemented(self, openai_encoder, routes, index_c

class TestLayerFit:
def test_eval(self, openai_encoder, routes, test_data):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes,
auto_sync="local",
)
# unpack test data
X, y = zip(*test_data)
# evaluate
route_layer.evaluate(X=X, y=y, batch_size=int(len(test_data) / 5))

def test_fit(self, openai_encoder, routes, test_data):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes,
auto_sync="local",
)
# unpack test data
X, y = zip(*test_data)
route_layer.fit(X=X, y=y, batch_size=int(len(test_data) / 5))
Expand Down

0 comments on commit b3cc85d

Please sign in to comment.