From 4ad6cd40023da776427a065033920260181407ef Mon Sep 17 00:00:00 2001
From: Sam Horsfield <s.horsfield19@imperial.ac.uk>
Date: Fri, 25 Oct 2024 10:49:32 +0100
Subject: [PATCH] Fixes Keyerror during refinding

---
 panaroo_runner/find_missing.py     |  8 +++---
 panaroo_runner/generate_network.py | 16 ++++++-----
 src/graph.cpp                      | 43 +++++++++++++-----------------
 src/graph.h                        |  2 +-
 4 files changed, 32 insertions(+), 37 deletions(-)

diff --git a/panaroo_runner/find_missing.py b/panaroo_runner/find_missing.py
index 79db128..85036f9 100644
--- a/panaroo_runner/find_missing.py
+++ b/panaroo_runner/find_missing.py
@@ -136,8 +136,6 @@ def search_graph(search_pair,
     graph_existing_shm = shared_memory.SharedMemory(name=graph_shd_arr_tup.name)
     graph_shd_arr = np.ndarray(graph_shd_arr_tup.shape, dtype=graph_shd_arr_tup.dtype, buffer=graph_existing_shm.buf)
 
-    # sort items to preserve order
-    conflicts = {k: v for k, v in sorted(dicts["conflicts"].items(), key=lambda item: item[0])}
     node_search_dict = dicts["searches"]
 
     node_locs = {}
@@ -145,11 +143,11 @@ def search_graph(search_pair,
     # keep track of regions already with genes to avoid re-traversal
     to_avoid = set()
 
-    # mask regions that already have genes
-    for node, ORF_ID in conflicts.items():
+    # mask regions that already have genes, with sorted list to preserve order
+    for node in sorted(dicts["conflicts"].keys()):
         
         # read in ORF information
-        ORF_info = ORF_map[ORF_ID]
+        ORF_info = ORF_map[dicts["conflicts"][node]]
 
         # determine sequence overlap of ORFs
         for i, node_coords in enumerate(ORF_info[1]):
diff --git a/panaroo_runner/generate_network.py b/panaroo_runner/generate_network.py
index 184d2fa..48b1dbc 100644
--- a/panaroo_runner/generate_network.py
+++ b/panaroo_runner/generate_network.py
@@ -7,7 +7,7 @@
 def generate_network(DBG, overlap, ORF_file_paths, Edge_file_paths, cluster_file):
     # read in cluster_dict
     # TODO save pair here that holds ORFs removed for low scores after centroid scored
-    cluster_dict, ORFs_to_remove = ggCaller_cpp.read_cluster_file(cluster_file)
+    cluster_dict, ORFs_present = ggCaller_cpp.read_cluster_file(cluster_file)
 
     # associate sequences with their clusters
     seq_to_cluster = {}
@@ -36,7 +36,7 @@ def generate_network(DBG, overlap, ORF_file_paths, Edge_file_paths, cluster_file
                 pan_centroid_ID = str(colour_ID) + "_0_" + str(ORF_ID)
 
                 # make sure ORF wasn't removed after centroid scored
-                if pan_centroid_ID in ORFs_to_remove:
+                if ORF_ID not in ORFs_present[colour_ID]:
                     continue
 
                 # add information to cluster_centroid_data
@@ -57,12 +57,12 @@ def generate_network(DBG, overlap, ORF_file_paths, Edge_file_paths, cluster_file
                     pan_ORF_id = str(genome_id) + "_0_" + str(local_id)
 
                     # make sure ORF wasn't removed after centroid scored
-                    if pan_ORF_id in ORFs_to_remove:
+                    if local_id not in ORFs_present[genome_id]:
                         continue
 
                     # only hold lengths of genes that are not in a cluster
-                    if ORF_ID_str in ORF_length_map:
-                        del ORF_length_map[ORF_ID_str]
+                    if pan_ORF_id in ORF_length_map:
+                        del ORF_length_map[pan_ORF_id]
 
                     # index sequences to clusters and the number of edges they have
                     seq_to_cluster[pan_ORF_id] = [cluster_id, 0]
@@ -85,6 +85,9 @@ def generate_network(DBG, overlap, ORF_file_paths, Edge_file_paths, cluster_file
 
             pan_ORF_id = str(genome_id) + "_0_" + str(local_id)
 
+            if local_id not in ORFs_present[genome_id]:
+                continue
+
             if pan_ORF_id in ORF_length_map:
                 new_centroid = False
                 length, hash = ORF_length_map[pan_ORF_id]
@@ -127,7 +130,8 @@ def generate_network(DBG, overlap, ORF_file_paths, Edge_file_paths, cluster_file
 
 
     # clear cluster_dict
-    cluster_dict.clear()
+    del cluster_dict
+    del ORFs_present
 
     # determine paralogs if required
     paralogs = set()
diff --git a/src/graph.cpp b/src/graph.cpp
index 136429c..b6729c8 100644
--- a/src/graph.cpp
+++ b/src/graph.cpp
@@ -363,7 +363,7 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
     cout << endl;
 
     // keep track of all genes that are low scoring
-    std::unordered_set<std::string> ORFs_to_remove;
+    std::unordered_map<size_t, std::unordered_set<int>> ORFs_present;
 
     // generate clusters if required
     if (clustering || !no_filter)
@@ -512,8 +512,6 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                     ia >> ORF_map;
                 }
 
-                std::unordered_set<std::string> ORFs_to_remove_private;
-
                 // remove all low scoring ORFs if present in colour
                 const auto& removal = to_remove.find(colour_ID);
                 if (removal != to_remove.end())
@@ -550,7 +548,6 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                         if (std::get<4>(ORF_info) < minimum_ORF_score)
                         {
                             ORF_map.erase(ORF_ID.first);
-                            ORFs_to_remove_private.insert(ORF_ID_str);
                         }
                     }
                 }
@@ -567,7 +564,6 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                 #pragma omp critical
                 {
                     bar2.update();
-                    ORFs_to_remove.insert(ORFs_to_remove_private.begin(), ORFs_to_remove_private.end());
                 }
             }
         }
@@ -606,7 +602,7 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                 ia >> ORF_map;
             }
 
