Skip to content

Commit

Permalink
Merge pull request #221 from medema-group/hotfix/dense-AP
Browse files Browse the repository at this point in the history
Hotfix/dense ap
  • Loading branch information
nlouwen authored Dec 12, 2024
2 parents 7b7b637 + 5633326 commit de8e34c
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 36 deletions.
2 changes: 1 addition & 1 deletion big_scape/benchmarking/benchmark_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_computed_bs2_labels(self, data_path: Path) -> None:
run_time = datetime(1, 1, 1)
for dt_str in run_times:
date, time = dt_str.split("_")
day, month, year = map(int, date.split("-"))
year, month, day = map(int, date.split("-"))
hour, minute, second = map(int, time.split("-"))
current_dt = datetime(year, month, day, hour, minute, second)
if current_dt > run_time:
Expand Down
4 changes: 4 additions & 0 deletions big_scape/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class BigscapeConfig:

# CLUSTER
PREFERENCE: float = 0.0
DENSITY: float = 0.85
DENSE_PREFERENCE: float = -5.0

# TREE
TOP_FREQS: int = 3
Expand Down Expand Up @@ -194,6 +196,8 @@ def parse_config(config_file_path: Path, log_path: Optional[Path] = None) -> Non

# CLUSTER
BigscapeConfig.PREFERENCE = config["PREFERENCE"]
BigscapeConfig.DENSITY = config["DENSITY"]
BigscapeConfig.DENSE_PREFERENCE = config["DENSE_PREFERENCE"]

# TREE
BigscapeConfig.TOP_FREQS = config["TOP_FREQS"]
Expand Down
6 changes: 5 additions & 1 deletion big_scape/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ EXTEND_GAP_SCORE: -2
# as a percentage of total domains present in the compared record.
EXTEND_MAX_MATCH_PERC: 0.1

# CLUSTER
# GCF Calling
# Internal parameter of the Affinity Propagation clustering algorithm, governs the number
# of families created. Higher preference will result in more families and vice versa.
PREFERENCE: 0.0
# Connected component density threshold (incl.) and Affinity Propagation preference to be used
# on dense connected components.
DENSITY: 0.85
DENSE_PREFERENCE: -5.0

