Skip to content

Commit

Permalink
fixes #217 (#231)
Browse files Browse the repository at this point in the history
The plot_var functions is altered in two ways:
- Adds an edge if input is used in value of a node that has also a distribution.
- Introduces a red edge if input is used in the value node and the dist node of a variable.
  • Loading branch information
wiep authored Jan 17, 2025
1 parent d1d5b37 commit f5691a2
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions liesel/model/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,43 @@ def _draw_edges(graph, axis, pos, is_var):

if is_var:
dist_edges = []
non_dist_edges = []
value_edges = []

for edge in edges:

# find distribution edges
if edge[1].has_dist:
edge_0_output_nodes = set(edge[0].all_output_nodes())
edge_0_nodes = edge[0].nodes
edge_1_input_nodes = set(edge[1].dist_node.all_input_nodes())

if bool(edge_0_output_nodes.union(edge_0_nodes) & edge_1_input_nodes):
dist_edges.append(edge)
else:
non_dist_edges.append(edge)

# find value edges
edge_0_output_nodes = set(edge[0].all_output_nodes())
edge_0_nodes = edge[0].nodes
edge_1_input_nodes = set(edge[1].value_node.all_input_nodes())

if bool(edge_0_output_nodes.union(edge_0_nodes) & edge_1_input_nodes):
value_edges.append(edge)

edges_in_both = set(dist_edges) & set(value_edges)
dist_edges = set(dist_edges) - edges_in_both
value_edges = set(value_edges) - edges_in_both

# assigns value_edges to edges to make it comparible with is_var=False
edges = value_edges

nx.draw_networkx_edges(
graph,
pos,
edgelist=edges_in_both,
edge_color="#FF0000",
arrows=True,
ax=axis,
node_size=500,
)

nx.draw_networkx_edges(
graph,
Expand All @@ -196,8 +221,6 @@ def _draw_edges(graph, axis, pos, is_var):
node_size=500,
)

edges = non_dist_edges

nx.draw_networkx_edges(
graph,
pos,
Expand Down Expand Up @@ -247,12 +270,22 @@ def _add_legend(axis):
[0],
[0],
marker=r"$\rightarrow$",
color="#aaaaaa",
color="#AAAAAA",
label="Used in distribution",
markerfacecolor="k",
markersize=12,
lw=0,
),
Line2D(
[0],
[0],
marker=r"$\rightarrow$",
color="#FF0000",
label="Used in value and distribution",
markerfacecolor="k",
markersize=12,
lw=0,
),
]

axis.legend(handles=legend_elements, loc="best")

0 comments on commit f5691a2

Please sign in to comment.