Shape2Force / training /s2f_trainer.py
kaveh's picture
updated
54e160a
"""
S2F training logic: loss, metrics, and training loop.
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# Add S2F root to path
S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if S2F_ROOT not in sys.path:
sys.path.insert(0, S2F_ROOT)
from models.s2f_model import create_settings_channels
from utils.substrate_settings import compute_settings_normalization
from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation
from scipy.stats import pearsonr
class S2FLoss(nn.Module):
"""S2F loss: reconstruction (L1) + GAN + optional force consistency."""
def __init__(self, lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0,
gan_mode='vanilla', custom_loss=None, use_force_consistency=False,
force_consistency_target='mean'):
super().__init__()
self.lambda_L1 = lambda_L1
self.lambda_gan = lambda_gan
self.lambda_force = lambda_force
self.gan_mode = gan_mode
self.use_force_consistency = use_force_consistency
self.force_consistency_target = force_consistency_target
self.reconstruction_loss = custom_loss if custom_loss is not None else nn.L1Loss()
self.force_consistency_loss = nn.MSELoss() if use_force_consistency else None
self.gan_loss = nn.BCEWithLogitsLoss() if gan_mode == 'vanilla' else nn.MSELoss()
def forward(self, pred, target, disc_pred=None, disc_target=None):
recon_loss = self.reconstruction_loss(pred, target)
gan_loss = 0.0
if disc_pred is not None and disc_target is not None:
gan_loss = self.gan_loss(disc_pred, disc_target)
force_loss = 0.0
if self.use_force_consistency and self.force_consistency_loss is not None:
if self.force_consistency_target == 'mean':
pred_global = torch.mean(pred.view(pred.size(0), -1), dim=1, keepdim=True)
target_global = torch.mean(target.view(target.size(0), -1), dim=1, keepdim=True)
else:
pred_global = torch.sum(pred.view(pred.size(0), -1), dim=1, keepdim=True)
target_global = torch.sum(target.view(target.size(0), -1), dim=1, keepdim=True)
force_loss = self.force_consistency_loss(pred_global, target_global)
total = self.lambda_L1 * recon_loss + self.lambda_gan * gan_loss + self.lambda_force * force_loss
return total, recon_loss, gan_loss, force_loss
def calculate_soft_dice_loss(pred, target, smooth=1e-6):
"""Dice score (higher is better)."""
pred_flat = pred.view(pred.size(0), -1)
target_flat = target.view(target.size(0), -1)
intersection = (pred_flat * target_flat).sum(dim=1)
dice_scores = (2.0 * intersection + smooth) / (pred_flat.sum(dim=1) + target_flat.sum(dim=1) + smooth)
return dice_scores.mean().item()
def train_s2f(generator, discriminator, train_loader, val_loader, device='cuda',
num_epochs=100, g_lr=2e-4, d_lr=2e-4, beta1=0.5, beta2=0.999,
save_dir='ckp', lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0,
gan_mode='vanilla', save_predictions_every=5, custom_loss=None,
loaded_metadata=False, use_settings=False, use_force_consistency=False,
force_consistency_target='mean', config_path=None):
"""
Train S2F model.
"""
from diffusers.optimization import get_cosine_schedule_with_warmup
config_path = config_path or os.path.join(S2F_ROOT, 'config', 'substrate_settings.json')
normalization_params = None
if use_settings:
if not loaded_metadata:
raise ValueError("loaded_metadata must be True when use_settings=True")
normalization_params = compute_settings_normalization(config_path=config_path)
history = {'g_loss': [], 'd_loss': [], 'g_recon_loss': [], 'g_gan_loss': [], 'g_force_loss': [],
'train_loss': [], 'val_loss': [], 'train_ssim': [], 'val_ssim': [],
'train_psnr': [], 'val_psnr': [], 'train_mse': [], 'val_mse': [],
'train_dice_score': [], 'val_dice_score': []}
if not torch.cuda.is_available() and device == 'cuda':
device = 'cpu'
generator = generator.to(device)
discriminator = discriminator.to(device)
criterion = S2FLoss(lambda_L1=lambda_L1, lambda_gan=lambda_gan, lambda_force=lambda_force,
gan_mode=gan_mode, custom_loss=custom_loss,
use_force_consistency=use_force_consistency,
force_consistency_target=force_consistency_target)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(beta1, beta2))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(beta1, beta2))
num_steps = len(train_loader) * num_epochs
g_scheduler = get_cosine_schedule_with_warmup(g_optimizer, int(num_steps * 0.1), num_steps)
d_scheduler = get_cosine_schedule_with_warmup(d_optimizer, int(num_steps * 0.1), num_steps)
os.makedirs(save_dir, exist_ok=True)
vis_dir = os.path.join(save_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
best_val_loss = float('inf')
disc_output_shape = None
for epoch in range(num_epochs):
generator.train()
discriminator.train()
g_loss_total = d_loss_total = g_recon_total = g_gan_total = g_force_total = 0.0
train_ssim = train_psnr = train_mse = train_dice = 0.0
pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
for batch_data in pbar:
if loaded_metadata:
input_images, target_images, _, _, metadata = batch_data
else:
input_images, target_images, _, _ = batch_data
input_images = input_images.to(device, dtype=torch.float32)
target_images = target_images.to(device, dtype=torch.float32)
batch_size = input_images.size(0)
if use_settings and normalization_params is not None:
settings_channels = create_settings_channels(
metadata, normalization_params, device, input_images.shape,
config_path=config_path
)
input_images = torch.cat([input_images, settings_channels], dim=1)
target_scaled = target_images * 2.0 - 1.0
if disc_output_shape is None:
with torch.no_grad():
dummy = torch.cat([input_images[:1], target_scaled[:1]], dim=1)
disc_output_shape = discriminator(dummy).shape[2:]
real_labels = torch.ones(batch_size, 1, *disc_output_shape).to(device)
fake_labels = torch.zeros(batch_size, 1, *disc_output_shape).to(device)
g_optimizer.zero_grad()
fake_images = generator(input_images)
fake_for_loss = (fake_images + 1.0) / 2.0
fake_input = torch.cat([input_images, fake_images], dim=1)
fake_pred = discriminator(fake_input)
g_loss, g_recon, g_gan, g_force = criterion(fake_for_loss, target_images, fake_pred, real_labels)
g_loss.backward()
g_optimizer.step()
d_optimizer.zero_grad()
real_input = torch.cat([input_images, target_scaled], dim=1)
real_pred = discriminator(real_input)
d_real = criterion.gan_loss(real_pred, real_labels)
fake_input_d = torch.cat([input_images, fake_images.detach()], dim=1)
fake_pred_d = discriminator(fake_input_d)
d_fake = criterion.gan_loss(fake_pred_d, fake_labels)
d_loss = (d_real + d_fake) * 0.5
d_loss.backward()
d_optimizer.step()
g_scheduler.step()
d_scheduler.step()
g_loss_total += g_loss.item()
d_loss_total += d_loss.item()
g_recon_total += g_recon.item()
g_gan_total += g_gan.item()
g_force_total += g_force.item() if isinstance(g_force, torch.Tensor) else g_force
train_ssim += calculate_ssim_tensor(fake_for_loss, target_images)
train_psnr += calculate_psnr(fake_for_loss, target_images)
train_mse += F.mse_loss(fake_for_loss, target_images).item()
train_dice += calculate_soft_dice_loss(fake_for_loss, target_images)
pbar.set_postfix({'G': g_loss.item(),
'D': d_loss.item(), 'Dice': train_dice / (pbar.n + 1)})
n_train = len(train_loader)
g_loss_total /= n_train
d_loss_total /= n_train
train_ssim /= n_train
train_psnr /= n_train
train_mse /= n_train
train_dice /= n_train
generator.eval()
val_loss = val_ssim = val_psnr = val_mse = val_dice = 0.0
with torch.no_grad():
for batch_data in val_loader:
if loaded_metadata:
input_images, target_images, _, _, metadata = batch_data
else:
input_images, target_images, _, _ = batch_data
input_images = input_images.to(device, dtype=torch.float32)
target_images = target_images.to(device, dtype=torch.float32)
if use_settings and normalization_params is not None:
settings_channels = create_settings_channels(
metadata, normalization_params, device, input_images.shape,
config_path=config_path
)
input_images = torch.cat([input_images, settings_channels], dim=1)
fake_images = generator(input_images)
fake_for_loss = (fake_images + 1.0) / 2.0
_, recon_loss, _, force_loss = criterion(fake_for_loss, target_images)
val_loss += recon_loss.item()
val_ssim += calculate_ssim_tensor(fake_for_loss, target_images)
val_psnr += calculate_psnr(fake_for_loss, target_images)
val_mse += F.mse_loss(fake_for_loss, target_images).item()
val_dice += calculate_soft_dice_loss(fake_for_loss, target_images)
n_val = len(val_loader)
val_loss /= n_val
val_ssim /= n_val
val_psnr /= n_val
val_mse /= n_val
val_dice /= n_val
history['g_loss'].append(g_loss_total)
history['d_loss'].append(d_loss_total)
history['train_loss'].append(g_loss_total)
history['val_loss'].append(val_loss)
history['train_ssim'].append(train_ssim)
history['val_ssim'].append(val_ssim)
history['train_psnr'].append(train_psnr)
history['val_psnr'].append(val_psnr)
history['train_mse'].append(train_mse)
history['val_mse'].append(val_mse)
history['train_dice_score'].append(train_dice)
history['val_dice_score'].append(val_dice)
best_mark = "✓" if val_loss < best_val_loss else ""
print(f"Train: G_Loss:{g_loss_total:.4f} D_Loss:{d_loss_total:.4f} "
f"MSE:{train_mse:.4f} SSIM:{train_ssim:.4f} Dice:{train_dice:.4f}")
print(f"Valid: Loss:{val_loss:.4f} MSE:{val_mse:.4f} SSIM:{val_ssim:.4f} Dice:{val_dice:.4f} {best_mark}")
checkpoint = {
'epoch': epoch,
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'g_optimizer_state_dict': g_optimizer.state_dict(),
'd_optimizer_state_dict': d_optimizer.state_dict(),
'val_loss': val_loss,
'history': history
}
torch.save(checkpoint, os.path.join(save_dir, 'last_checkpoint.pth'))
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(checkpoint, os.path.join(save_dir, 'best_checkpoint.pth'))
if epoch % save_predictions_every == 0:
generator.eval()
with torch.no_grad():
batch_data = next(iter(val_loader))
if loaded_metadata:
input_images, target_images, _, _, metadata = batch_data
else:
input_images, target_images, _, _ = batch_data
input_images = input_images.to(device, dtype=torch.float32)
target_images = target_images.to(device, dtype=torch.float32)
if use_settings and normalization_params is not None:
settings_channels = create_settings_channels(
metadata, normalization_params, device, input_images.shape,
config_path=config_path
)
input_images = torch.cat([input_images, settings_channels], dim=1)
fake_images = generator(input_images)
fake_vis = (fake_images + 1.0) / 2.0
n_vis = min(4, input_images.size(0))
fig, axes = plt.subplots(3, n_vis, figsize=(4 * n_vis, 12))
if n_vis == 1:
axes = axes.reshape(3, 1)
for i in range(n_vis):
axes[0, i].imshow(input_images[i, 0].cpu().numpy(), cmap='gray')
axes[0, i].axis('off')
axes[1, i].imshow(fake_vis[i, 0].cpu().numpy(), cmap='jet', vmin=0, vmax=1)
axes[1, i].axis('off')
axes[2, i].imshow(target_images[i, 0].cpu().numpy(), cmap='jet', vmin=0, vmax=1)
axes[2, i].axis('off')
plt.tight_layout()
plt.savefig(os.path.join(vis_dir, f'predictions_epoch_{epoch:02d}.png'), dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved visualization for epoch {epoch}")
return history