Spaces:
Running on Zero
Running on Zero
File size: 4,923 Bytes
86072ea aff3c6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | # Modified from Trackastra (https://github.com/weigertlab/trackastra)
import logging
import time
from types import SimpleNamespace
import networkx as nx
import yaml
try:
import motile
except ModuleNotFoundError:
raise ModuleNotFoundError(
"For tracking with an ILP, please conda install the optional `motile`"
" dependency following https://funkelab.github.io/motile/install.html."
)
logger = logging.getLogger(__name__)
ILP_CONFIGS = {
"gt": SimpleNamespace(
nodeW=0,
nodeC=-10, # take all nodes
edgeW=-1,
edgeC=0,
appearC=0.25,
disappearC=0.5,
splitC=0.25,
),
"deepcell_gt": SimpleNamespace(
nodeW=0,
nodeC=-10, # take all nodes
edgeW=-1,
edgeC=0,
appearC=0.25,
disappearC=0.5,
splitC=1,
),
"deepcell_gt_tuned": SimpleNamespace(
nodeW=0,
nodeC=-10, # take all nodes
edgeW=-1,
edgeC=0,
appearC=0.5,
disappearC=0.5,
splitC=1,
),
"deepcell_res_tuned": SimpleNamespace(
nodeW=0,
nodeC=0.25,
edgeW=-1,
edgeC=-0.25,
appearC=0.25,
disappearC=0.25,
splitC=1.0,
),
}
def track_ilp(
candidate_graph,
allow_divisions: bool = True,
ilp_config: str = "gt",
params_file: str | None = None,
**kwargs,
):
candidate_graph_motile = motile.TrackGraph(candidate_graph, frame_attribute="time")
ilp, _used_costs = solve_full_ilp(
candidate_graph_motile,
allow_divisions=allow_divisions,
mode=ilp_config,
params_file=params_file,
)
print_solution_stats(ilp, candidate_graph_motile)
graph = solution_to_graph(ilp, candidate_graph_motile)
return graph
def solve_full_ilp(
graph,
allow_divisions: bool,
mode: str,
params_file: str | None,
):
solver = motile.Solver(graph)
if params_file:
with open(params_file) as f:
p = yaml.safe_load(f)
# TODO more checks
p = SimpleNamespace(**p)
logger.info(f"Using ILP parameters {p}")
else:
try:
p = ILP_CONFIGS[mode]
logger.info(f"Using `{mode}` ILP config.")
except KeyError:
raise ValueError(
f"Unknown ILP mode {mode}. Choose from {list(ILP_CONFIGS.keys())} or"
" supply custom parameters via `params_file` argument."
)
# Add costs
used_costs = SimpleNamespace()
# NODES
solver.add_cost(
motile.costs.NodeSelection(weight=p.nodeW, constant=p.nodeC, attribute="weight")
)
used_costs.nodeW = p.nodeW
used_costs.nodeC = p.nodeC
# EDGES
solver.add_cost(
motile.costs.EdgeSelection(weight=p.edgeW, constant=p.edgeC, attribute="weight")
)
used_costs.edgeW = p.edgeW
used_costs.edgeC = p.edgeC
# APPEAR
solver.add_cost(motile.costs.Appear(constant=p.appearC))
used_costs.appearC = p.appearC
# DISAPPEAR
solver.add_cost(motile.costs.Disappear(constant=p.disappearC))
used_costs.disappearC = p.disappearC
# DIVISION
if allow_divisions:
solver.add_cost(motile.costs.Split(constant=p.splitC))
used_costs.splitC = p.splitC
# Add constraints
solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2 if allow_divisions else 1))
solver.solve()
return solver, vars(used_costs)
def solution_to_graph(solver, base_graph):
new_graph = nx.DiGraph()
node_indicators = solver.get_variables(motile.variables.NodeSelected)
edge_indicators = solver.get_variables(motile.variables.EdgeSelected)
# Build nodes
for node, index in node_indicators.items():
if solver.solution[index] > 0.5:
new_graph.add_node(node, **base_graph.nodes[node])
# Build edges
for edge, index in edge_indicators.items():
if solver.solution[index] > 0.5:
new_graph.add_edge(*edge, **base_graph.edges[edge])
return new_graph
def print_solution_stats(solver, graph, gt_graph=None):
time.sleep(0.1) # to wait for ilpy prints
print(
f"\nCandidate graph\t\t{len(graph.nodes):3} nodes\t{len(graph.edges):3} edges"
)
if gt_graph:
print(
f"Ground truth graph\t{len(gt_graph.nodes):3}"
f" nodes\t{len(gt_graph.edges):3} edges"
)
node_selected = solver.get_variables(motile.variables.NodeSelected)
edge_selected = solver.get_variables(motile.variables.EdgeSelected)
nodes = 0
for node in graph.nodes:
if solver.solution[node_selected[node]] > 0.5:
nodes += 1
edges = 0
for u, v in graph.edges:
if solver.solution[edge_selected[(u, v)]] > 0.5:
edges += 1
print(f"Solution graph\t\t{nodes:3} nodes\t{edges:3} edges")
|