Spaces:
Running on Zero
Running on Zero
| # Modified from Trackastra (https://github.com/weigertlab/trackastra) | |
| import logging | |
| import time | |
| from types import SimpleNamespace | |
| import networkx as nx | |
| import yaml | |
| try: | |
| import motile | |
| except ModuleNotFoundError: | |
| raise ModuleNotFoundError( | |
| "For tracking with an ILP, please conda install the optional `motile`" | |
| " dependency following https://funkelab.github.io/motile/install.html." | |
| ) | |
| logger = logging.getLogger(__name__) | |
| ILP_CONFIGS = { | |
| "gt": SimpleNamespace( | |
| nodeW=0, | |
| nodeC=-10, # take all nodes | |
| edgeW=-1, | |
| edgeC=0, | |
| appearC=0.25, | |
| disappearC=0.5, | |
| splitC=0.25, | |
| ), | |
| "deepcell_gt": SimpleNamespace( | |
| nodeW=0, | |
| nodeC=-10, # take all nodes | |
| edgeW=-1, | |
| edgeC=0, | |
| appearC=0.25, | |
| disappearC=0.5, | |
| splitC=1, | |
| ), | |
| "deepcell_gt_tuned": SimpleNamespace( | |
| nodeW=0, | |
| nodeC=-10, # take all nodes | |
| edgeW=-1, | |
| edgeC=0, | |
| appearC=0.5, | |
| disappearC=0.5, | |
| splitC=1, | |
| ), | |
| "deepcell_res_tuned": SimpleNamespace( | |
| nodeW=0, | |
| nodeC=0.25, | |
| edgeW=-1, | |
| edgeC=-0.25, | |
| appearC=0.25, | |
| disappearC=0.25, | |
| splitC=1.0, | |
| ), | |
| } | |
| def track_ilp( | |
| candidate_graph, | |
| allow_divisions: bool = True, | |
| ilp_config: str = "gt", | |
| params_file: str | None = None, | |
| **kwargs, | |
| ): | |
| candidate_graph_motile = motile.TrackGraph(candidate_graph, frame_attribute="time") | |
| ilp, _used_costs = solve_full_ilp( | |
| candidate_graph_motile, | |
| allow_divisions=allow_divisions, | |
| mode=ilp_config, | |
| params_file=params_file, | |
| ) | |
| print_solution_stats(ilp, candidate_graph_motile) | |
| graph = solution_to_graph(ilp, candidate_graph_motile) | |
| return graph | |
| def solve_full_ilp( | |
| graph, | |
| allow_divisions: bool, | |
| mode: str, | |
| params_file: str | None, | |
| ): | |
| solver = motile.Solver(graph) | |
| if params_file: | |
| with open(params_file) as f: | |
| p = yaml.safe_load(f) | |
| # TODO more checks | |
| p = SimpleNamespace(**p) | |
| logger.info(f"Using ILP parameters {p}") | |
| else: | |
| try: | |
| p = ILP_CONFIGS[mode] | |
| logger.info(f"Using `{mode}` ILP config.") | |
| except KeyError: | |
| raise ValueError( | |
| f"Unknown ILP mode {mode}. Choose from {list(ILP_CONFIGS.keys())} or" | |
| " supply custom parameters via `params_file` argument." | |
| ) | |
| # Add costs | |
| used_costs = SimpleNamespace() | |
| # NODES | |
| solver.add_cost( | |
| motile.costs.NodeSelection(weight=p.nodeW, constant=p.nodeC, attribute="weight") | |
| ) | |
| used_costs.nodeW = p.nodeW | |
| used_costs.nodeC = p.nodeC | |
| # EDGES | |
| solver.add_cost( | |
| motile.costs.EdgeSelection(weight=p.edgeW, constant=p.edgeC, attribute="weight") | |
| ) | |
| used_costs.edgeW = p.edgeW | |
| used_costs.edgeC = p.edgeC | |
| # APPEAR | |
| solver.add_cost(motile.costs.Appear(constant=p.appearC)) | |
| used_costs.appearC = p.appearC | |
| # DISAPPEAR | |
| solver.add_cost(motile.costs.Disappear(constant=p.disappearC)) | |
| used_costs.disappearC = p.disappearC | |
| # DIVISION | |
| if allow_divisions: | |
| solver.add_cost(motile.costs.Split(constant=p.splitC)) | |
| used_costs.splitC = p.splitC | |
| # Add constraints | |
| solver.add_constraint(motile.constraints.MaxParents(1)) | |
| solver.add_constraint(motile.constraints.MaxChildren(2 if allow_divisions else 1)) | |
| solver.solve() | |
| return solver, vars(used_costs) | |
| def solution_to_graph(solver, base_graph): | |
| new_graph = nx.DiGraph() | |
| node_indicators = solver.get_variables(motile.variables.NodeSelected) | |
| edge_indicators = solver.get_variables(motile.variables.EdgeSelected) | |
| # Build nodes | |
| for node, index in node_indicators.items(): | |
| if solver.solution[index] > 0.5: | |
| new_graph.add_node(node, **base_graph.nodes[node]) | |
| # Build edges | |
| for edge, index in edge_indicators.items(): | |
| if solver.solution[index] > 0.5: | |
| new_graph.add_edge(*edge, **base_graph.edges[edge]) | |
| return new_graph | |
| def print_solution_stats(solver, graph, gt_graph=None): | |
| time.sleep(0.1) # to wait for ilpy prints | |
| print( | |
| f"\nCandidate graph\t\t{len(graph.nodes):3} nodes\t{len(graph.edges):3} edges" | |
| ) | |
| if gt_graph: | |
| print( | |
| f"Ground truth graph\t{len(gt_graph.nodes):3}" | |
| f" nodes\t{len(gt_graph.edges):3} edges" | |
| ) | |
| node_selected = solver.get_variables(motile.variables.NodeSelected) | |
| edge_selected = solver.get_variables(motile.variables.EdgeSelected) | |
| nodes = 0 | |
| for node in graph.nodes: | |
| if solver.solution[node_selected[node]] > 0.5: | |
| nodes += 1 | |
| edges = 0 | |
| for u, v in graph.edges: | |
| if solver.solution[edge_selected[(u, v)]] > 0.5: | |
| edges += 1 | |
| print(f"Solution graph\t\t{nodes:3} nodes\t{edges:3} edges") | |