-            std::unordered_set<std::string> ORFs_to_remove_private;
+            std::unordered_set<int> ORFs_present_private;
 
             // get whether colour is reference or not
             bool is_ref = ((bool)_RefSet[colour_ID]) ? true : false;
@@ -661,18 +657,13 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                             {
                                 // simplify ORF_info
                                 simplify_ORFNodeVector(ORF_map[ORF_ID], overlap);
-                                // issue is here, genes in cluster_map are lost if are too low scoring
                                 gene_map[ORF_ID] = std::move(ORF_map[ORF_ID]);
+
+                                // keep track of genes that are present
+                                ORFs_present_private.insert(ORF_ID);
                             }
                         }
                     }
-
-                    // go over remaining ORFs and add to ORFs_to_remove
-                    for (const auto& ORF_entry : ORF_map)
-                    {
-                        std::string ORF_ID_str = std::to_string(colour_ID) + "_" + std::to_string(ORF_entry.first);
-                        ORFs_to_remove_private.insert(ORF_ID_str);
-                    }
                 } else
                 {
                     // return unfiltered genes
@@ -682,6 +673,7 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                         simplify_ORFNodeVector(entry.second, overlap);
                         gene_map[entry.first] = std::move(entry.second);
                         gene_paths.push_back({entry.first});
+                        ORFs_present_private.insert(entry.first);
                     }
                 }
             } else
@@ -693,6 +685,7 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
                     simplify_ORFNodeVector(entry.second, overlap);
                     gene_map[entry.first] = std::move(entry.second);
                     gene_paths.push_back({entry.first});
+                    ORFs_present_private.insert(entry.first);
                 }
             }
 
@@ -789,18 +782,18 @@ std::pair<std::map<size_t, std::string>, std::map<size_t, std::string>> Graph::f
             #pragma omp critical
             {
                 bar.update();
-                ORFs_to_remove.insert(ORFs_to_remove_private.begin(), ORFs_to_remove_private.end());
+                ORFs_present[colour_ID] = std::move(ORFs_present_private);
             }
         }
     }
 
-    // write ORFs_to_remove
+    // write ORFs_present
     {
-        std::ofstream ofs(cluster_file + ".rem");
+        std::ofstream ofs(cluster_file + ".pres");
         boost::archive::text_oarchive oa(ofs);
         // write class instance to archive
 
-        oa << ORFs_to_remove;
+        oa << ORFs_present;
     }
 
     // add line for progress bar
@@ -927,24 +920,24 @@ void Graph::_index_graph (const std::vector<std::string>& stop_codons_for,
     _stop_freq= stop_codon_freq;
 }
 
-std::pair<ORFClusterMap, std::unordered_set<std::string>> read_cluster_file(const std::string& cluster_file)
+std::pair<ORFClusterMap, std::unordered_map<size_t, std::unordered_set<int>>> read_cluster_file(const std::string& cluster_file)
 {
-    ORFClusterMap cluster_pair;
-    std::unordered_set<std::string> ORFs_to_remove;
+    ORFClusterMap cluster_map;
+    std::unordered_map<size_t, std::unordered_set<int>> ORFs_present;
 
     {
         std::ifstream ifs(cluster_file);
         boost::archive::text_iarchive ia(ifs);
-        ia >> cluster_pair;
+        ia >> cluster_map;
     }
 
     {
-        std::ifstream ifs(cluster_file + ".rem");
+        std::ifstream ifs(cluster_file + ".pres");
         boost::archive::text_iarchive ia(ifs);
-        ia >> ORFs_to_remove;
+        ia >> ORFs_present;
     }
 
-    return std::make_pair(cluster_pair, ORFs_to_remove);
+    return std::make_pair(cluster_map, ORFs_present);
 }
 
 ORFNodeMap read_ORF_file(const std::string& ORF_file)
diff --git a/src/graph.h b/src/graph.h
index 1c142af..7379022 100644
--- a/src/graph.h
+++ b/src/graph.h
@@ -218,7 +218,7 @@ std::pair<RefindMap, bool> refind_gene(const size_t& colour_ID,
     boost::dynamic_bitset<> _RefSet;
 };
 
-std::pair<ORFClusterMap, std::unordered_set<std::string>> read_cluster_file(const std::string& cluster_file);
+std::pair<ORFClusterMap, std::unordered_map<size_t, std::unordered_set<int>>> read_cluster_file(const std::string& cluster_file);
 
 ORFNodeMap read_ORF_file(const std::string& ORF_file);