kaveh commited on
Commit
effe3d2
·
1 Parent(s): 31a6a59

added for inference on standalone bright fields

Browse files
Files changed (1) hide show
  1. utils/inference.py +180 -0
utils/inference.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference helpers: run single-cell model on batches, plot samples, save all predictions with metrics.
3
+ """
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ from utils import config
10
+ from utils.metrics import detect_tanh_output_model, convert_tanh_to_sigmoid_range
11
+
12
+
13
+ def run_batch_singlecell(bf_batch, generator, device, substrate, norm_params, config_path=None):
14
+ """Run single-cell generator on a batch of 1-channel BF images; returns CPU predictions [0,1]."""
15
+ from models.s2f_model import create_settings_channels
16
+ meta = {"substrate": [substrate] * bf_batch.shape[0]}
17
+ settings_ch = create_settings_channels(meta, norm_params, device, bf_batch.shape, config_path=config_path)
18
+ inp = torch.cat([bf_batch.to(device, dtype=torch.float32), settings_ch], dim=1)
19
+ with torch.no_grad():
20
+ pred = generator(inp)
21
+ if detect_tanh_output_model(generator):
22
+ pred = convert_tanh_to_sigmoid_range(pred)
23
+ return pred.cpu()
24
+
25
+
26
+ def force_sum_and_stats(heatmap_np):
27
+ """Return (force_scaled, pixel_sum, max, mean) for a prediction heatmap."""
28
+ pixel_sum = float(np.sum(heatmap_np))
29
+ force_scaled = pixel_sum * config.SCALE_FACTOR_FORCE
30
+ return force_scaled, pixel_sum, float(np.max(heatmap_np)), float(np.mean(heatmap_np))
31
+
32
+
33
+ def plot_inference_samples(loader, generator, n_samples=3, device=None, substrate='fibroblasts_PDMS',
34
+ normalization_params=None, config_path=None):
35
+ """Plot BF | Prediction for first n_samples from loader (no ground truth)."""
36
+ from utils.substrate_settings import compute_settings_normalization
37
+ device = device or next(generator.parameters()).device
38
+ generator.eval()
39
+ norm = normalization_params or compute_settings_normalization(config_path=config_path)
40
+ bf_list = []
41
+ it = iter(loader)
42
+ while len(bf_list) < n_samples:
43
+ try:
44
+ batch = next(it)
45
+ except StopIteration:
46
+ break
47
+ for i in range(batch.shape[0]):
48
+ if len(bf_list) >= n_samples:
49
+ break
50
+ bf_list.append(batch[i])
51
+ n = len(bf_list)
52
+ if n == 0:
53
+ print("No samples in loader.")
54
+ return
55
+ bf_batch = torch.stack(bf_list)
56
+ pred = run_batch_singlecell(bf_batch, generator, device, substrate, norm, config_path)
57
+ fig, axes = plt.subplots(n, 2, figsize=(8, 4 * n))
58
+ if n == 1:
59
+ axes = axes.reshape(1, -1)
60
+ for i in range(n):
61
+ axes[i, 0].imshow(bf_list[i].squeeze().numpy(), cmap='gray')
62
+ axes[i, 0].set_title('Bright field')
63
+ axes[i, 0].axis('off')
64
+ axes[i, 1].imshow(pred[i].squeeze().numpy(), cmap='jet', vmin=0, vmax=1)
65
+ axes[i, 1].set_title('Prediction')
66
+ axes[i, 1].axis('off')
67
+ plt.tight_layout()
68
+ plt.show()
69
+
70
+
71
+ def save_all_predictions(loader, generator, save_dir, device=None, substrate='fibroblasts_PDMS',
72
+ normalization_params=None, config_path=None):
73
+ """Run inference on full loader; save each image as BF | Prediction with force/pixel stats."""
74
+ from utils.substrate_settings import compute_settings_normalization
75
+ device = device or next(generator.parameters()).device
76
+ generator.eval()
77
+ norm = normalization_params or compute_settings_normalization(config_path=config_path)
78
+ paths = loader.dataset.paths
79
+ os.makedirs(save_dir, exist_ok=True)
80
+ idx = 0
81
+ with torch.no_grad():
82
+ for batch in loader:
83
+ pred = run_batch_singlecell(batch, generator, device, substrate, norm, config_path)
84
+ for i in range(pred.shape[0]):
85
+ bf_np = batch[i].squeeze().numpy()
86
+ pred_np = pred[i].squeeze().numpy()
87
+ force_scaled, pixel_sum, hm_max, hm_mean = force_sum_and_stats(pred_np)
88
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5))
89
+ axes[0].imshow(bf_np, cmap='gray')
90
+ axes[0].set_title('Bright field')
91
+ axes[0].axis('off')
92
+ axes[1].imshow(pred_np, cmap='jet', vmin=0, vmax=1)
93
+ axes[1].set_title('Prediction')
94
+ axes[1].axis('off')
95
+ stats = (f'Force (scaled): {force_scaled:.4f} | Pixel sum: {pixel_sum:.2f} | '
96
+ f'Max: {hm_max:.4f} | Mean: {hm_mean:.4f}')
97
+ fig.suptitle(stats, fontsize=10)
98
+ plt.tight_layout()
99
+ name = os.path.splitext(os.path.basename(paths[idx]))[0] + '_pred.png'
100
+ plt.savefig(os.path.join(save_dir, name), dpi=150, bbox_inches='tight')
101
+ plt.close()
102
+ idx += 1
103
+ print(f"Saved {idx} predictions to {save_dir}")
104
+
105
+
106
+ def run_batch_spheroid(bf_batch, generator, device):
107
+ """Run spheroid generator on a batch of 1-channel BF images (no settings); returns CPU predictions [0,1]."""
108
+ inp = bf_batch.to(device, dtype=torch.float32)
109
+ with torch.no_grad():
110
+ pred = generator(inp)
111
+ if detect_tanh_output_model(generator):
112
+ pred = convert_tanh_to_sigmoid_range(pred)
113
+ return pred.cpu()
114
+
115
+
116
+ def plot_inference_samples_spheroid(loader, generator, n_samples=3, device=None):
117
+ """Plot BF | Prediction for first n_samples from loader (spheroid, no ground truth)."""
118
+ device = device or next(generator.parameters()).device
119
+ generator.eval()
120
+ bf_list = []
121
+ it = iter(loader)
122
+ while len(bf_list) < n_samples:
123
+ try:
124
+ batch = next(it)
125
+ except StopIteration:
126
+ break
127
+ for i in range(batch.shape[0]):
128
+ if len(bf_list) >= n_samples:
129
+ break
130
+ bf_list.append(batch[i])
131
+ n = len(bf_list)
132
+ if n == 0:
133
+ print("No samples in loader.")
134
+ return
135
+ bf_batch = torch.stack(bf_list)
136
+ pred = run_batch_spheroid(bf_batch, generator, device)
137
+ fig, axes = plt.subplots(n, 2, figsize=(8, 4 * n))
138
+ if n == 1:
139
+ axes = axes.reshape(1, -1)
140
+ for i in range(n):
141
+ axes[i, 0].imshow(bf_list[i].squeeze().numpy(), cmap='gray')
142
+ axes[i, 0].set_title('Bright field')
143
+ axes[i, 0].axis('off')
144
+ axes[i, 1].imshow(pred[i].squeeze().numpy(), cmap='jet', vmin=0, vmax=1)
145
+ axes[i, 1].set_title('Prediction')
146
+ axes[i, 1].axis('off')
147
+ plt.tight_layout()
148
+ plt.show()
149
+
150
+
151
+ def save_all_predictions_spheroid(loader, generator, save_dir, device=None):
152
+ """Run spheroid inference on full loader; save each image as BF | Prediction with force/pixel stats."""
153
+ device = device or next(generator.parameters()).device
154
+ generator.eval()
155
+ paths = loader.dataset.paths
156
+ os.makedirs(save_dir, exist_ok=True)
157
+ idx = 0
158
+ with torch.no_grad():
159
+ for batch in loader:
160
+ pred = run_batch_spheroid(batch, generator, device)
161
+ for i in range(pred.shape[0]):
162
+ bf_np = batch[i].squeeze().numpy()
163
+ pred_np = pred[i].squeeze().numpy()
164
+ force_scaled, pixel_sum, hm_max, hm_mean = force_sum_and_stats(pred_np)
165
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5))
166
+ axes[0].imshow(bf_np, cmap='gray')
167
+ axes[0].set_title('Bright field')
168
+ axes[0].axis('off')
169
+ axes[1].imshow(pred_np, cmap='jet', vmin=0, vmax=1)
170
+ axes[1].set_title('Prediction')
171
+ axes[1].axis('off')
172
+ stats = (f'Force (scaled): {force_scaled:.4f} | Pixel sum: {pixel_sum:.2f} | '
173
+ f'Max: {hm_max:.4f} | Mean: {hm_mean:.4f}')
174
+ fig.suptitle(stats, fontsize=10)
175
+ plt.tight_layout()
176
+ name = os.path.splitext(os.path.basename(paths[idx]))[0] + '_pred.png'
177
+ plt.savefig(os.path.join(save_dir, name), dpi=150, bbox_inches='tight')
178
+ plt.close()
179
+ idx += 1
180
+ print(f"Saved {idx} predictions to {save_dir}")