| | |
| | """ |
| | S2F evaluation script. |
| | Metrics: MSE, MS-SSIM, Pixel Correlation, Relative Magnitude Error, Force Sum/Mean correlation. |
| | |
| | Usage: |
| | python -m training.evaluate --model single_cell --checkpoint ckp/best_checkpoint.pth --data path/to/test |
| | python -m training.evaluate --model spheroid --checkpoint ckp/best_checkpoint.pth --data path/to/test |
| | """ |
| | import os |
| | import sys |
| | import argparse |
| |
|
| | S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| | if S2F_ROOT not in sys.path: |
| | sys.path.insert(0, S2F_ROOT) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Evaluate S2F model') |
| | parser.add_argument('--data', required=True, help='Path to test folder (subfolders with BF_001.tif, *_gray.jpg)') |
| | parser.add_argument('--model', choices=['single_cell', 'spheroid'], default='single_cell') |
| | parser.add_argument('--checkpoint', required=True, help='Path to .pth checkpoint') |
| | parser.add_argument('--substrate', default='fibroblasts_PDMS', help='Substrate for single_cell') |
| | parser.add_argument('--batch_size', type=int, default=2) |
| | parser.add_argument('--img_size', type=int, default=1024) |
| | parser.add_argument('--threshold', type=float, default=0.0, help='Threshold for heatmap metrics') |
| | parser.add_argument('--output', default=None, help='Optional CSV path for per-sample metrics') |
| | parser.add_argument('--save_plots', default=None, help='Directory to save prediction plots') |
| | parser.add_argument('--device', default='cuda') |
| | args = parser.parse_args() |
| |
|
| | from data.cell_dataset import load_folder_data |
| | from models.s2f_model import create_s2f_model |
| | from utils.substrate_settings import compute_settings_normalization |
| | from utils.metrics import ( |
| | evaluate_metrics_on_dataset, |
| | print_metrics_report, |
| | gen_prediction_plots, |
| | detect_tanh_output_model, |
| | ) |
| | import torch |
| | import pandas as pd |
| |
|
| | use_settings = args.model == 'single_cell' |
| | config_path = os.path.join(S2F_ROOT, 'config', 'substrate_settings.json') |
| |
|
| | print(f"Loading data from {args.data}") |
| | val_loader = load_folder_data( |
| | args.data, |
| | substrate=args.substrate if use_settings else None, |
| | img_size=args.img_size, |
| | batch_size=args.batch_size, |
| | return_metadata=use_settings, |
| | ) |
| |
|
| | in_channels = 3 if use_settings else 1 |
| | model_type = 's2f' if use_settings else 's2f_spheroid' |
| | generator, _ = create_s2f_model(in_channels=in_channels, model_type=model_type) |
| | ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False) |
| | generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True) |
| |
|
| | norm_params = compute_settings_normalization(config_path=config_path) if use_settings else None |
| | uses_tanh = detect_tanh_output_model(generator) |
| |
|
| | results = evaluate_metrics_on_dataset( |
| | generator, |
| | val_loader, |
| | device=args.device, |
| | description="Evaluating", |
| | save_predictions=(args.save_plots is not None or args.output is not None), |
| | threshold=args.threshold, |
| | use_settings=use_settings, |
| | normalization_params=norm_params, |
| | config_path=config_path, |
| | substrate_override=args.substrate, |
| | ) |
| |
|
| | report = {'validation': results} |
| | print_metrics_report(report, threshold=args.threshold, uses_tanh=uses_tanh) |
| | print(f"Samples: {len(val_loader.dataset)}") |
| |
|
| | if args.save_plots and 'individual_predictions' in results: |
| | gen_prediction_plots( |
| | results['individual_predictions'], |
| | args.save_plots, |
| | sort_by='mse', |
| | sort_order='asc', |
| | threshold=args.threshold, |
| | ) |
| | print(f"Saved prediction plots to {args.save_plots}") |
| |
|
| | if args.output: |
| | preds = results.get('individual_predictions', []) |
| | if preds: |
| | df = pd.DataFrame([{ |
| | 'mse': p['mse'], |
| | 'ms_ssim': p['ms_ssim'], |
| | 'pixel_correlation': p['pixel_correlation'], |
| | 'relative_magnitude_error': p.get('wfm_relative_magnitude_error'), |
| | 'force_sum_gt': p['force_sum_gt'], |
| | 'force_sum_pred': p['force_sum_pred'], |
| | } for p in preds]) |
| | df.to_csv(args.output, index=False) |
| | print(f"Saved per-sample metrics to {args.output}") |
| | else: |
| | |
| | with open(args.output.replace('.csv', '_summary.txt'), 'w') as f: |
| | f.write(f"MSE: {results['heatmap']['mse']:.6f}\n") |
| | f.write(f"MS-SSIM: {results['heatmap']['ms_ssim']:.4f}\n") |
| | f.write(f"Pixel Corr: {results['heatmap']['pixel_correlation']:.4f}\n") |
| | f.write(f"Rel Mag Error: {results['wfm']['relative_magnitude_error']:.4f}\n") |
| | print(f"Saved summary to {args.output.replace('.csv', '_summary.txt')}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|