| | """ |
| | S2F training logic: loss, metrics, and training loop. |
| | """ |
| | import os |
| | import sys |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import matplotlib.pyplot as plt |
| | from tqdm.auto import tqdm |
| |
|
| | |
| | S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| | if S2F_ROOT not in sys.path: |
| | sys.path.insert(0, S2F_ROOT) |
| |
|
| | from models.s2f_model import create_settings_channels |
| | from utils.substrate_settings import compute_settings_normalization |
| | from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation |
| | from scipy.stats import pearsonr |
| |
|
| |
|
| | class S2FLoss(nn.Module): |
| | """S2F loss: reconstruction (L1) + GAN + optional force consistency.""" |
| | def __init__(self, lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0, |
| | gan_mode='vanilla', custom_loss=None, use_force_consistency=False, |
| | force_consistency_target='mean'): |
| | super().__init__() |
| | self.lambda_L1 = lambda_L1 |
| | self.lambda_gan = lambda_gan |
| | self.lambda_force = lambda_force |
| | self.gan_mode = gan_mode |
| | self.use_force_consistency = use_force_consistency |
| | self.force_consistency_target = force_consistency_target |
| | self.reconstruction_loss = custom_loss if custom_loss is not None else nn.L1Loss() |
| | self.force_consistency_loss = nn.MSELoss() if use_force_consistency else None |
| | self.gan_loss = nn.BCEWithLogitsLoss() if gan_mode == 'vanilla' else nn.MSELoss() |
| |
|
| | def forward(self, pred, target, disc_pred=None, disc_target=None): |
| | recon_loss = self.reconstruction_loss(pred, target) |
| | gan_loss = 0.0 |
| | if disc_pred is not None and disc_target is not None: |
| | gan_loss = self.gan_loss(disc_pred, disc_target) |
| | force_loss = 0.0 |
| | if self.use_force_consistency and self.force_consistency_loss is not None: |
| | if self.force_consistency_target == 'mean': |
| | pred_global = torch.mean(pred.view(pred.size(0), -1), dim=1, keepdim=True) |
| | target_global = torch.mean(target.view(target.size(0), -1), dim=1, keepdim=True) |
| | else: |
| | pred_global = torch.sum(pred.view(pred.size(0), -1), dim=1, keepdim=True) |
| | target_global = torch.sum(target.view(target.size(0), -1), dim=1, keepdim=True) |
| | force_loss = self.force_consistency_loss(pred_global, target_global) |
| | total = self.lambda_L1 * recon_loss + self.lambda_gan * gan_loss + self.lambda_force * force_loss |
| | return total, recon_loss, gan_loss, force_loss |
| |
|
| |
|
| | def calculate_soft_dice_loss(pred, target, smooth=1e-6): |
| | """Dice score (higher is better).""" |
| | pred_flat = pred.view(pred.size(0), -1) |
| | target_flat = target.view(target.size(0), -1) |
| | intersection = (pred_flat * target_flat).sum(dim=1) |
| | dice_scores = (2.0 * intersection + smooth) / (pred_flat.sum(dim=1) + target_flat.sum(dim=1) + smooth) |
| | return dice_scores.mean().item() |
| |
|
| |
|
| | def train_s2f(generator, discriminator, train_loader, val_loader, device='cuda', |
| | num_epochs=100, g_lr=2e-4, d_lr=2e-4, beta1=0.5, beta2=0.999, |
| | save_dir='ckp', lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0, |
| | gan_mode='vanilla', save_predictions_every=5, custom_loss=None, |
| | loaded_metadata=False, use_settings=False, use_force_consistency=False, |
| | force_consistency_target='mean', config_path=None): |
| | """ |
| | Train S2F model. |
| | """ |
| | from diffusers.optimization import get_cosine_schedule_with_warmup |
| |
|
| | config_path = config_path or os.path.join(S2F_ROOT, 'config', 'substrate_settings.json') |
| | normalization_params = None |
| | if use_settings: |
| | if not loaded_metadata: |
| | raise ValueError("loaded_metadata must be True when use_settings=True") |
| | normalization_params = compute_settings_normalization(config_path=config_path) |
| |
|
| | history = {'g_loss': [], 'd_loss': [], 'g_recon_loss': [], 'g_gan_loss': [], 'g_force_loss': [], |
| | 'train_loss': [], 'val_loss': [], 'train_ssim': [], 'val_ssim': [], |
| | 'train_psnr': [], 'val_psnr': [], 'train_mse': [], 'val_mse': [], |
| | 'train_dice_score': [], 'val_dice_score': []} |
| |
|
| | if not torch.cuda.is_available() and device == 'cuda': |
| | device = 'cpu' |
| | generator = generator.to(device) |
| | discriminator = discriminator.to(device) |
| | criterion = S2FLoss(lambda_L1=lambda_L1, lambda_gan=lambda_gan, lambda_force=lambda_force, |
| | gan_mode=gan_mode, custom_loss=custom_loss, |
| | use_force_consistency=use_force_consistency, |
| | force_consistency_target=force_consistency_target) |
| | g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(beta1, beta2)) |
| | d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(beta1, beta2)) |
| | num_steps = len(train_loader) * num_epochs |
| | g_scheduler = get_cosine_schedule_with_warmup(g_optimizer, int(num_steps * 0.1), num_steps) |
| | d_scheduler = get_cosine_schedule_with_warmup(d_optimizer, int(num_steps * 0.1), num_steps) |
| |
|
| | os.makedirs(save_dir, exist_ok=True) |
| | vis_dir = os.path.join(save_dir, 'visualizations') |
| | os.makedirs(vis_dir, exist_ok=True) |
| | best_val_loss = float('inf') |
| | disc_output_shape = None |
| |
|
| | for epoch in range(num_epochs): |
| | generator.train() |
| | discriminator.train() |
| | g_loss_total = d_loss_total = g_recon_total = g_gan_total = g_force_total = 0.0 |
| | train_ssim = train_psnr = train_mse = train_dice = 0.0 |
| | pbar = tqdm(train_loader, desc=f'Epoch {epoch}') |
| |
|
| | for batch_data in pbar: |
| | if loaded_metadata: |
| | input_images, target_images, _, _, metadata = batch_data |
| | else: |
| | input_images, target_images, _, _ = batch_data |
| | input_images = input_images.to(device, dtype=torch.float32) |
| | target_images = target_images.to(device, dtype=torch.float32) |
| | batch_size = input_images.size(0) |
| |
|
| | if use_settings and normalization_params is not None: |
| | settings_channels = create_settings_channels( |
| | metadata, normalization_params, device, input_images.shape, |
| | config_path=config_path |
| | ) |
| | input_images = torch.cat([input_images, settings_channels], dim=1) |
| |
|
| | target_scaled = target_images * 2.0 - 1.0 |
| | if disc_output_shape is None: |
| | with torch.no_grad(): |
| | dummy = torch.cat([input_images[:1], target_scaled[:1]], dim=1) |
| | disc_output_shape = discriminator(dummy).shape[2:] |
| | real_labels = torch.ones(batch_size, 1, *disc_output_shape).to(device) |
| | fake_labels = torch.zeros(batch_size, 1, *disc_output_shape).to(device) |
| |
|
| | g_optimizer.zero_grad() |
| | fake_images = generator(input_images) |
| | fake_for_loss = (fake_images + 1.0) / 2.0 |
| | fake_input = torch.cat([input_images, fake_images], dim=1) |
| | fake_pred = discriminator(fake_input) |
| | g_loss, g_recon, g_gan, g_force = criterion(fake_for_loss, target_images, fake_pred, real_labels) |
| | g_loss.backward() |
| | g_optimizer.step() |
| |
|
| | d_optimizer.zero_grad() |
| | real_input = torch.cat([input_images, target_scaled], dim=1) |
| | real_pred = discriminator(real_input) |
| | d_real = criterion.gan_loss(real_pred, real_labels) |
| | fake_input_d = torch.cat([input_images, fake_images.detach()], dim=1) |
| | fake_pred_d = discriminator(fake_input_d) |
| | d_fake = criterion.gan_loss(fake_pred_d, fake_labels) |
| | d_loss = (d_real + d_fake) * 0.5 |
| | d_loss.backward() |
| | d_optimizer.step() |
| | g_scheduler.step() |
| | d_scheduler.step() |
| |
|
| | g_loss_total += g_loss.item() |
| | d_loss_total += d_loss.item() |
| | g_recon_total += g_recon.item() |
| | g_gan_total += g_gan.item() |
| | g_force_total += g_force.item() if isinstance(g_force, torch.Tensor) else g_force |
| | train_ssim += calculate_ssim_tensor(fake_for_loss, target_images) |
| | train_psnr += calculate_psnr(fake_for_loss, target_images) |
| | train_mse += F.mse_loss(fake_for_loss, target_images).item() |
| | train_dice += calculate_soft_dice_loss(fake_for_loss, target_images) |
| | pbar.set_postfix({'G': g_loss.item(), |
| | 'D': d_loss.item(), 'Dice': train_dice / (pbar.n + 1)}) |
| |
|
| | n_train = len(train_loader) |
| | g_loss_total /= n_train |
| | d_loss_total /= n_train |
| | train_ssim /= n_train |
| | train_psnr /= n_train |
| | train_mse /= n_train |
| | train_dice /= n_train |
| |
|
| | generator.eval() |
| | val_loss = val_ssim = val_psnr = val_mse = val_dice = 0.0 |
| | with torch.no_grad(): |
| | for batch_data in val_loader: |
| | if loaded_metadata: |
| | input_images, target_images, _, _, metadata = batch_data |
| | else: |
| | input_images, target_images, _, _ = batch_data |
| | input_images = input_images.to(device, dtype=torch.float32) |
| | target_images = target_images.to(device, dtype=torch.float32) |
| | if use_settings and normalization_params is not None: |
| | settings_channels = create_settings_channels( |
| | metadata, normalization_params, device, input_images.shape, |
| | config_path=config_path |
| | ) |
| | input_images = torch.cat([input_images, settings_channels], dim=1) |
| | fake_images = generator(input_images) |
| | fake_for_loss = (fake_images + 1.0) / 2.0 |
| | _, recon_loss, _, force_loss = criterion(fake_for_loss, target_images) |
| | val_loss += recon_loss.item() |
| | val_ssim += calculate_ssim_tensor(fake_for_loss, target_images) |
| | val_psnr += calculate_psnr(fake_for_loss, target_images) |
| | val_mse += F.mse_loss(fake_for_loss, target_images).item() |
| | val_dice += calculate_soft_dice_loss(fake_for_loss, target_images) |
| | n_val = len(val_loader) |
| | val_loss /= n_val |
| | val_ssim /= n_val |
| | val_psnr /= n_val |
| | val_mse /= n_val |
| | val_dice /= n_val |
| |
|
| | history['g_loss'].append(g_loss_total) |
| | history['d_loss'].append(d_loss_total) |
| | history['train_loss'].append(g_loss_total) |
| | history['val_loss'].append(val_loss) |
| | history['train_ssim'].append(train_ssim) |
| | history['val_ssim'].append(val_ssim) |
| | history['train_psnr'].append(train_psnr) |
| | history['val_psnr'].append(val_psnr) |
| | history['train_mse'].append(train_mse) |
| | history['val_mse'].append(val_mse) |
| | history['train_dice_score'].append(train_dice) |
| | history['val_dice_score'].append(val_dice) |
| |
|
| | best_mark = "✓" if val_loss < best_val_loss else "" |
| | print(f"Train: G_Loss:{g_loss_total:.4f} D_Loss:{d_loss_total:.4f} " |
| | f"MSE:{train_mse:.4f} SSIM:{train_ssim:.4f} Dice:{train_dice:.4f}") |
| | print(f"Valid: Loss:{val_loss:.4f} MSE:{val_mse:.4f} SSIM:{val_ssim:.4f} Dice:{val_dice:.4f} {best_mark}") |
| |
|
| | checkpoint = { |
| | 'epoch': epoch, |
| | 'generator_state_dict': generator.state_dict(), |
| | 'discriminator_state_dict': discriminator.state_dict(), |
| | 'g_optimizer_state_dict': g_optimizer.state_dict(), |
| | 'd_optimizer_state_dict': d_optimizer.state_dict(), |
| | 'val_loss': val_loss, |
| | 'history': history |
| | } |
| | torch.save(checkpoint, os.path.join(save_dir, 'last_checkpoint.pth')) |
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | torch.save(checkpoint, os.path.join(save_dir, 'best_checkpoint.pth')) |
| |
|
| | if epoch % save_predictions_every == 0: |
| | generator.eval() |
| | with torch.no_grad(): |
| | batch_data = next(iter(val_loader)) |
| | if loaded_metadata: |
| | input_images, target_images, _, _, metadata = batch_data |
| | else: |
| | input_images, target_images, _, _ = batch_data |
| | input_images = input_images.to(device, dtype=torch.float32) |
| | target_images = target_images.to(device, dtype=torch.float32) |
| | if use_settings and normalization_params is not None: |
| | settings_channels = create_settings_channels( |
| | metadata, normalization_params, device, input_images.shape, |
| | config_path=config_path |
| | ) |
| | input_images = torch.cat([input_images, settings_channels], dim=1) |
| | fake_images = generator(input_images) |
| | fake_vis = (fake_images + 1.0) / 2.0 |
| | n_vis = min(4, input_images.size(0)) |
| | fig, axes = plt.subplots(3, n_vis, figsize=(4 * n_vis, 12)) |
| | if n_vis == 1: |
| | axes = axes.reshape(3, 1) |
| | for i in range(n_vis): |
| | axes[0, i].imshow(input_images[i, 0].cpu().numpy(), cmap='gray') |
| | axes[0, i].axis('off') |
| | axes[1, i].imshow(fake_vis[i, 0].cpu().numpy(), cmap='jet', vmin=0, vmax=1) |
| | axes[1, i].axis('off') |
| | axes[2, i].imshow(target_images[i, 0].cpu().numpy(), cmap='jet', vmin=0, vmax=1) |
| | axes[2, i].axis('off') |
| | plt.tight_layout() |
| | plt.savefig(os.path.join(vis_dir, f'predictions_epoch_{epoch:02d}.png'), dpi=150, bbox_inches='tight') |
| | plt.close() |
| | print(f"Saved visualization for epoch {epoch}") |
| |
|
| | return history |
| |
|