Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add solution exporting #29

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions motile/track_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import TYPE_CHECKING, Any, Hashable

from .variables import EdgeSelected, NodeSelected

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
Expand Down Expand Up @@ -114,6 +116,21 @@ def nodes_by_frame(self, t: int) -> list[Hashable]:
return []
return self._nodes_by_frame[t]

def mark_solution(self, solver, solution_attribute="selected"):
node_selected = solver.get_variables(NodeSelected)
for node in self.nodes:
if solver.solution[node_selected[node]] > 0.5:
self.nodes[node]["selected"] = 1
else:
self.nodes[node]["selected"] = 0

edge_selected = solver.get_variables(EdgeSelected)
for edge in self.edges:
if solver.solution[edge_selected[edge]] > 0.5:
self.edges[edge]["selected"] = 1
else:
self.edges[edge]["selected"] = 0

def _update_metadata(self) -> None:
if not self._graph_changed:
return
Expand Down
112 changes: 111 additions & 1 deletion motile/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import networkx as nx

from .track_graph import TrackGraph
from motile.track_graph import TrackGraph


def get_tracks(
Expand Down Expand Up @@ -55,3 +55,113 @@ def get_tracks(
)
for g in nx.weakly_connected_components(graph)
]


def get_networkx_graph(
graph: TrackGraph,
require_selected: bool = False,
selected_attribute: str = "selected",
) -> list[TrackGraph]:
"""Return the physical directed graph (no hyperedges Flo!) as networkx.DiGraph

Args:
graph (:class:`TrackGraph`):

The track graph.

require_selected (``bool``):

If ``True``, consider only edges that have a selected_attribute
attribute that is set to ``True``.

selected_attribute (``str``):

Only used if require_selected=True. Determines the attribute
name to check if an edge is selected. Default value is
'selected'.

Returns:

networkx.DiGraph
"""

if require_selected:
selected_edges = [
e
for e in graph.edges
if (
selected_attribute in graph.edges[e]
and graph.edges[e][selected_attribute]
)
]

# TODO edge_subgraph will miss nodes with in- and out-deg 0
graph = graph.nx_graph.edge_subgraph(selected_edges)
else:
graph = graph.nx_graph

return graph


def create_toy_example_graph():
cells = [
{"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1},
{"id": 1, "t": 0, "x": 25, "score": 0.1},
{"id": 2, "t": 1, "x": 0, "score": 0.3, "gt": 1},
{"id": 3, "t": 1, "x": 26, "score": 0.4},
{"id": 4, "t": 2, "x": 2, "score": 0.6, "gt": 1},
{"id": 5, "t": 2, "x": 24, "score": 0.3, "gt": 0},
{"id": 6, "t": 2, "x": 35, "score": 0.7},
]

edges = [
{"source": 0, "target": 2, "score": 0.9, "gt": 1},
{"source": 1, "target": 3, "score": 0.9},
{"source": 0, "target": 3, "score": 0.5},
{"source": 1, "target": 2, "score": 0.5},
{"source": 2, "target": 4, "score": 0.7, "gt": 1},
{"source": 3, "target": 5, "score": 0.7},
{"source": 2, "target": 5, "score": 0.3, "gt": 0},
{"source": 3, "target": 4, "score": 0.3},
{"source": 3, "target": 6, "score": 0.8},
]
graph = nx.DiGraph()
graph.add_nodes_from([(cell["id"], cell) for cell in cells])
graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
return TrackGraph(graph)


if __name__ == "__main__":
tg = create_toy_example_graph()

# toy solver
from motile import Solver
from motile.constraints import MaxChildren, MaxParents
from motile.costs import Appear, EdgeSelection, NodeSelection

solver = Solver(tg)

# tell it how to compute costs for selecting nodes and edges
solver.add_costs(NodeSelection(weight=-1.0, attribute="score"))
solver.add_costs(EdgeSelection(weight=-1.0, attribute="score"))

# add a small penalty to start a new track
solver.add_costs(Appear(constant=1.0))

# add constraints on the solution (no splits, no merges)
solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(1))

# solve
solution = solver.solve()

# mark solution with attribute
tg.mark_solution(solver, solution_attribute="selected")

full_graph = get_networkx_graph(tg)
print(full_graph)

sol_graph = get_networkx_graph(
tg, require_selected=True, selected_attribute="selected"
)
print(sol_graph)