File size: 8,001 Bytes
2b9ff22 fbe1a2c 2b9ff22 fbe1a2c 2b9ff22 fbe1a2c 2b9ff22 fbe1a2c 2b9ff22 fbe1a2c 2b9ff22 fbe1a2c 2b9ff22 fbe1a2c 2b9ff22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
Core inference logic for S2F (Shape2Force).
Predicts force maps from bright field microscopy images.
"""
import os
import sys
import cv2
import torch
import numpy as np
# Ensure S2F is in path when running from project root or S2F
S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
if S2F_ROOT not in sys.path:
sys.path.insert(0, S2F_ROOT)
from models.s2f_model import create_s2f_model
from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
from utils import config
def load_image(filepath, target_size=1024):
"""Load and preprocess a bright field image."""
img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError(f"Could not load image: {filepath}")
if isinstance(target_size, int):
target_size = (target_size, target_size)
img = cv2.resize(img, target_size)
img = img.astype(np.float32) / 255.0
return img
def sum_force_map(force_map):
"""Compute cell force as sum of pixel values scaled by SCALE_FACTOR_FORCE."""
if isinstance(force_map, np.ndarray):
force_map = torch.from_numpy(force_map.astype(np.float32))
if force_map.dim() == 2:
force_map = force_map.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
elif force_map.dim() == 3:
force_map = force_map.unsqueeze(0) # [1, 1, H, W]
# force_map: [B, 1, H, W], sum over spatial dims (2, 3)
return torch.sum(force_map, dim=(2, 3)) * config.SCALE_FACTOR_FORCE
def create_settings_channels_single(substrate_name, device, height, width, config_path=None,
substrate_config=None):
"""
Create settings channels for a single image (single-cell mode).
Args:
substrate_name: Substrate name (used if substrate_config is None)
device: torch device
height, width: spatial dimensions
config_path: Path to substrate config JSON
substrate_config: Optional dict with 'pixelsize' and 'young'. If provided, overrides substrate_name.
"""
norm_params = compute_settings_normalization(config_path=config_path)
if substrate_config is not None and 'pixelsize' in substrate_config and 'young' in substrate_config:
settings = substrate_config
else:
settings = get_settings_of_category(substrate_name, config_path=config_path)
pmin, pmax = norm_params['pixelsize']['min'], norm_params['pixelsize']['max']
ymin, ymax = norm_params['young']['min'], norm_params['young']['max']
pixelsize_norm = (settings['pixelsize'] - pmin) / (pmax - pmin) if pmax > pmin else 0.5
young_norm = (settings['young'] - ymin) / (ymax - ymin) if ymax > ymin else 0.5
pixelsize_norm = max(0.0, min(1.0, pixelsize_norm))
young_norm = max(0.0, min(1.0, young_norm))
pixelsize_ch = torch.full(
(1, 1, height, width), pixelsize_norm, device=device, dtype=torch.float32
)
young_ch = torch.full(
(1, 1, height, width), young_norm, device=device, dtype=torch.float32
)
return torch.cat([pixelsize_ch, young_ch], dim=1)
class S2FPredictor:
"""
Shape2Force predictor for single-cell or spheroid force map prediction.
"""
def __init__(self, model_type="single_cell", checkpoint_path=None, ckp_folder=None, device=None):
"""
Args:
model_type: "single_cell" or "spheroid"
checkpoint_path: Path to .pth checkpoint (relative to ckp_folder or absolute)
ckp_folder: Folder containing checkpoints (default: S2F/ckp)
device: "cuda" or "cpu" (auto-detected if None)
"""
self.model_type = model_type
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
ckp_base = os.path.join(S2F_ROOT, "ckp")
if not os.path.isdir(ckp_base):
project_root = os.path.dirname(S2F_ROOT)
if os.path.isdir(os.path.join(project_root, "ckp")):
ckp_base = os.path.join(project_root, "ckp")
subfolder = "single_cell" if model_type == "single_cell" else "spheroid"
ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
if not os.path.isdir(ckp_dir):
ckp_dir = ckp_base # fallback if subfolders not used
in_channels = 3 if model_type == "single_cell" else 1
s2f_model_type = "s2f" if model_type == "single_cell" else "s2f_spheroid"
generator, _ = create_s2f_model(in_channels=in_channels, model_type=s2f_model_type)
self.generator = generator
if checkpoint_path:
full_path = checkpoint_path
if not os.path.isabs(checkpoint_path):
full_path = os.path.join(ckp_dir, checkpoint_path)
if not os.path.exists(full_path):
full_path = os.path.join(ckp_base, checkpoint_path) # try base folder
if not os.path.exists(full_path):
raise FileNotFoundError(f"Checkpoint not found: {full_path}")
if model_type == "single_cell":
self.generator.load_checkpoint_with_expansion(full_path, strict=True)
else:
checkpoint = torch.load(full_path, map_location="cpu", weights_only=False)
state = checkpoint.get("generator_state_dict") or checkpoint.get("model_state_dict") or checkpoint
self.generator.load_state_dict(state, strict=True)
if hasattr(self.generator, "set_output_mode"):
self.generator.set_output_mode(use_tanh=False) # sigmoid [0,1] for inference
self.generator = self.generator.to(self.device)
self.generator.eval()
self.norm_params = compute_settings_normalization() if model_type == "single_cell" else None
self._use_tanh_output = model_type == "single_cell" # single_cell uses tanh, spheroid uses sigmoid
self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json")
def predict(self, image_path=None, image_array=None, substrate="fibroblasts_PDMS",
substrate_config=None):
"""
Run prediction on an image.
Args:
image_path: Path to bright field image (tif, png, jpg)
image_array: numpy array (H, W) or (H, W, C) in [0, 255] or [0, 1]
substrate: Substrate name for single-cell mode (used if substrate_config is None)
substrate_config: Optional dict with 'pixelsize' and 'young'. Overrides substrate lookup.
Returns:
heatmap: numpy array (1024, 1024) in [0, 1]
force: scalar cell force (sum of heatmap * SCALE_FACTOR_FORCE)
pixel_sum: raw sum of all pixel values in heatmap
"""
if image_path is not None:
img = load_image(image_path)
elif image_array is not None:
img = np.asarray(image_array, dtype=np.float32)
if img.ndim == 3:
img = img[:, :, 0] if img.shape[-1] >= 1 else img
if img.max() > 1.0:
img = img / 255.0
img = cv2.resize(img, (1024, 1024))
else:
raise ValueError("Provide image_path or image_array")
x = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,H,W]
if self.model_type == "single_cell" and self.norm_params is not None:
settings_ch = create_settings_channels_single(
substrate, self.device, x.shape[2], x.shape[3],
config_path=self.config_path, substrate_config=substrate_config
)
x = torch.cat([x, settings_ch], dim=1) # [1,3,H,W]
with torch.no_grad():
pred = self.generator(x)
if self._use_tanh_output:
pred = (pred + 1.0) / 2.0 # Tanh [-1,1] to [0, 1]
# else: spheroid already outputs sigmoid [0, 1]
heatmap = pred[0, 0].cpu().numpy()
force = sum_force_map(pred).item()
pixel_sum = float(np.sum(heatmap))
return heatmap, force, pixel_sum
|