diff --git a/liesel/model/viz.py b/liesel/model/viz.py index 1da9897..23e038b 100644 --- a/liesel/model/viz.py +++ b/liesel/model/viz.py @@ -173,9 +173,11 @@ def _draw_edges(graph, axis, pos, is_var): if is_var: dist_edges = [] - non_dist_edges = [] + value_edges = [] for edge in edges: + + # find distribution edges if edge[1].has_dist: edge_0_output_nodes = set(edge[0].all_output_nodes()) edge_0_nodes = edge[0].nodes @@ -183,8 +185,31 @@ def _draw_edges(graph, axis, pos, is_var): if bool(edge_0_output_nodes.union(edge_0_nodes) & edge_1_input_nodes): dist_edges.append(edge) - else: - non_dist_edges.append(edge) + + # find value edges + edge_0_output_nodes = set(edge[0].all_output_nodes()) + edge_0_nodes = edge[0].nodes + edge_1_input_nodes = set(edge[1].value_node.all_input_nodes()) + + if bool(edge_0_output_nodes.union(edge_0_nodes) & edge_1_input_nodes): + value_edges.append(edge) + + edges_in_both = set(dist_edges) & set(value_edges) + dist_edges = set(dist_edges) - edges_in_both + value_edges = set(value_edges) - edges_in_both + + # assigns value_edges to edges to make it comparible with is_var=False + edges = value_edges + + nx.draw_networkx_edges( + graph, + pos, + edgelist=edges_in_both, + edge_color="#FF0000", + arrows=True, + ax=axis, + node_size=500, + ) nx.draw_networkx_edges( graph, @@ -196,8 +221,6 @@ def _draw_edges(graph, axis, pos, is_var): node_size=500, ) - edges = non_dist_edges - nx.draw_networkx_edges( graph, pos, @@ -247,12 +270,22 @@ def _add_legend(axis): [0], [0], marker=r"$\rightarrow$", - color="#aaaaaa", + color="#AAAAAA", label="Used in distribution", markerfacecolor="k", markersize=12, lw=0, ), + Line2D( + [0], + [0], + marker=r"$\rightarrow$", + color="#FF0000", + label="Used in value and distribution", + markerfacecolor="k", + markersize=12, + lw=0, + ), ] axis.legend(handles=legend_elements, loc="best")