# GCF TREE
# The number of common domains (present in the exemplar BGC record) used to
Expand Down
66 changes: 55 additions & 11 deletions big_scape/network/families.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# from python
import sys
from typing import Callable
from typing import Callable, Optional
import warnings
import numpy as np
import networkx
Expand Down Expand Up @@ -50,7 +50,45 @@ def generate_families(

similarity_matrix, node_ids = edge_list_to_sim_matrix(connected_component)

labels, centers = aff_sim_matrix(similarity_matrix)
if get_cc_density(connected_component) >= BigscapeConfig.DENSITY:
# if a connected component is highly connected, no (or less) splitting is needed
# run affinity propagation with a lower preference to find the best family center
labels, centers = aff_sim_matrix(
similarity_matrix, BigscapeConfig.DENSE_PREFERENCE
)
else:
labels, centers = aff_sim_matrix(similarity_matrix)

# If affinity propagation did not converge, no centers are returned.
# to show them in the network anyways, merge them into one arbitrary family
if len(centers) == 0:
center = node_ids[0]

if DB.metadata is None:
raise RuntimeError("DB metadata is None!")

gbk_table = DB.metadata.tables["gbk"]
record_table = DB.metadata.tables["bgc_record"]
center_data = DB.execute(
select(
gbk_table.c.path,
record_table.c.record_type,
record_table.c.record_number,
)
.join(record_table, record_table.c.gbk_id == gbk_table.c.id)
.where(record_table.c.id == center)
).fetchone()

if center_data is None:
raise RuntimeError("Family center not found in database: %s", center)

c_path, c_type, c_number = center_data
logging.warning(
"Affinity Propagation did not converge, records in this connected component "
"have been merged into one arbitrary family with center: %s",
"_".join(map(str, [c_path.split("/")[-1], c_type, c_number])),
)
return [(rec_id, center, cutoff, bin_label, run_id) for rec_id in node_ids]

for idx, label in enumerate(labels):
label = int(label)
Expand Down Expand Up @@ -85,15 +123,17 @@ def get_cc_edge_weight_std(connected_component) -> float:
return edge_std


def get_cc_connectivity(connected_component) -> float:
"""calculates the connectivity of a connected component
def get_cc_density(
connected_component: list[tuple[int, int, float, float, float, float, int]]
) -> float:
"""calculates the density of a connected component: nr edges / nr of possible edges
Args:
connected_component (list[tuple[int, int, float, float, float, float, str]]):
connected_component (list[tuple[int, int, float, float, float, float, int]]):
connected component in the form of a list of edges
Returns:
float: connectivity of the connected component
float: density of the connected component
"""

nr_edges = len(connected_component)
Expand All @@ -102,10 +142,10 @@ def get_cc_connectivity(connected_component) -> float:
nodes_b = [edge[1] for edge in connected_component]
nr_nodes = len(set(nodes_a + nodes_b))

cc_connectivity = nr_edges / (nr_nodes * (nr_nodes - 1) / 2)
cc_connectivity = round(cc_connectivity, 2)
cc_density = nr_edges / (nr_nodes * (nr_nodes - 1) / 2)
cc_density = round(cc_density, 2)

return cc_connectivity
return cc_density


def test_centrality(connected_component, node_fraction) -> tuple[bool, list[int]]:
Expand Down Expand Up @@ -148,27 +188,31 @@ def test_centrality(connected_component, node_fraction) -> tuple[bool, list[int]
return False, sorted_between_bentrality_nodes


def aff_sim_matrix(matrix):
def aff_sim_matrix(matrix, preference: Optional[float] = None):
"""Execute affinity propagation on a __similarity__ matrix
Note: a similarity matrix. Not a distance matrix.
Args:
matrix (numpy.array[numpy.array]): similarity matrix in numpy array of array
format.
preference (float, optional): Affinity propagation preference.
Returns:
tuple[list[int], list[int]]: list of labels and list of cluster center ids
"""
# thanks numpy but we sort of know what we're doing
warnings.filterwarnings(action="ignore", category=ConvergenceWarning)

if preference is None:
preference = BigscapeConfig.PREFERENCE

af_results = AffinityPropagation(
damping=0.90,
max_iter=1000,
convergence_iter=200,
affinity="precomputed",
preference=BigscapeConfig.PREFERENCE,
preference=preference,
).fit(matrix)

return af_results.labels_, af_results.cluster_centers_indices_
Expand Down
43 changes: 25 additions & 18 deletions big_scape/output/legacy_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def write_record_annotations_file(run, cutoff, all_bgc_records) -> None:
"GBK",
"Record_Type",
"Record_Number",
"Full_Name",
"Class",
"Category",
"Organism",
Expand All @@ -403,6 +404,7 @@ def write_record_annotations_file(run, cutoff, all_bgc_records) -> None:
Path(gbk_path).stem,
record_type,
str(record_number),
f"{Path(gbk_path).name}_{record_type}_{record_number}",
product,
record_categories[rec_id],
organism,
Expand Down Expand Up @@ -473,7 +475,9 @@ def write_clustering_file(run, cutoff, pair_generator) -> None:
record_data = DB.execute(select_statement).fetchall()

with open(clustering_file_path, "w") as clustering_file:
header = "\t".join(["GBK", "Record_Type", "Record_Number", "CC", "Family"])
header = "\t".join(
["GBK", "Record_Type", "Record_Number", "CC", "Family", "Full_Name"]
)
clustering_file.write(header + "\n")

for record in record_data:
Expand All @@ -486,6 +490,7 @@ def write_clustering_file(run, cutoff, pair_generator) -> None:
str(record_number),
str(cc_number),
f"FAM_{family:0>5}",
f"{Path(gbk_path).name}_{record_type}_{record_number}",
]
)
clustering_file.write(row + "\n")
Expand Down Expand Up @@ -568,22 +573,22 @@ def write_network_file(
cutoff (float, optional): distance cutoff for returned edges. Defaults to None.
"""
if weight is None:
legacy_weights = [
"PKSI",
"PKSother",
"NRPS",
"RiPPs",
"saccharides",
"terpene",
"PKS-NRP_Hybrids",
"other",
]
incl_weights = ["mix"]

if not run["mix"]:
incl_weights.remove("mix")
if run["legacy_weights"]:
incl_weights.extend(legacy_weights)
incl_weights.extend(
[
"PKSI",
"PKSother",
"NRPS",
"RiPPs",
"saccharides",
"terpene",
"PKS-NRP_Hybrids",
"other",
]
)
if not run["mix"]:
incl_weights.remove("mix")
else:
incl_weights = [weight]

Expand Down Expand Up @@ -650,9 +655,9 @@ def write_network_file(

with open(output_path, "w") as network_file:
header = (
"GBK_a\tRecord_Type_a\tRecord_Number_a\tORF_coords_a\tGBK_b\t"
"Record_Type_b\tRecord_Number_b\tORF_coords_b\tdistance\tjaccard\tadjacency\t"
"dss\tweights\taligmnent_mode\textend_strategy\n"
"GBK_a\tRecord_Type_a\tRecord_Number_a\tFull_Name_a\tORF_coords_a\tGBK_b\t"
"Record_Type_b\tRecord_Number_b\tFull_Name_b\tORF_coords_b\tdistance\t"
"jaccard\tadjacency\tdss\tweights\taligmnent_mode\textend_strategy\n"
)

network_file.write(header)
Expand Down Expand Up @@ -681,10 +686,12 @@ def write_network_file(
Path(gbk_path_a).stem,
record_type_a,
str(record_number_a),
f"{Path(gbk_path_a).name}_{record_type_a}_{record_number_a}",
f"{ext_a_start}:{ext_a_stop}",
Path(gbk_path_b).stem,
record_type_b,
str(record_number_b),
f"{Path(gbk_path_b).name}_{record_type_b}_{record_number_b}",
f"{ext_b_start}:{ext_b_stop}",
f"{distance:.2f}",
f"{jaccard:.2f}",
Expand Down
10 changes: 5 additions & 5 deletions test/network/test_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,19 @@ def test_get_cc_edge_weight_std(self):

self.assertEqual(expected_std, actual_std)

def test_get_cc_connectivity(self):
"""Tests whether the connectivity of a connected component is correctly
def test_get_cc_density(self):
"""Tests whether the density of a connected component is correctly
calculated
"""
adj_list = TestAffinityPropagation.gen_edge_list()

# 8 nodes, 15 edges
# 115 / (8 * (8 - 1) / 2) = 0.54
expected_connectivity = 0.54
expected_density = 0.54

actual_connectivity = bs_families.get_cc_connectivity(adj_list)
actual_density = bs_families.get_cc_density(adj_list)

self.assertEqual(expected_connectivity, actual_connectivity)
self.assertEqual(expected_density, actual_density)

def test_test_centrality(self):
"""Tests whether the test_centrality function correctly identifies a network
Expand Down

0 comments on commit de8e34c

Please sign in to comment.