kaveh commited on
Commit
e732653
·
1 Parent(s): 96605dc

adapted for both singlecell and spheroid

Browse files
Files changed (1) hide show
  1. models/s2f_model.py +211 -105
models/s2f_model.py CHANGED
@@ -8,12 +8,7 @@ import torch.nn.functional as F
8
  from .blocks import ResidualBlock
9
  from .cbam import CBAM
10
 
11
- from utils import config
12
- from utils.substrate_settings import (
13
- get_settings_of_category,
14
- compute_settings_normalization,
15
- load_substrate_config,
16
- )
17
 
18
 
19
  def normalize_settings(substrate_name, normalization_params, config=None, config_path=None):
@@ -38,7 +33,6 @@ def normalize_settings(substrate_name, normalization_params, config=None, config
38
 
39
  return pixelsize_norm, young_norm
40
 
41
-
42
  def create_settings_channels(metadata, normalization_params, device, image_shape, config_path=None):
43
  """
44
  Create settings channels for a batch of images.
@@ -73,9 +67,8 @@ def create_settings_channels(metadata, normalization_params, device, image_shape
73
 
74
  return settings_channels
75
 
76
-
77
  class GlobalContextModule(nn.Module):
78
- """Global context module for capturing cell shape information"""
79
  def __init__(self, in_channels):
80
  super().__init__()
81
  self.global_pool = nn.AdaptiveAvgPool2d(1)
@@ -110,9 +103,8 @@ class GlobalContextModule(nn.Module):
110
  multi_scale_out = self.fusion(multi_scale_out)
111
  return x + (large_features * global_weight) + multi_scale_out
112
 
113
-
114
  class HierarchicalAttention(nn.Module):
115
- """Hierarchical attention combining spatial and channel attention"""
116
  def __init__(self, channels):
117
  super().__init__()
118
  self.spatial_att = nn.Sequential(
@@ -142,9 +134,8 @@ class HierarchicalAttention(nn.Module):
142
  cross_weight = self.cross_att(attended)
143
  return x + (attended * cross_weight)
144
 
145
-
146
- class EnhancedAttentionGate(nn.Module):
147
- """Enhanced attention gate with global context"""
148
  def __init__(self, F_g, F_l, F_int):
149
  super().__init__()
150
  self.W_g = nn.Sequential(
@@ -184,6 +175,70 @@ class EnhancedAttentionGate(nn.Module):
184
  return x * psi
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  class S2FGenerator(nn.Module):
188
  """
189
  S2F (Shape2Force) model: U-Net generator for force map prediction.
@@ -217,7 +272,7 @@ class S2FGenerator(nn.Module):
217
  else:
218
  self.initial_conv = nn.Conv2d(in_channels, 64, 3, padding=1)
219
 
220
- def enhanced_conv_block(in_c, out_c, use_attention=True):
221
  layers = [
222
  nn.Conv2d(in_c, out_c, 3, padding=1),
223
  nn.BatchNorm2d(out_c),
@@ -239,9 +294,9 @@ class S2FGenerator(nn.Module):
239
  layers.append(GlobalContextModule(out_c))
240
  return nn.Sequential(*layers)
241
 
242
- self.encoder1 = enhanced_conv_block(64, 64, use_attention=False)
243
  self.pool1 = nn.MaxPool2d(2)
244
- self.encoder2 = enhanced_conv_block(64, 128, use_attention=True)
245
  self.pool2 = nn.MaxPool2d(2)
246
  self.encoder3 = dilated_conv_block(128, 256, use_global_context=True)
247
  self.pool3 = nn.MaxPool2d(2)
@@ -262,22 +317,22 @@ class S2FGenerator(nn.Module):
262
  HierarchicalAttention(1024)
263
  )
264
 
265
- self.att4 = EnhancedAttentionGate(512, 512, 256)
266
- self.att3 = EnhancedAttentionGate(256, 256, 128)
267
- self.att2 = EnhancedAttentionGate(128, 128, 64)
268
- self.att1 = EnhancedAttentionGate(64, 64, 32)
269
 
270
  self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
271
- self.dec4 = enhanced_conv_block(1024, 512, use_attention=True)
272
  self.refine4 = HierarchicalAttention(512)
273
  self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
274
- self.dec3 = enhanced_conv_block(512, 256, use_attention=True)
275
  self.refine3 = HierarchicalAttention(256)
276
  self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
277
- self.dec2 = enhanced_conv_block(256, 128, use_attention=True)
278
  self.refine2 = HierarchicalAttention(128)
279
  self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
280
- self.dec1 = enhanced_conv_block(128, 64, use_attention=True)
281
  self.refine1 = HierarchicalAttention(64)
282
 
283
  self.final_conv = nn.Sequential(
@@ -328,87 +383,128 @@ class S2FGenerator(nn.Module):
328
  out = self.final_conv(d1)
329
  return out
330
 
331
- def load_checkpoint_with_expansion(self, checkpoint_path, strict=False):
332
- """Load checkpoint and expand from 1-channel to 3-channel if needed."""
333
- checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
334
- generator_state = checkpoint['generator_state_dict']
335
- needs_expansion = False
336
-
337
- if 'scale_pyramid.0.weight' in generator_state:
338
- old_shape = generator_state['scale_pyramid.0.weight'].shape
339
- current_shape = self.scale_pyramid[0].weight.shape
340
- if old_shape[1] != current_shape[1]:
341
- needs_expansion = True
342
- elif 'initial_conv.weight' in generator_state:
343
- old_shape = generator_state['initial_conv.weight'].shape
344
- current_shape = self.initial_conv.weight.shape
345
- if old_shape[1] != current_shape[1]:
346
- needs_expansion = True
347
-
348
- if needs_expansion:
349
- generator_state = self._expand_generator_state(generator_state)
350
-
351
- self.load_state_dict(generator_state, strict=strict)
352
- return checkpoint
353
-
354
- def _expand_generator_state(self, generator_state):
355
- """Expand generator state dict from 1-channel to 3-channel input."""
356
- expanded_state = generator_state.copy()
357
- if 'scale_pyramid.0.weight' in generator_state:
358
- for i in range(3):
359
- key = f'scale_pyramid.{i}.weight' if i == 0 else f'scale_pyramid.{i}.1.weight'
360
- if key in generator_state:
361
- old_weight = generator_state[key]
362
- new_weight = torch.zeros(32, 3, 3, 3)
363
- new_weight[:, 0:1, :, :] = old_weight
364
- expanded_state[key] = new_weight
365
- elif 'initial_conv.weight' in generator_state:
366
- old_weight = generator_state['initial_conv.weight']
367
- new_weight = torch.zeros(64, 3, 3, 3)
368
- new_weight[:, 0:1, :, :] = old_weight
369
- expanded_state['initial_conv.weight'] = new_weight
370
- return expanded_state
371
 
 
 
 
 
 
 
 
 
 
372
 
373
- class PatchGANDiscriminator(nn.Module):
374
- """PatchGAN Discriminator (included for create_s2f_model compatibility)."""
375
- def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
376
- super().__init__()
377
- use_bias = norm_layer == nn.InstanceNorm2d
378
- self.initial_conv = nn.Sequential(
379
- nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias),
380
- nn.LeakyReLU(0.2, inplace=True)
381
- )
382
- self.layers = nn.ModuleList()
383
- nf_mult, nf_mult_prev = 1, 1
384
- for n in range(1, n_layers):
385
- nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8)
386
- self.layers.append(nn.Sequential(
387
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1, bias=use_bias),
388
- norm_layer(ndf * nf_mult),
389
- nn.LeakyReLU(0.2, inplace=True)
390
- ))
391
- nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8)
392
- self.layers.append(nn.Sequential(
393
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias),
394
- norm_layer(ndf * nf_mult),
395
- nn.LeakyReLU(0.2, inplace=True)
396
- ))
397
- self.output_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
398
- self.attention = nn.Sequential(
399
- nn.Conv2d(ndf * nf_mult, ndf * nf_mult // 4, 1),
400
- nn.ReLU(inplace=True),
401
- nn.Conv2d(ndf * nf_mult // 4, ndf * nf_mult, 1),
402
- nn.Sigmoid()
403
- )
404
-
405
- def forward(self, input):
406
- x = self.initial_conv(input)
407
- for layer in self.layers:
408
- x = layer(x)
409
- x = x * self.attention(x)
410
- return self.output_conv(x)
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  def create_s2f_model(
414
  in_channels=1,
@@ -418,15 +514,25 @@ def create_s2f_model(
418
  use_multi_scale_input=True,
419
  ndf=64,
420
  n_layers=3,
 
421
  ):
422
  """Create S2F model with generator and discriminator."""
423
- generator = S2FGenerator(
 
424
  in_channels=in_channels,
425
  out_channels=out_channels,
426
  img_size=img_size,
427
  bridge_type=bridge_type,
428
  use_multi_scale_input=use_multi_scale_input,
429
- )
 
 
 
 
 
 
 
 
430
  discriminator = PatchGANDiscriminator(
431
  in_channels=in_channels + out_channels,
432
  ndf=ndf,
 
8
  from .blocks import ResidualBlock
9
  from .cbam import CBAM
10
 
11
+ from utils.substrate_settings import get_settings_of_category
 
 
 
 
 
12
 
13
 
14
  def normalize_settings(substrate_name, normalization_params, config=None, config_path=None):
 
33
 
34
  return pixelsize_norm, young_norm
35
 
 
36
  def create_settings_channels(metadata, normalization_params, device, image_shape, config_path=None):
37
  """
38
  Create settings channels for a batch of images.
 
67
 
68
  return settings_channels
69
 
 
70
  class GlobalContextModule(nn.Module):
71
+ """A module for capturing cell shape information"""
72
  def __init__(self, in_channels):
73
  super().__init__()
74
  self.global_pool = nn.AdaptiveAvgPool2d(1)
 
103
  multi_scale_out = self.fusion(multi_scale_out)
104
  return x + (large_features * global_weight) + multi_scale_out
105
 
 
106
  class HierarchicalAttention(nn.Module):
107
+ """A module for combining spatial and channel attention"""
108
  def __init__(self, channels):
109
  super().__init__()
110
  self.spatial_att = nn.Sequential(
 
134
  cross_weight = self.cross_att(attended)
135
  return x + (attended * cross_weight)
136
 
137
+ class AttentionGate(nn.Module):
138
+ """Attention gate with global context"""
 
139
  def __init__(self, F_g, F_l, F_int):
140
  super().__init__()
141
  self.W_g = nn.Sequential(
 
175
  return x * psi
176
 
177
 
178
+ class SpheroidAttentionGate(nn.Module):
179
+ """Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_FN.pth."""
180
+ def __init__(self, F_g, F_l, F_int):
181
+ super(SpheroidAttentionGate, self).__init__()
182
+ self.W_g = nn.Sequential(
183
+ nn.Conv2d(F_g, F_int, kernel_size=1),
184
+ nn.BatchNorm2d(F_int)
185
+ )
186
+ self.W_x = nn.Sequential(
187
+ nn.Conv2d(F_l, F_int, kernel_size=1),
188
+ nn.BatchNorm2d(F_int)
189
+ )
190
+ self.psi = nn.Sequential(
191
+ nn.ReLU(inplace=True),
192
+ nn.Conv2d(F_int, 1, kernel_size=1),
193
+ nn.Sigmoid()
194
+ )
195
+
196
+ def forward(self, g, x):
197
+ g1 = self.W_g(g)
198
+ x1 = self.W_x(x)
199
+ psi = self.psi(g1 + x1)
200
+ return x * psi
201
+
202
+ class PatchGANDiscriminator(nn.Module):
203
+ """PatchGAN Discriminator (included for create_s2f_model compatibility)."""
204
+ def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
205
+ super().__init__()
206
+ use_bias = norm_layer == nn.InstanceNorm2d
207
+ self.initial_conv = nn.Sequential(
208
+ nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias),
209
+ nn.LeakyReLU(0.2, inplace=True)
210
+ )
211
+ self.layers = nn.ModuleList()
212
+ nf_mult, nf_mult_prev = 1, 1
213
+ for n in range(1, n_layers):
214
+ nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8)
215
+ self.layers.append(nn.Sequential(
216
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1, bias=use_bias),
217
+ norm_layer(ndf * nf_mult),
218
+ nn.LeakyReLU(0.2, inplace=True)
219
+ ))
220
+ nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8)
221
+ self.layers.append(nn.Sequential(
222
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias),
223
+ norm_layer(ndf * nf_mult),
224
+ nn.LeakyReLU(0.2, inplace=True)
225
+ ))
226
+ self.output_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
227
+ self.attention = nn.Sequential(
228
+ nn.Conv2d(ndf * nf_mult, ndf * nf_mult // 4, 1),
229
+ nn.ReLU(inplace=True),
230
+ nn.Conv2d(ndf * nf_mult // 4, ndf * nf_mult, 1),
231
+ nn.Sigmoid()
232
+ )
233
+
234
+ def forward(self, input):
235
+ x = self.initial_conv(input)
236
+ for layer in self.layers:
237
+ x = layer(x)
238
+ x = x * self.attention(x)
239
+ return self.output_conv(x)
240
+
241
+
242
  class S2FGenerator(nn.Module):
243
  """
244
  S2F (Shape2Force) model: U-Net generator for force map prediction.
 
272
  else:
273
  self.initial_conv = nn.Conv2d(in_channels, 64, 3, padding=1)
274
 
275
+ def reg_conv_block(in_c, out_c, use_attention=True):
276
  layers = [
277
  nn.Conv2d(in_c, out_c, 3, padding=1),
278
  nn.BatchNorm2d(out_c),
 
294
  layers.append(GlobalContextModule(out_c))
295
  return nn.Sequential(*layers)
296
 
297
+ self.encoder1 = reg_conv_block(64, 64, use_attention=False)
298
  self.pool1 = nn.MaxPool2d(2)
299
+ self.encoder2 = reg_conv_block(64, 128, use_attention=True)
300
  self.pool2 = nn.MaxPool2d(2)
301
  self.encoder3 = dilated_conv_block(128, 256, use_global_context=True)
302
  self.pool3 = nn.MaxPool2d(2)
 
317
  HierarchicalAttention(1024)
318
  )
319
 
320
+ self.att4 = AttentionGate(512, 512, 256)
321
+ self.att3 = AttentionGate(256, 256, 128)
322
+ self.att2 = AttentionGate(128, 128, 64)
323
+ self.att1 = AttentionGate(64, 64, 32)
324
 
325
  self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
326
+ self.dec4 = reg_conv_block(1024, 512, use_attention=True)
327
  self.refine4 = HierarchicalAttention(512)
328
  self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
329
+ self.dec3 = reg_conv_block(512, 256, use_attention=True)
330
  self.refine3 = HierarchicalAttention(256)
331
  self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
332
+ self.dec2 = reg_conv_block(256, 128, use_attention=True)
333
  self.refine2 = HierarchicalAttention(128)
334
  self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
335
+ self.dec1 = reg_conv_block(128, 64, use_attention=True)
336
  self.refine1 = HierarchicalAttention(64)
337
 
338
  self.final_conv = nn.Sequential(
 
383
  out = self.final_conv(d1)
384
  return out
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ class S2FSpheroidGenerator(nn.Module):
388
+ """
389
+ A s2f model with some tunings for spheroid data
390
+ """
391
+ def __init__(self, in_channels=1, out_channels=1, predict_numbers=False, img_size=1024, use_tanh_output=True):
392
+ super(S2FSpheroidGenerator, self).__init__()
393
+ self.predict_numbers = predict_numbers
394
+ self.img_size = img_size
395
+ self.use_tanh_output = use_tanh_output
396
 
397
+ def conv_block(in_c, out_c):
398
+ return nn.Sequential(
399
+ nn.Conv2d(in_c, out_c, 3, padding=1),
400
+ nn.BatchNorm2d(out_c),
401
+ nn.ReLU(inplace=True),
402
+ ResidualBlock(out_c, out_c)
403
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
+ # Encoder
406
+ self.encoder1 = conv_block(in_channels, 32) # [B, 32, 1024, 1024]
407
+ self.pool1 = nn.MaxPool2d(2) # [B, 32, 512, 512]
408
+ self.encoder2 = conv_block(32, 64) # [B, 64, 512, 512]
409
+ self.pool2 = nn.MaxPool2d(2) # [B, 64, 256, 256]
410
+ self.encoder3 = conv_block(64, 128) # [B, 128, 256, 256]
411
+ self.pool3 = nn.MaxPool2d(2) # [B, 128, 128, 128]
412
+ self.encoder4 = conv_block(128, 256) # [B, 256, 128, 128]
413
+ self.pool4 = nn.MaxPool2d(2) # [B, 256, 64, 64]
414
+ self.bridge = nn.Sequential(
415
+ nn.Conv2d(256, 512, kernel_size=3, padding=2, dilation=2),
416
+ nn.BatchNorm2d(512),
417
+ nn.ReLU(),
418
+ ResidualBlock(512, 512)
419
+ ) # [B, 512, 64, 64]
420
+
421
+ # Attention Gates (SpheroidAttentionGate from s2f_spheroid, matches ckp_spheroid_FN.pth)
422
+ self.att3 = SpheroidAttentionGate(256, 256, 128)
423
+ self.att2 = SpheroidAttentionGate(128, 128, 64)
424
+ self.att1 = SpheroidAttentionGate(64, 64, 32)
425
+
426
+ # Decoder
427
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) # [B, 256, 128, 128]
428
+ self.dec3 = conv_block(512, 256) # [B, 256, 128, 128]
429
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) # [B, 128, 256, 256]
430
+ self.dec2 = conv_block(256, 128) # [B, 128, 256, 256]
431
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) # [B, 64, 512, 512]
432
+ self.dec1 = conv_block(128, 64) # [B, 64, 512, 512]
433
+ self.up0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) # [B, 32, 1024, 1024]
434
+ self.dec0 = conv_block(64, 32) # [B, 32, 1024, 1024]
435
+
436
+ # Final prediction
437
+ self.pred_conv = nn.Conv2d(32, out_channels, kernel_size=1) # [B, 1, 1024, 1024]
438
+
439
+ def forward(self, x): # Input: [B, 1, 1024, 1024]
440
+ # Encoder
441
+ e1 = self.encoder1(x) # [B, 32, 1024, 1024]
442
+ e2 = self.encoder2(self.pool1(e1)) # [B, 64, 512, 512]
443
+ e3 = self.encoder3(self.pool2(e2)) # [B, 128, 256, 256]
444
+ e4 = self.encoder4(self.pool3(e3)) # [B, 256, 128, 128]
445
+ b = self.bridge(self.pool4(e4)) # [B, 512, 64, 64]
446
+
447
+ # Decoder + Attention
448
+ g3 = self.up3(b) # [B, 256, 128, 128]
449
+ x3 = self.att3(g3, e4) # [B, 256, 128, 128]
450
+ d3 = self.dec3(torch.cat([g3, x3], dim=1)) # [B, 256, 128, 128]
451
+
452
+ g2 = self.up2(d3) # [B, 128, 256, 256]
453
+ x2 = self.att2(g2, e3) # [B, 128, 256, 256]
454
+ d2 = self.dec2(torch.cat([g2, x2], dim=1)) # [B, 128, 256, 256]
455
+
456
+ g1 = self.up1(d2) # [B, 64, 512, 512]
457
+ x1 = self.att1(g1, e2) # [B, 64, 512, 512]
458
+ d1 = self.dec1(torch.cat([g1, x1], dim=1)) # [B, 64, 512, 512]
459
+
460
+ g0 = self.up0(d1) # [B, 32, 1024, 1024]
461
+ d0 = self.dec0(torch.cat([g0, e1], dim=1)) # [B, 32, 1024, 1024]
462
+
463
+ out = self.pred_conv(d0) # [B, 1, 1024, 1024]
464
+ out_resized = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
465
+
466
+ if self.use_tanh_output:
467
+ return torch.tanh(out_resized) # [-1, 1] for Pix2Pix training
468
+ else:
469
+ return torch.sigmoid(out_resized) # [0, 1] for direct inference
470
+
471
+ def predict(self, loader):
472
+ """
473
+ Predict on the first batch from the loader
474
+ """
475
+ self.eval()
476
+ with torch.no_grad():
477
+ # Get first batch from loader
478
+ batch = next(iter(loader))
479
+ input_images, ground_truth_heatmaps, _, _ = batch # Ignore cell_area and cell_force
480
+
481
+ # Move to same device as model
482
+ device = next(self.parameters()).device
483
+ input_images = input_images.to(device)
484
+ ground_truth_heatmaps = ground_truth_heatmaps.to(device)
485
+
486
+ # Get predictions
487
+ predicted_heatmaps = self(input_images)
488
+
489
+ if self.use_tanh_output:
490
+ predicted_heatmaps = (predicted_heatmaps + 1.0) / 2.0
491
+
492
+ return input_images, ground_truth_heatmaps, predicted_heatmaps
493
+
494
+
495
+ def set_output_mode(self, use_tanh=True):
496
+ """
497
+ Set the output activation mode
498
+
499
+ Args:
500
+ use_tanh: If True, use tanh output [-1, 1] for GAN training
501
+ If False, use sigmoid output [0, 1] for direct inference
502
+ """
503
+ self.use_tanh_output = use_tanh
504
+ if use_tanh:
505
+ print("Generator set to tanh output mode [-1, 1] for GAN training")
506
+ else:
507
+ print("Generator set to sigmoid output mode [0, 1] for inference/evaluation")
508
 
509
  def create_s2f_model(
510
  in_channels=1,
 
514
  use_multi_scale_input=True,
515
  ndf=64,
516
  n_layers=3,
517
+ model_type='s2f',
518
  ):
519
  """Create S2F model with generator and discriminator."""
520
+ if model_type == 's2f':
521
+ generator = S2FGenerator(
522
  in_channels=in_channels,
523
  out_channels=out_channels,
524
  img_size=img_size,
525
  bridge_type=bridge_type,
526
  use_multi_scale_input=use_multi_scale_input,
527
+ )
528
+ elif model_type == 's2f_spheroid':
529
+ generator = S2FSpheroidGenerator(
530
+ in_channels=in_channels,
531
+ out_channels=out_channels,
532
+ img_size=img_size,
533
+ )
534
+ else:
535
+ raise ValueError(f"Invalid model type: {model_type}")
536
  discriminator = PatchGANDiscriminator(
537
  in_channels=in_channels + out_channels,
538
  ndf=ndf,