Spaces:
Running
Running
| """ | |
| S2F (Shape2Force) model for force map prediction (inference only). | |
| Supports single-cell and spheroid modes. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .blocks import ResidualBlock | |
| from .cbam import CBAM | |
| from utils import config | |
| from utils.substrate_settings import ( | |
| get_settings_of_category, | |
| compute_settings_normalization, | |
| load_substrate_config, | |
| ) | |
| def normalize_settings(substrate_name, normalization_params, config=None, config_path=None): | |
| """ | |
| Normalize settings for a given substrate. | |
| Args: | |
| substrate_name (str): Name of the substrate | |
| normalization_params (dict): Normalization parameters | |
| Returns: | |
| tuple: (normalized_pixelsize, normalized_young) | |
| """ | |
| settings = get_settings_of_category(substrate_name, config=config, config_path=config_path) | |
| # Min-max normalization to [0, 1] | |
| pixelsize_norm = (settings['pixelsize'] - normalization_params['pixelsize']['min']) / \ | |
| (normalization_params['pixelsize']['max'] - normalization_params['pixelsize']['min']) | |
| young_norm = (settings['young'] - normalization_params['young']['min']) / \ | |
| (normalization_params['young']['max'] - normalization_params['young']['min']) | |
| return pixelsize_norm, young_norm | |
| def create_settings_channels(metadata, normalization_params, device, image_shape, config_path=None): | |
| """ | |
| Create settings channels for a batch of images. | |
| Args: | |
| metadata (dict): Batch metadata containing substrate information | |
| normalization_params (dict): Normalization parameters | |
| device: Device to create tensors on | |
| image_shape (tuple): Shape of input images (B, C, H, W) | |
| Returns: | |
| torch.Tensor: Settings channels [B, 2, H, W] where channels are [pixelsize, young] | |
| """ | |
| batch_size, _, height, width = image_shape | |
| # Create settings channels | |
| pixelsize_channel = torch.zeros(batch_size, 1, height, width, device=device) | |
| young_channel = torch.zeros(batch_size, 1, height, width, device=device) | |
| for i in range(batch_size): | |
| substrate = metadata['substrate'][i] | |
| pixelsize_norm, young_norm = normalize_settings( | |
| substrate, normalization_params, config_path=config_path | |
| ) | |
| # Fill entire channel with normalized value | |
| pixelsize_channel[i, 0] = pixelsize_norm | |
| young_channel[i, 0] = young_norm | |
| # Concatenate channels | |
| settings_channels = torch.cat([pixelsize_channel, young_channel], dim=1) # [B, 2, H, W] | |
| return settings_channels | |
| class GlobalContextModule(nn.Module): | |
| """Global context module for capturing cell shape information""" | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.global_pool = nn.AdaptiveAvgPool2d(1) | |
| self.global_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels//4, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(in_channels//4, in_channels, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.large_kernel = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels), | |
| nn.Conv2d(in_channels, in_channels, 1), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.multi_scale = nn.ModuleList([ | |
| nn.Conv2d(in_channels, in_channels//4, 3, padding=1, dilation=1), | |
| nn.Conv2d(in_channels, in_channels//4, 3, padding=2, dilation=2), | |
| nn.Conv2d(in_channels, in_channels//4, 3, padding=4, dilation=4), | |
| nn.Conv2d(in_channels, in_channels//4, 3, padding=8, dilation=8) | |
| ]) | |
| self.fusion = nn.Conv2d(in_channels, in_channels, 1) | |
| def forward(self, x): | |
| global_ctx = self.global_pool(x) | |
| global_weight = self.global_conv(global_ctx) | |
| large_features = self.large_kernel(x) | |
| multi_scale_features = [] | |
| for conv in self.multi_scale: | |
| multi_scale_features.append(conv(x)) | |
| multi_scale_out = torch.cat(multi_scale_features, dim=1) | |
| multi_scale_out = self.fusion(multi_scale_out) | |
| return x + (large_features * global_weight) + multi_scale_out | |
| class HierarchicalAttention(nn.Module): | |
| """Hierarchical attention combining spatial and channel attention""" | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.spatial_att = nn.Sequential( | |
| nn.Conv2d(channels, channels//8, 1), | |
| nn.Conv2d(channels//8, 1, 3, padding=1), | |
| nn.Sigmoid() | |
| ) | |
| self.channel_att = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(channels, channels//16, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channels//16, channels, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.cross_att = nn.Sequential( | |
| nn.Conv2d(channels, channels//4, 1), | |
| nn.BatchNorm2d(channels//4), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channels//4, channels, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| spatial_weight = self.spatial_att(x) | |
| channel_weight = self.channel_att(x) | |
| attended = x * spatial_weight * channel_weight | |
| cross_weight = self.cross_att(attended) | |
| return x + (attended * cross_weight) | |
| class EnhancedAttentionGate(nn.Module): | |
| """Enhanced attention gate with global context""" | |
| def __init__(self, F_g, F_l, F_int): | |
| super().__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(F_int, F_int//2, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(F_int//2), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(F_int//2, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| self.global_context = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(F_l, F_int//4, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(F_int//4, 1, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| if g1.shape[2:] != x1.shape[2:]: | |
| g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False) | |
| psi = self.psi(g1 + x1) | |
| global_weight = self.global_context(x) | |
| psi = psi * global_weight | |
| if psi.shape[2:] != x.shape[2:]: | |
| psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False) | |
| return x * psi | |
| class S2FGenerator(nn.Module): | |
| """ | |
| S2F (Shape2Force) model: U-Net generator for force map prediction. | |
| Supports substrate-specific settings as additional input channels. | |
| """ | |
| def __init__(self, | |
| in_channels=1, | |
| out_channels=1, | |
| img_size=1024, | |
| bridge_type='cbam', | |
| use_multi_scale_input=True): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.bridge_type = bridge_type | |
| self.use_multi_scale_input = use_multi_scale_input | |
| if self.use_multi_scale_input: | |
| self.scale_pyramid = nn.ModuleList([ | |
| nn.Conv2d(in_channels, 32, 3, padding=1), | |
| nn.Sequential( | |
| nn.AvgPool2d(2, stride=2), | |
| nn.Conv2d(in_channels, 32, 3, padding=1) | |
| ), | |
| nn.Sequential( | |
| nn.AvgPool2d(4, stride=4), | |
| nn.Conv2d(in_channels, 32, 3, padding=1) | |
| ) | |
| ]) | |
| self.initial_conv = nn.Conv2d(96, 64, 1) | |
| else: | |
| self.initial_conv = nn.Conv2d(in_channels, 64, 3, padding=1) | |
| def enhanced_conv_block(in_c, out_c, use_attention=True): | |
| layers = [ | |
| nn.Conv2d(in_c, out_c, 3, padding=1), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(inplace=True), | |
| ResidualBlock(out_c, out_c) | |
| ] | |
| if use_attention: | |
| layers.append(HierarchicalAttention(out_c)) | |
| return nn.Sequential(*layers) | |
| def dilated_conv_block(in_c, out_c, use_global_context=False): | |
| layers = [ | |
| nn.Conv2d(in_c, out_c, 3, padding=2, dilation=2), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(inplace=True), | |
| ResidualBlock(out_c, out_c) | |
| ] | |
| if use_global_context: | |
| layers.append(GlobalContextModule(out_c)) | |
| return nn.Sequential(*layers) | |
| self.encoder1 = enhanced_conv_block(64, 64, use_attention=False) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.encoder2 = enhanced_conv_block(64, 128, use_attention=True) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.encoder3 = dilated_conv_block(128, 256, use_global_context=True) | |
| self.pool3 = nn.MaxPool2d(2) | |
| self.encoder4 = dilated_conv_block(256, 512, use_global_context=True) | |
| self.pool4 = nn.MaxPool2d(2) | |
| if bridge_type == 'cbam': | |
| self.bridge = nn.Sequential( | |
| dilated_conv_block(512, 1024, use_global_context=True), | |
| CBAM(1024), | |
| GlobalContextModule(1024), | |
| HierarchicalAttention(1024) | |
| ) | |
| else: | |
| self.bridge = nn.Sequential( | |
| dilated_conv_block(512, 1024, use_global_context=True), | |
| GlobalContextModule(1024), | |
| HierarchicalAttention(1024) | |
| ) | |
| self.att4 = EnhancedAttentionGate(512, 512, 256) | |
| self.att3 = EnhancedAttentionGate(256, 256, 128) | |
| self.att2 = EnhancedAttentionGate(128, 128, 64) | |
| self.att1 = EnhancedAttentionGate(64, 64, 32) | |
| self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) | |
| self.dec4 = enhanced_conv_block(1024, 512, use_attention=True) | |
| self.refine4 = HierarchicalAttention(512) | |
| self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
| self.dec3 = enhanced_conv_block(512, 256, use_attention=True) | |
| self.refine3 = HierarchicalAttention(256) | |
| self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.dec2 = enhanced_conv_block(256, 128, use_attention=True) | |
| self.refine2 = HierarchicalAttention(128) | |
| self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
| self.dec1 = enhanced_conv_block(128, 64, use_attention=True) | |
| self.refine1 = HierarchicalAttention(64) | |
| self.final_conv = nn.Sequential( | |
| nn.Conv2d(64, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, out_channels, 1), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| if self.use_multi_scale_input: | |
| scale_features = [] | |
| for i, scale_conv in enumerate(self.scale_pyramid): | |
| if i == 0: | |
| scale_features.append(scale_conv(x)) | |
| else: | |
| scale_out = scale_conv(x) | |
| scale_out = F.interpolate(scale_out, size=x.shape[2:], mode='bilinear', align_corners=False) | |
| scale_features.append(scale_out) | |
| fused = torch.cat(scale_features, dim=1) | |
| initial_features = self.initial_conv(fused) | |
| else: | |
| initial_features = self.initial_conv(x) | |
| e1 = self.encoder1(initial_features) | |
| e2 = self.encoder2(self.pool1(e1)) | |
| e3 = self.encoder3(self.pool2(e2)) | |
| e4 = self.encoder4(self.pool3(e3)) | |
| b = self.bridge(self.pool4(e4)) | |
| g4 = self.up4(b) | |
| x4 = self.att4(g4, e4) | |
| d4 = self.dec4(torch.cat([g4, x4], dim=1)) | |
| d4 = self.refine4(d4) | |
| g3 = self.up3(d4) | |
| x3 = self.att3(g3, e3) | |
| d3 = self.dec3(torch.cat([g3, x3], dim=1)) | |
| d3 = self.refine3(d3) | |
| g2 = self.up2(d3) | |
| x2 = self.att2(g2, e2) | |
| d2 = self.dec2(torch.cat([g2, x2], dim=1)) | |
| d2 = self.refine2(d2) | |
| g1 = self.up1(d2) | |
| x1 = self.att1(g1, e1) | |
| d1 = self.dec1(torch.cat([g1, x1], dim=1)) | |
| d1 = self.refine1(d1) | |
| out = self.final_conv(d1) | |
| return out | |
| def load_checkpoint_with_expansion(self, checkpoint_path, strict=False): | |
| """Load checkpoint and expand from 1-channel to 3-channel if needed.""" | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) | |
| generator_state = checkpoint['generator_state_dict'] | |
| needs_expansion = False | |
| if 'scale_pyramid.0.weight' in generator_state: | |
| old_shape = generator_state['scale_pyramid.0.weight'].shape | |
| current_shape = self.scale_pyramid[0].weight.shape | |
| if old_shape[1] != current_shape[1]: | |
| needs_expansion = True | |
| elif 'initial_conv.weight' in generator_state: | |
| old_shape = generator_state['initial_conv.weight'].shape | |
| current_shape = self.initial_conv.weight.shape | |
| if old_shape[1] != current_shape[1]: | |
| needs_expansion = True | |
| if needs_expansion: | |
| generator_state = self._expand_generator_state(generator_state) | |
| self.load_state_dict(generator_state, strict=strict) | |
| return checkpoint | |
| def _expand_generator_state(self, generator_state): | |
| """Expand generator state dict from 1-channel to 3-channel input.""" | |
| expanded_state = generator_state.copy() | |
| if 'scale_pyramid.0.weight' in generator_state: | |
| for i in range(3): | |
| key = f'scale_pyramid.{i}.weight' if i == 0 else f'scale_pyramid.{i}.1.weight' | |
| if key in generator_state: | |
| old_weight = generator_state[key] | |
| new_weight = torch.zeros(32, 3, 3, 3) | |
| new_weight[:, 0:1, :, :] = old_weight | |
| expanded_state[key] = new_weight | |
| elif 'initial_conv.weight' in generator_state: | |
| old_weight = generator_state['initial_conv.weight'] | |
| new_weight = torch.zeros(64, 3, 3, 3) | |
| new_weight[:, 0:1, :, :] = old_weight | |
| expanded_state['initial_conv.weight'] = new_weight | |
| return expanded_state | |
| class SpheroidAttentionGate(nn.Module): | |
| """Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_*.pth.""" | |
| def __init__(self, F_g, F_l, F_int): | |
| super(SpheroidAttentionGate, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(F_int, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.psi(g1 + x1) | |
| return x * psi | |
| class S2FSpheroidGenerator(nn.Module): | |
| """ | |
| S2F model tuned for spheroid data. Uses sigmoid output [0, 1] for inference. | |
| """ | |
| def __init__(self, in_channels=1, out_channels=1, predict_numbers=False, img_size=1024, use_tanh_output=True): | |
| super(S2FSpheroidGenerator, self).__init__() | |
| self.predict_numbers = predict_numbers | |
| self.img_size = img_size | |
| self.use_tanh_output = use_tanh_output | |
| def conv_block(in_c, out_c): | |
| return nn.Sequential( | |
| nn.Conv2d(in_c, out_c, 3, padding=1), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(inplace=True), | |
| ResidualBlock(out_c, out_c) | |
| ) | |
| # Encoder | |
| self.encoder1 = conv_block(in_channels, 32) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.encoder2 = conv_block(32, 64) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.encoder3 = conv_block(64, 128) | |
| self.pool3 = nn.MaxPool2d(2) | |
| self.encoder4 = conv_block(128, 256) | |
| self.pool4 = nn.MaxPool2d(2) | |
| self.bridge = nn.Sequential( | |
| nn.Conv2d(256, 512, kernel_size=3, padding=2, dilation=2), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| ResidualBlock(512, 512) | |
| ) | |
| self.att3 = SpheroidAttentionGate(256, 256, 128) | |
| self.att2 = SpheroidAttentionGate(128, 128, 64) | |
| self.att1 = SpheroidAttentionGate(64, 64, 32) | |
| self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
| self.dec3 = conv_block(512, 256) | |
| self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.dec2 = conv_block(256, 128) | |
| self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
| self.dec1 = conv_block(128, 64) | |
| self.up0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) | |
| self.dec0 = conv_block(64, 32) | |
| self.pred_conv = nn.Conv2d(32, out_channels, kernel_size=1) | |
| def forward(self, x): | |
| e1 = self.encoder1(x) | |
| e2 = self.encoder2(self.pool1(e1)) | |
| e3 = self.encoder3(self.pool2(e2)) | |
| e4 = self.encoder4(self.pool3(e3)) | |
| b = self.bridge(self.pool4(e4)) | |
| g3 = self.up3(b) | |
| x3 = self.att3(g3, e4) | |
| d3 = self.dec3(torch.cat([g3, x3], dim=1)) | |
| g2 = self.up2(d3) | |
| x2 = self.att2(g2, e3) | |
| d2 = self.dec2(torch.cat([g2, x2], dim=1)) | |
| g1 = self.up1(d2) | |
| x1 = self.att1(g1, e2) | |
| d1 = self.dec1(torch.cat([g1, x1], dim=1)) | |
| g0 = self.up0(d1) | |
| d0 = self.dec0(torch.cat([g0, e1], dim=1)) | |
| out = self.pred_conv(d0) | |
| out_resized = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False) | |
| if self.use_tanh_output: | |
| return torch.tanh(out_resized) | |
| else: | |
| return torch.sigmoid(out_resized) | |
| def set_output_mode(self, use_tanh=True): | |
| """Set output activation: tanh [-1,1] for training, sigmoid [0,1] for inference.""" | |
| self.use_tanh_output = use_tanh | |
| class PatchGANDiscriminator(nn.Module): | |
| """PatchGAN Discriminator (included for create_s2f_model compatibility).""" | |
| def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): | |
| super().__init__() | |
| use_bias = norm_layer == nn.InstanceNorm2d | |
| self.initial_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ) | |
| self.layers = nn.ModuleList() | |
| nf_mult, nf_mult_prev = 1, 1 | |
| for n in range(1, n_layers): | |
| nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8) | |
| self.layers.append(nn.Sequential( | |
| nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1, bias=use_bias), | |
| norm_layer(ndf * nf_mult), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| )) | |
| nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8) | |
| self.layers.append(nn.Sequential( | |
| nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias), | |
| norm_layer(ndf * nf_mult), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| )) | |
| self.output_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1) | |
| self.attention = nn.Sequential( | |
| nn.Conv2d(ndf * nf_mult, ndf * nf_mult // 4, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(ndf * nf_mult // 4, ndf * nf_mult, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, input): | |
| x = self.initial_conv(input) | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = x * self.attention(x) | |
| return self.output_conv(x) | |
| def create_s2f_model( | |
| in_channels=1, | |
| out_channels=1, | |
| img_size=1024, | |
| bridge_type='cbam', | |
| use_multi_scale_input=True, | |
| ndf=64, | |
| n_layers=3, | |
| model_type='s2f', | |
| ): | |
| """Create S2F model with generator and discriminator. | |
| model_type: 's2f' for single-cell, 's2f_spheroid' for spheroid. | |
| """ | |
| if model_type == 's2f': | |
| generator = S2FGenerator( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| img_size=img_size, | |
| bridge_type=bridge_type, | |
| use_multi_scale_input=use_multi_scale_input, | |
| ) | |
| elif model_type == 's2f_spheroid': | |
| generator = S2FSpheroidGenerator( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| img_size=img_size, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid model type: {model_type}") | |
| discriminator = PatchGANDiscriminator( | |
| in_channels=in_channels + out_channels, | |
| ndf=ndf, | |
| n_layers=n_layers | |
| ) | |
| return generator, discriminator | |