File size: 7,860 Bytes
effe3d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
Inference helpers: run single-cell model on batches, plot samples, save all predictions with metrics.
"""
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from utils import config
from utils.metrics import detect_tanh_output_model, convert_tanh_to_sigmoid_range
def run_batch_singlecell(bf_batch, generator, device, substrate, norm_params, config_path=None):
"""Run single-cell generator on a batch of 1-channel BF images; returns CPU predictions [0,1]."""
from models.s2f_model import create_settings_channels
meta = {"substrate": [substrate] * bf_batch.shape[0]}
settings_ch = create_settings_channels(meta, norm_params, device, bf_batch.shape, config_path=config_path)
inp = torch.cat([bf_batch.to(device, dtype=torch.float32), settings_ch], dim=1)
with torch.no_grad():
pred = generator(inp)
if detect_tanh_output_model(generator):
pred = convert_tanh_to_sigmoid_range(pred)
return pred.cpu()
def force_sum_and_stats(heatmap_np):
"""Return (force_scaled, pixel_sum, max, mean) for a prediction heatmap."""
pixel_sum = float(np.sum(heatmap_np))
force_scaled = pixel_sum * config.SCALE_FACTOR_FORCE
return force_scaled, pixel_sum, float(np.max(heatmap_np)), float(np.mean(heatmap_np))
def plot_inference_samples(loader, generator, n_samples=3, device=None, substrate='fibroblasts_PDMS',
normalization_params=None, config_path=None):
"""Plot BF | Prediction for first n_samples from loader (no ground truth)."""
from utils.substrate_settings import compute_settings_normalization
device = device or next(generator.parameters()).device
generator.eval()
norm = normalization_params or compute_settings_normalization(config_path=config_path)
bf_list = []
it = iter(loader)
while len(bf_list) < n_samples:
try:
batch = next(it)
except StopIteration:
break
for i in range(batch.shape[0]):
if len(bf_list) >= n_samples:
break
bf_list.append(batch[i])
n = len(bf_list)
if n == 0:
print("No samples in loader.")
return
bf_batch = torch.stack(bf_list)
pred = run_batch_singlecell(bf_batch, generator, device, substrate, norm, config_path)
fig, axes = plt.subplots(n, 2, figsize=(8, 4 * n))
if n == 1:
axes = axes.reshape(1, -1)
for i in range(n):
axes[i, 0].imshow(bf_list[i].squeeze().numpy(), cmap='gray')
axes[i, 0].set_title('Bright field')
axes[i, 0].axis('off')
axes[i, 1].imshow(pred[i].squeeze().numpy(), cmap='jet', vmin=0, vmax=1)
axes[i, 1].set_title('Prediction')
axes[i, 1].axis('off')
plt.tight_layout()
plt.show()
def save_all_predictions(loader, generator, save_dir, device=None, substrate='fibroblasts_PDMS',
normalization_params=None, config_path=None):
"""Run inference on full loader; save each image as BF | Prediction with force/pixel stats."""
from utils.substrate_settings import compute_settings_normalization
device = device or next(generator.parameters()).device
generator.eval()
norm = normalization_params or compute_settings_normalization(config_path=config_path)
paths = loader.dataset.paths
os.makedirs(save_dir, exist_ok=True)
idx = 0
with torch.no_grad():
for batch in loader:
pred = run_batch_singlecell(batch, generator, device, substrate, norm, config_path)
for i in range(pred.shape[0]):
bf_np = batch[i].squeeze().numpy()
pred_np = pred[i].squeeze().numpy()
force_scaled, pixel_sum, hm_max, hm_mean = force_sum_and_stats(pred_np)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(bf_np, cmap='gray')
axes[0].set_title('Bright field')
axes[0].axis('off')
axes[1].imshow(pred_np, cmap='jet', vmin=0, vmax=1)
axes[1].set_title('Prediction')
axes[1].axis('off')
stats = (f'Force (scaled): {force_scaled:.4f} | Pixel sum: {pixel_sum:.2f} | '
f'Max: {hm_max:.4f} | Mean: {hm_mean:.4f}')
fig.suptitle(stats, fontsize=10)
plt.tight_layout()
name = os.path.splitext(os.path.basename(paths[idx]))[0] + '_pred.png'
plt.savefig(os.path.join(save_dir, name), dpi=150, bbox_inches='tight')
plt.close()
idx += 1
print(f"Saved {idx} predictions to {save_dir}")
def run_batch_spheroid(bf_batch, generator, device):
"""Run spheroid generator on a batch of 1-channel BF images (no settings); returns CPU predictions [0,1]."""
inp = bf_batch.to(device, dtype=torch.float32)
with torch.no_grad():
pred = generator(inp)
if detect_tanh_output_model(generator):
pred = convert_tanh_to_sigmoid_range(pred)
return pred.cpu()
def plot_inference_samples_spheroid(loader, generator, n_samples=3, device=None):
"""Plot BF | Prediction for first n_samples from loader (spheroid, no ground truth)."""
device = device or next(generator.parameters()).device
generator.eval()
bf_list = []
it = iter(loader)
while len(bf_list) < n_samples:
try:
batch = next(it)
except StopIteration:
break
for i in range(batch.shape[0]):
if len(bf_list) >= n_samples:
break
bf_list.append(batch[i])
n = len(bf_list)
if n == 0:
print("No samples in loader.")
return
bf_batch = torch.stack(bf_list)
pred = run_batch_spheroid(bf_batch, generator, device)
fig, axes = plt.subplots(n, 2, figsize=(8, 4 * n))
if n == 1:
axes = axes.reshape(1, -1)
for i in range(n):
axes[i, 0].imshow(bf_list[i].squeeze().numpy(), cmap='gray')
axes[i, 0].set_title('Bright field')
axes[i, 0].axis('off')
axes[i, 1].imshow(pred[i].squeeze().numpy(), cmap='jet', vmin=0, vmax=1)
axes[i, 1].set_title('Prediction')
axes[i, 1].axis('off')
plt.tight_layout()
plt.show()
def save_all_predictions_spheroid(loader, generator, save_dir, device=None):
"""Run spheroid inference on full loader; save each image as BF | Prediction with force/pixel stats."""
device = device or next(generator.parameters()).device
generator.eval()
paths = loader.dataset.paths
os.makedirs(save_dir, exist_ok=True)
idx = 0
with torch.no_grad():
for batch in loader:
pred = run_batch_spheroid(batch, generator, device)
for i in range(pred.shape[0]):
bf_np = batch[i].squeeze().numpy()
pred_np = pred[i].squeeze().numpy()
force_scaled, pixel_sum, hm_max, hm_mean = force_sum_and_stats(pred_np)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(bf_np, cmap='gray')
axes[0].set_title('Bright field')
axes[0].axis('off')
axes[1].imshow(pred_np, cmap='jet', vmin=0, vmax=1)
axes[1].set_title('Prediction')
axes[1].axis('off')
stats = (f'Force (scaled): {force_scaled:.4f} | Pixel sum: {pixel_sum:.2f} | '
f'Max: {hm_max:.4f} | Mean: {hm_mean:.4f}')
fig.suptitle(stats, fontsize=10)
plt.tight_layout()
name = os.path.splitext(os.path.basename(paths[idx]))[0] + '_pred.png'
plt.savefig(os.path.join(save_dir, name), dpi=150, bbox_inches='tight')
plt.close()
idx += 1
print(f"Saved {idx} predictions to {save_dir}")
|