Shape2Force / S2FApp /models /s2f_model.py
kaveh's picture
added spheroid model
de14db1
"""
S2F (Shape2Force) model for force map prediction (inference only).
Supports single-cell and spheroid modes.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .blocks import ResidualBlock
from .cbam import CBAM
from utils import config
from utils.substrate_settings import (
get_settings_of_category,
compute_settings_normalization,
load_substrate_config,
)
def normalize_settings(substrate_name, normalization_params, config=None, config_path=None):
"""
Normalize settings for a given substrate.
Args:
substrate_name (str): Name of the substrate
normalization_params (dict): Normalization parameters
Returns:
tuple: (normalized_pixelsize, normalized_young)
"""
settings = get_settings_of_category(substrate_name, config=config, config_path=config_path)
# Min-max normalization to [0, 1]
pixelsize_norm = (settings['pixelsize'] - normalization_params['pixelsize']['min']) / \
(normalization_params['pixelsize']['max'] - normalization_params['pixelsize']['min'])
young_norm = (settings['young'] - normalization_params['young']['min']) / \
(normalization_params['young']['max'] - normalization_params['young']['min'])
return pixelsize_norm, young_norm
def create_settings_channels(metadata, normalization_params, device, image_shape, config_path=None):
"""
Create settings channels for a batch of images.
Args:
metadata (dict): Batch metadata containing substrate information
normalization_params (dict): Normalization parameters
device: Device to create tensors on
image_shape (tuple): Shape of input images (B, C, H, W)
Returns:
torch.Tensor: Settings channels [B, 2, H, W] where channels are [pixelsize, young]
"""
batch_size, _, height, width = image_shape
# Create settings channels
pixelsize_channel = torch.zeros(batch_size, 1, height, width, device=device)
young_channel = torch.zeros(batch_size, 1, height, width, device=device)
for i in range(batch_size):
substrate = metadata['substrate'][i]
pixelsize_norm, young_norm = normalize_settings(
substrate, normalization_params, config_path=config_path
)
# Fill entire channel with normalized value
pixelsize_channel[i, 0] = pixelsize_norm
young_channel[i, 0] = young_norm
# Concatenate channels
settings_channels = torch.cat([pixelsize_channel, young_channel], dim=1) # [B, 2, H, W]
return settings_channels
class GlobalContextModule(nn.Module):
"""Global context module for capturing cell shape information"""
def __init__(self, in_channels):
super().__init__()
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.global_conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels//4, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels//4, in_channels, 1),
nn.Sigmoid()
)
self.large_kernel = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
nn.Conv2d(in_channels, in_channels, 1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.multi_scale = nn.ModuleList([
nn.Conv2d(in_channels, in_channels//4, 3, padding=1, dilation=1),
nn.Conv2d(in_channels, in_channels//4, 3, padding=2, dilation=2),
nn.Conv2d(in_channels, in_channels//4, 3, padding=4, dilation=4),
nn.Conv2d(in_channels, in_channels//4, 3, padding=8, dilation=8)
])
self.fusion = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
global_ctx = self.global_pool(x)
global_weight = self.global_conv(global_ctx)
large_features = self.large_kernel(x)
multi_scale_features = []
for conv in self.multi_scale:
multi_scale_features.append(conv(x))
multi_scale_out = torch.cat(multi_scale_features, dim=1)
multi_scale_out = self.fusion(multi_scale_out)
return x + (large_features * global_weight) + multi_scale_out
class HierarchicalAttention(nn.Module):
"""Hierarchical attention combining spatial and channel attention"""
def __init__(self, channels):
super().__init__()
self.spatial_att = nn.Sequential(
nn.Conv2d(channels, channels//8, 1),
nn.Conv2d(channels//8, 1, 3, padding=1),
nn.Sigmoid()
)
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//16, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels//16, channels, 1),
nn.Sigmoid()
)
self.cross_att = nn.Sequential(
nn.Conv2d(channels, channels//4, 1),
nn.BatchNorm2d(channels//4),
nn.ReLU(inplace=True),
nn.Conv2d(channels//4, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
spatial_weight = self.spatial_att(x)
channel_weight = self.channel_att(x)
attended = x * spatial_weight * channel_weight
cross_weight = self.cross_att(attended)
return x + (attended * cross_weight)
class EnhancedAttentionGate(nn.Module):
"""Enhanced attention gate with global context"""
def __init__(self, F_g, F_l, F_int):
super().__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(F_int, F_int//2, kernel_size=3, padding=1),
nn.BatchNorm2d(F_int//2),
nn.ReLU(inplace=True),
nn.Conv2d(F_int//2, 1, kernel_size=1),
nn.Sigmoid()
)
self.global_context = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(F_l, F_int//4, 1),
nn.ReLU(inplace=True),
nn.Conv2d(F_int//4, 1, 1),
nn.Sigmoid()
)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
if g1.shape[2:] != x1.shape[2:]:
g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
psi = self.psi(g1 + x1)
global_weight = self.global_context(x)
psi = psi * global_weight
if psi.shape[2:] != x.shape[2:]:
psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False)
return x * psi
class S2FGenerator(nn.Module):
"""
S2F (Shape2Force) model: U-Net generator for force map prediction.
Supports substrate-specific settings as additional input channels.
"""
def __init__(self,
in_channels=1,
out_channels=1,
img_size=1024,
bridge_type='cbam',
use_multi_scale_input=True):
super().__init__()
self.img_size = img_size
self.bridge_type = bridge_type
self.use_multi_scale_input = use_multi_scale_input
if self.use_multi_scale_input:
self.scale_pyramid = nn.ModuleList([
nn.Conv2d(in_channels, 32, 3, padding=1),
nn.Sequential(
nn.AvgPool2d(2, stride=2),
nn.Conv2d(in_channels, 32, 3, padding=1)
),
nn.Sequential(
nn.AvgPool2d(4, stride=4),
nn.Conv2d(in_channels, 32, 3, padding=1)
)
])
self.initial_conv = nn.Conv2d(96, 64, 1)
else:
self.initial_conv = nn.Conv2d(in_channels, 64, 3, padding=1)
def enhanced_conv_block(in_c, out_c, use_attention=True):
layers = [
nn.Conv2d(in_c, out_c, 3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
ResidualBlock(out_c, out_c)
]
if use_attention:
layers.append(HierarchicalAttention(out_c))
return nn.Sequential(*layers)
def dilated_conv_block(in_c, out_c, use_global_context=False):
layers = [
nn.Conv2d(in_c, out_c, 3, padding=2, dilation=2),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
ResidualBlock(out_c, out_c)
]
if use_global_context:
layers.append(GlobalContextModule(out_c))
return nn.Sequential(*layers)
self.encoder1 = enhanced_conv_block(64, 64, use_attention=False)
self.pool1 = nn.MaxPool2d(2)
self.encoder2 = enhanced_conv_block(64, 128, use_attention=True)
self.pool2 = nn.MaxPool2d(2)
self.encoder3 = dilated_conv_block(128, 256, use_global_context=True)
self.pool3 = nn.MaxPool2d(2)
self.encoder4 = dilated_conv_block(256, 512, use_global_context=True)
self.pool4 = nn.MaxPool2d(2)
if bridge_type == 'cbam':
self.bridge = nn.Sequential(
dilated_conv_block(512, 1024, use_global_context=True),
CBAM(1024),
GlobalContextModule(1024),
HierarchicalAttention(1024)
)
else:
self.bridge = nn.Sequential(
dilated_conv_block(512, 1024, use_global_context=True),
GlobalContextModule(1024),
HierarchicalAttention(1024)
)
self.att4 = EnhancedAttentionGate(512, 512, 256)
self.att3 = EnhancedAttentionGate(256, 256, 128)
self.att2 = EnhancedAttentionGate(128, 128, 64)
self.att1 = EnhancedAttentionGate(64, 64, 32)
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec4 = enhanced_conv_block(1024, 512, use_attention=True)
self.refine4 = HierarchicalAttention(512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = enhanced_conv_block(512, 256, use_attention=True)
self.refine3 = HierarchicalAttention(256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = enhanced_conv_block(256, 128, use_attention=True)
self.refine2 = HierarchicalAttention(128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = enhanced_conv_block(128, 64, use_attention=True)
self.refine1 = HierarchicalAttention(64)
self.final_conv = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, out_channels, 1),
nn.Tanh()
)
def forward(self, x):
if self.use_multi_scale_input:
scale_features = []
for i, scale_conv in enumerate(self.scale_pyramid):
if i == 0:
scale_features.append(scale_conv(x))
else:
scale_out = scale_conv(x)
scale_out = F.interpolate(scale_out, size=x.shape[2:], mode='bilinear', align_corners=False)
scale_features.append(scale_out)
fused = torch.cat(scale_features, dim=1)
initial_features = self.initial_conv(fused)
else:
initial_features = self.initial_conv(x)
e1 = self.encoder1(initial_features)
e2 = self.encoder2(self.pool1(e1))
e3 = self.encoder3(self.pool2(e2))
e4 = self.encoder4(self.pool3(e3))
b = self.bridge(self.pool4(e4))
g4 = self.up4(b)
x4 = self.att4(g4, e4)
d4 = self.dec4(torch.cat([g4, x4], dim=1))
d4 = self.refine4(d4)
g3 = self.up3(d4)
x3 = self.att3(g3, e3)
d3 = self.dec3(torch.cat([g3, x3], dim=1))
d3 = self.refine3(d3)
g2 = self.up2(d3)
x2 = self.att2(g2, e2)
d2 = self.dec2(torch.cat([g2, x2], dim=1))
d2 = self.refine2(d2)
g1 = self.up1(d2)
x1 = self.att1(g1, e1)
d1 = self.dec1(torch.cat([g1, x1], dim=1))
d1 = self.refine1(d1)
out = self.final_conv(d1)
return out
def load_checkpoint_with_expansion(self, checkpoint_path, strict=False):
"""Load checkpoint and expand from 1-channel to 3-channel if needed."""
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
generator_state = checkpoint['generator_state_dict']
needs_expansion = False
if 'scale_pyramid.0.weight' in generator_state:
old_shape = generator_state['scale_pyramid.0.weight'].shape
current_shape = self.scale_pyramid[0].weight.shape
if old_shape[1] != current_shape[1]:
needs_expansion = True
elif 'initial_conv.weight' in generator_state:
old_shape = generator_state['initial_conv.weight'].shape
current_shape = self.initial_conv.weight.shape
if old_shape[1] != current_shape[1]:
needs_expansion = True
if needs_expansion:
generator_state = self._expand_generator_state(generator_state)
self.load_state_dict(generator_state, strict=strict)
return checkpoint
def _expand_generator_state(self, generator_state):
"""Expand generator state dict from 1-channel to 3-channel input."""
expanded_state = generator_state.copy()
if 'scale_pyramid.0.weight' in generator_state:
for i in range(3):
key = f'scale_pyramid.{i}.weight' if i == 0 else f'scale_pyramid.{i}.1.weight'
if key in generator_state:
old_weight = generator_state[key]
new_weight = torch.zeros(32, 3, 3, 3)
new_weight[:, 0:1, :, :] = old_weight
expanded_state[key] = new_weight
elif 'initial_conv.weight' in generator_state:
old_weight = generator_state['initial_conv.weight']
new_weight = torch.zeros(64, 3, 3, 3)
new_weight[:, 0:1, :, :] = old_weight
expanded_state['initial_conv.weight'] = new_weight
return expanded_state
class SpheroidAttentionGate(nn.Module):
"""Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_*.pth."""
def __init__(self, F_g, F_l, F_int):
super(SpheroidAttentionGate, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(F_int, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.psi(g1 + x1)
return x * psi
class S2FSpheroidGenerator(nn.Module):
"""
S2F model tuned for spheroid data. Uses sigmoid output [0, 1] for inference.
"""
def __init__(self, in_channels=1, out_channels=1, predict_numbers=False, img_size=1024, use_tanh_output=True):
super(S2FSpheroidGenerator, self).__init__()
self.predict_numbers = predict_numbers
self.img_size = img_size
self.use_tanh_output = use_tanh_output
def conv_block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
ResidualBlock(out_c, out_c)
)
# Encoder
self.encoder1 = conv_block(in_channels, 32)
self.pool1 = nn.MaxPool2d(2)
self.encoder2 = conv_block(32, 64)
self.pool2 = nn.MaxPool2d(2)
self.encoder3 = conv_block(64, 128)
self.pool3 = nn.MaxPool2d(2)
self.encoder4 = conv_block(128, 256)
self.pool4 = nn.MaxPool2d(2)
self.bridge = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=2, dilation=2),
nn.BatchNorm2d(512),
nn.ReLU(),
ResidualBlock(512, 512)
)
self.att3 = SpheroidAttentionGate(256, 256, 128)
self.att2 = SpheroidAttentionGate(128, 128, 64)
self.att1 = SpheroidAttentionGate(64, 64, 32)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = conv_block(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = conv_block(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = conv_block(128, 64)
self.up0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec0 = conv_block(64, 32)
self.pred_conv = nn.Conv2d(32, out_channels, kernel_size=1)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(self.pool1(e1))
e3 = self.encoder3(self.pool2(e2))
e4 = self.encoder4(self.pool3(e3))
b = self.bridge(self.pool4(e4))
g3 = self.up3(b)
x3 = self.att3(g3, e4)
d3 = self.dec3(torch.cat([g3, x3], dim=1))
g2 = self.up2(d3)
x2 = self.att2(g2, e3)
d2 = self.dec2(torch.cat([g2, x2], dim=1))
g1 = self.up1(d2)
x1 = self.att1(g1, e2)
d1 = self.dec1(torch.cat([g1, x1], dim=1))
g0 = self.up0(d1)
d0 = self.dec0(torch.cat([g0, e1], dim=1))
out = self.pred_conv(d0)
out_resized = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
if self.use_tanh_output:
return torch.tanh(out_resized)
else:
return torch.sigmoid(out_resized)
def set_output_mode(self, use_tanh=True):
"""Set output activation: tanh [-1,1] for training, sigmoid [0,1] for inference."""
self.use_tanh_output = use_tanh
class PatchGANDiscriminator(nn.Module):
"""PatchGAN Discriminator (included for create_s2f_model compatibility)."""
def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
super().__init__()
use_bias = norm_layer == nn.InstanceNorm2d
self.initial_conv = nn.Sequential(
nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias),
nn.LeakyReLU(0.2, inplace=True)
)
self.layers = nn.ModuleList()
nf_mult, nf_mult_prev = 1, 1
for n in range(1, n_layers):
nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8)
self.layers.append(nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, inplace=True)
))
nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8)
self.layers.append(nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, inplace=True)
))
self.output_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
self.attention = nn.Sequential(
nn.Conv2d(ndf * nf_mult, ndf * nf_mult // 4, 1),
nn.ReLU(inplace=True),
nn.Conv2d(ndf * nf_mult // 4, ndf * nf_mult, 1),
nn.Sigmoid()
)
def forward(self, input):
x = self.initial_conv(input)
for layer in self.layers:
x = layer(x)
x = x * self.attention(x)
return self.output_conv(x)
def create_s2f_model(
in_channels=1,
out_channels=1,
img_size=1024,
bridge_type='cbam',
use_multi_scale_input=True,
ndf=64,
n_layers=3,
model_type='s2f',
):
"""Create S2F model with generator and discriminator.
model_type: 's2f' for single-cell, 's2f_spheroid' for spheroid.
"""
if model_type == 's2f':
generator = S2FGenerator(
in_channels=in_channels,
out_channels=out_channels,
img_size=img_size,
bridge_type=bridge_type,
use_multi_scale_input=use_multi_scale_input,
)
elif model_type == 's2f_spheroid':
generator = S2FSpheroidGenerator(
in_channels=in_channels,
out_channels=out_channels,
img_size=img_size,
)
else:
raise ValueError(f"Invalid model type: {model_type}")
discriminator = PatchGANDiscriminator(
in_channels=in_channels + out_channels,
ndf=ndf,
n_layers=n_layers
)
return generator, discriminator