From 6f57d0f0b622a584d5c8c3d938d0921b8f2166e6 Mon Sep 17 00:00:00 2001 From: Katy Brown Date: Thu, 28 Nov 2024 12:14:02 +0000 Subject: [PATCH] rearranged files and adjusted unit tests to match --- plot_phylo/{post_draw.py => amend_tree.py} | 0 plot_phylo/draw_tree.py | 63 ++++++++-------- plot_phylo/plot_phylo.py | 58 +++++++-------- setup.py | 2 - tests/test_plot_phylo.py | 85 +++++++++++++--------- tests/test_plot_phylo_data.py | 9 ++- 6 files changed, 117 insertions(+), 100 deletions(-) rename plot_phylo/{post_draw.py => amend_tree.py} (100%) mode change 100644 => 100755 plot_phylo/draw_tree.py diff --git a/plot_phylo/post_draw.py b/plot_phylo/amend_tree.py similarity index 100% rename from plot_phylo/post_draw.py rename to plot_phylo/amend_tree.py diff --git a/plot_phylo/draw_tree.py b/plot_phylo/draw_tree.py old mode 100644 new mode 100755 index 4613b65..61bc6fe --- a/plot_phylo/draw_tree.py +++ b/plot_phylo/draw_tree.py @@ -1,37 +1,7 @@ +#!/usr/bin/env python3 import numpy as np -def collapse_nodes(tree, collapse_list, collapse_names): - cD = dict(zip(collapse_list, collapse_names)) - collapseD = dict() - for string in collapse_list: - keeps = set() - collapsed = set() - done = set() - ddD = dict() - for node in tree.traverse(): - x = 0 - L = list(node.get_leaves()) - dd = [] - for leaf in L: - if leaf.name.endswith(string) and leaf not in done: - dd.append(leaf.dist) - x += 1 - if x == len(L) or (len(L) == 1 and leaf not in done): - keeps.add(L[0].name) - done = done | set(L) - if x > 1: - collapsed.add(L[0]) - ddD[L[0]] = np.mean(dd) - tree.prune(keeps) - for leaf in tree.get_leaves(): - if leaf in collapsed: - leaf.dist = ddD[leaf] - leaf.name = 'COLLAPSE|%s' % (leaf.name) - collapseD[leaf.name] = cD[string] - return (tree, collapseD) - - def draw_tree(tree, ax, x=0, y=0, @@ -306,3 +276,34 @@ def draw_tree(tree, ax, va='center', fontsize=appearance['font_size']-2) return (y, ym, ps) + + +def collapse_nodes(tree, collapse_list, collapse_names): + cD = dict(zip(collapse_list, collapse_names)) + collapseD = dict() + for string in collapse_list: + keeps = set() + collapsed = set() + done = set() + ddD = dict() + for node in tree.traverse(): + x = 0 + L = list(node.get_leaves()) + dd = [] + for leaf in L: + if leaf.name.endswith(string) and leaf not in done: + dd.append(leaf.dist) + x += 1 + if x == len(L) or (len(L) == 1 and leaf not in done): + keeps.add(L[0].name) + done = done | set(L) + if x > 1: + collapsed.add(L[0]) + ddD[L[0]] = np.mean(dd) + tree.prune(keeps) + for leaf in tree.get_leaves(): + if leaf in collapsed: + leaf.dist = ddD[leaf] + leaf.name = 'COLLAPSE|%s' % (leaf.name) + collapseD[leaf.name] = cD[string] + return (tree, collapseD) \ No newline at end of file diff --git a/plot_phylo/plot_phylo.py b/plot_phylo/plot_phylo.py index abe92ac..03f18cb 100755 --- a/plot_phylo/plot_phylo.py +++ b/plot_phylo/plot_phylo.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import ete3 -import draw_tree -import post_draw +from . import draw_tree +from . import amend_tree def plot_phylo(tree, ax, @@ -163,42 +163,42 @@ def plot_phylo(tree, ax, # internally when the function is called recursively but # are not needed by the user - _, _, ps = draw_tree(T, ax, - x=xpos, - y=-ypos-height, - x0=xpos, - ps=[], - height=height, - width=width, - depth=maxdist, - align_tips=align_tips, - rev_align_tips=rev_align_tips, - branch_lengths=branch_lengths, - reverse=reverse, - appearance=appearance, - collapse=collapse, - collapseD=collapseD) + _, _, ps = draw_tree.draw_tree(T, ax, + x=xpos, + y=-ypos-height, + x0=xpos, + ps=[], + height=height, + width=width, + depth=maxdist, + align_tips=align_tips, + rev_align_tips=rev_align_tips, + branch_lengths=branch_lengths, + reverse=reverse, + appearance=appearance, + collapse=collapse, + collapseD=collapseD) if rev_align_tips: - ps = post_draw.reverse_align(ax, ps, reverse) + ps = amend_tree.reverse_align(ax, ps, reverse) # Hide axis if not show_axis: ax.set_axis_off() if scale_bar and branch_lengths: if not reverse: - post_draw.draw_scale_bar(ax, width, height, maxdist, xpos, ypos, - scale_bar_width=scale_bar_width, - appearance=appearance) + amend_tree.draw_scale_bar(ax, width, height, maxdist, xpos, ypos, + scale_bar_width=scale_bar_width, + appearance=appearance) else: - post_draw.draw_scale_bar(ax, width, height, maxdist, -xpos, ypos, - scale_bar_width=scale_bar_width, - appearance=appearance) + amend_tree.draw_scale_bar(ax, width, height, maxdist, -xpos, ypos, + scale_bar_width=scale_bar_width, + appearance=appearance) textobj = [p[1] for p in ps] if auto_ax: - textobj, ax = post_draw.auto_axis(ax, textobj, - xpos, ypos, - width, height, maxdist, - scale_bar, branch_lengths) - boxes = post_draw.get_boxes(ax, textobj) + textobj, ax = amend_tree.auto_axis(ax, textobj, + xpos, ypos, + width, height, maxdist, + scale_bar, branch_lengths) + boxes = amend_tree.get_boxes(ax, textobj) return (boxes) diff --git a/setup.py b/setup.py index 81d4556..d5dbe00 100644 --- a/setup.py +++ b/setup.py @@ -24,9 +24,7 @@ long_description_content_type="text/markdown", url="https://github.com/KatyBrown/plot_phylo", packages=setuptools.find_packages(), - package_dir={'plot_phylo':'plot_phylo'}, install_requires=['matplotlib', 'ete3'], - scripts=['plot_phylo/plot_phylo.py'], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/tests/test_plot_phylo.py b/tests/test_plot_phylo.py index 40678b6..344dc8d 100644 --- a/tests/test_plot_phylo.py +++ b/tests/test_plot_phylo.py @@ -2,6 +2,8 @@ import matplotlib.pyplot as plt import matplotlib import plot_phylo +from plot_phylo import draw_tree +from plot_phylo import amend_tree import pytest from test_plot_phylo_data import (test_plot_phylo_vars, test_plot_phylo_list, @@ -20,7 +22,6 @@ import os import shutil import numpy as np -matplotlib.use('Agg') def compare_images(f1, f2, tol): @@ -63,7 +64,10 @@ def test_plot_phylo_params(xpos, line_width, bold, expected_figure, - ID, tree, ylim): + ID, tree, ylim, + collapse, + collapse_names, + auto_ax): tree_stem = tree.split("/")[-1].split(".")[0] @@ -90,7 +94,10 @@ def test_plot_phylo_params(xpos, font_size=font_size, line_col=line_col, line_width=line_width, - bold=bold) + bold=bold, + collapse=collapse, + collapse_names=collapse_names, + auto_ax=auto_ax) try: os.mkdir("test_temp") except FileExistsError: @@ -128,7 +135,9 @@ def test_draw_tree_params(x, expected, ID, tree, - ylim): + ylim, + collapse, + collapseD): try: T = ete3.Tree(tree) @@ -150,19 +159,19 @@ def test_draw_tree_params(x, if nam not in appearance['col_dict']: appearance['col_dict'][nam] = 'black' - test_obj = plot_phylo.draw_tree(tree=T, ax=a, - x=x, - y=y, - x0=x0, - ps=ps, - height=ylim-1, - width=width, - depth=depth, - align_tips=align_tips, - rev_align_tips=rev_align_tips, - branch_lengths=branch_lengths, - reverse=reverse, - appearance=appearance) + test_obj = draw_tree.draw_tree(tree=T, ax=a, + x=x, + y=y, + x0=x0, + ps=ps, + height=ylim-1, + width=width, + depth=depth, + align_tips=align_tips, + rev_align_tips=rev_align_tips, + branch_lengths=branch_lengths, + reverse=reverse, + appearance=appearance) plt.close() ytest = round(test_obj[0], 2) y2test = round(test_obj[1], 2) @@ -203,7 +212,9 @@ def test_reverse_align_params(x, expected, ID, tree, - ylim): + ylim, + collapse, + collapseD): try: T = ete3.Tree(tree) except ete3.parser.newick.NewickError: @@ -224,22 +235,22 @@ def test_reverse_align_params(x, if nam not in appearance['col_dict']: appearance['col_dict'][nam] = 'black' - _, _, ps = plot_phylo.draw_tree(tree=T, ax=a, - x=x, - y=y, - x0=x0, - ps=[], - height=ylim-1, - width=width, - depth=depth, - align_tips=True, - rev_align_tips=True, - branch_lengths=branch_lengths, - reverse=reverse, - appearance=appearance) - plt.close() - reverse = plot_phylo.reverse_align(a, ps, True) + _, _, ps = draw_tree.draw_tree(tree=T, ax=a, + x=x, + y=y, + x0=x0, + ps=[], + height=ylim-1, + width=width, + depth=depth, + align_tips=True, + rev_align_tips=True, + branch_lengths=branch_lengths, + reverse=reverse, + appearance=appearance) + reverse = amend_tree.reverse_align(a, ps, True) + plt.close() e0 = expected.replace(".pickle", "_%s.pickle" % tree_stem) test_dat = [] @@ -271,7 +282,7 @@ def test_reverse_align_params(x, test_get_boxes_texts, test_get_boxes_results))) def test_get_boxes(ax, texts, expected_result): - boxes = plot_phylo.get_boxes(ax, texts) + boxes = amend_tree.get_boxes(ax, texts) bclean = dict() for b, vals in boxes.items(): bclean[b] = dict() @@ -303,7 +314,10 @@ def test_bad_tree(xpos, line_width, bold, expected_figure, - ID, tree, ylim): + ID, tree, ylim, + collapse, + collapse_names, + auto_ax): f = plt.figure(figsize=(10, 20)) a = f.add_subplot(111) a.set_xlim(-10, 20) @@ -330,3 +344,4 @@ def test_bad_tree(xpos, line_col=line_col, line_width=line_width, bold=bold) + plt.close() diff --git a/tests/test_plot_phylo_data.py b/tests/test_plot_phylo_data.py index 51bdb8a..b51f16e 100644 --- a/tests/test_plot_phylo_data.py +++ b/tests/test_plot_phylo_data.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import plot_phylo +from plot_phylo import draw_tree import copy # Default parameters values @@ -16,9 +17,9 @@ varis.remove('height') variD.pop('height') -defaults_draw_tree = plot_phylo.draw_tree.__defaults__ -ac_draw_tree = plot_phylo.draw_tree.__code__.co_argcount -varis_draw_tree = list(plot_phylo.draw_tree.__code__.co_varnames[ +defaults_draw_tree = draw_tree.draw_tree.__defaults__ +ac_draw_tree = draw_tree.draw_tree.__code__.co_argcount +varis_draw_tree = list(draw_tree.draw_tree.__code__.co_varnames[ ac_draw_tree-len(defaults_draw_tree):ac_draw_tree]) varis_draw_tree.remove('height') @@ -94,6 +95,7 @@ 'line_col', 'line_width', 'show_support', 'bold']: curr_dict['appearance'][var] = curr_dict[var] curr_dict['depth'] = [5, 5, 5] + curr_dict['collapseD'] = dict() curr_dict.update(test) pass_vals = [curr_dict[v] for v in varis_draw_tree] pass_vals.append('tests/test_objects/%s.pickle' % testnam) @@ -116,6 +118,7 @@ for var in ['col_dict', 'label_dict', 'font_size', 'line_col', 'line_width', 'show_support', 'bold']: curr_dict['appearance'][var] = curr_dict[var] + curr_dict['collapseD'] = dict() curr_dict.update(test) curr_dict['depth'] = [5, 5, 5] curr_dict.pop('rev_align_tips')