From 7dad501b01cf558a4229a18d457de525d63d3e03 Mon Sep 17 00:00:00 2001 From: SothanaV Date: Thu, 25 Apr 2024 16:51:10 +0700 Subject: [PATCH] [FIX] bug find_nearest_neighbours --- ampligraph/discovery/discovery.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ampligraph/discovery/discovery.py b/ampligraph/discovery/discovery.py index 9ff7f711..928bda87 100644 --- a/ampligraph/discovery/discovery.py +++ b/ampligraph/discovery/discovery.py @@ -1227,8 +1227,8 @@ def find_nearest_neighbours(kge_model, entities, n_neighbors=10, entities_subset all_neighbors_emb = kge_model.get_embeddings(entities_subset) all_neighbors = entities_subset else: - all_neighbors_emb = kge_model.trained_model_params[0] - all_neighbors = list(kge_model.ent_to_idx.keys()) + all_neighbors = list(kge_model.data_indexer.backend.get_all_entities()) + all_neighbors_emb = kge_model.get_embeddings(all_neighbors) assert n_neighbors < len(all_neighbors), 'n_neighbors must be less than the number of entities being fit!' knn_model = NearestNeighbors(n_neighbors=n_neighbors, metric=metric).fit(all_neighbors_emb)