File size: 4,920 Bytes
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54e160a
 
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54e160a
 
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
"""
S2F evaluation script.
Metrics: MSE, MS-SSIM, Pixel Correlation, Relative Magnitude Error, Force Sum/Mean correlation.

Usage:
  python -m training.evaluate --model single_cell --checkpoint ckp/best_checkpoint.pth --data path/to/test
  python -m training.evaluate --model spheroid --checkpoint ckp/best_checkpoint.pth --data path/to/test
"""
import os
import sys
import argparse

S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if S2F_ROOT not in sys.path:
    sys.path.insert(0, S2F_ROOT)


def main():
    parser = argparse.ArgumentParser(description='Evaluate S2F model')
    parser.add_argument('--data', required=True, help='Path to test folder (subfolders with BF_001.tif, *_gray.jpg)')
    parser.add_argument('--model', choices=['single_cell', 'spheroid'], default='single_cell')
    parser.add_argument('--checkpoint', required=True, help='Path to .pth checkpoint')
    parser.add_argument('--substrate', default='fibroblasts_PDMS', help='Substrate for single_cell')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--img_size', type=int, default=1024)
    parser.add_argument('--threshold', type=float, default=0.0, help='Threshold for heatmap metrics')
    parser.add_argument('--output', default=None, help='Optional CSV path for per-sample metrics')
    parser.add_argument('--save_plots', default=None, help='Directory to save prediction plots')
    parser.add_argument('--device', default='cuda')
    args = parser.parse_args()

    from data.cell_dataset import load_folder_data
    from models.s2f_model import create_s2f_model
    from utils.substrate_settings import compute_settings_normalization
    from utils.metrics import (
        evaluate_metrics_on_dataset,
        print_metrics_report,
        gen_prediction_plots,
        detect_tanh_output_model,
    )
    import torch
    import pandas as pd

    use_settings = args.model == 'single_cell'
    config_path = os.path.join(S2F_ROOT, 'config', 'substrate_settings.json')

    print(f"Loading data from {args.data}")
    val_loader = load_folder_data(
        args.data,
        substrate=args.substrate if use_settings else None,
        img_size=args.img_size,
        batch_size=args.batch_size,
        return_metadata=use_settings,
    )

    in_channels = 3 if use_settings else 1
    model_type = 's2f' if use_settings else 's2f_spheroid'
    generator, _ = create_s2f_model(in_channels=in_channels, model_type=model_type)
    ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
    generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True)

    norm_params = compute_settings_normalization(config_path=config_path) if use_settings else None
    uses_tanh = detect_tanh_output_model(generator)

    results = evaluate_metrics_on_dataset(
        generator,
        val_loader,
        device=args.device,
        description="Evaluating",
        save_predictions=(args.save_plots is not None or args.output is not None),
        threshold=args.threshold,
        use_settings=use_settings,
        normalization_params=norm_params,
        config_path=config_path,
        substrate_override=args.substrate,
    )

    report = {'validation': results}
    print_metrics_report(report, threshold=args.threshold, uses_tanh=uses_tanh)
    print(f"Samples: {len(val_loader.dataset)}")

    if args.save_plots and 'individual_predictions' in results:
        gen_prediction_plots(
            results['individual_predictions'],
            args.save_plots,
            sort_by='mse',
            sort_order='asc',
            threshold=args.threshold,
        )
        print(f"Saved prediction plots to {args.save_plots}")

    if args.output:
        preds = results.get('individual_predictions', [])
        if preds:
            df = pd.DataFrame([{
                'mse': p['mse'],
                'ms_ssim': p['ms_ssim'],
                'pixel_correlation': p['pixel_correlation'],
                'relative_magnitude_error': p.get('wfm_relative_magnitude_error'),
                'force_sum_gt': p['force_sum_gt'],
                'force_sum_pred': p['force_sum_pred'],
            } for p in preds])
            df.to_csv(args.output, index=False)
            print(f"Saved per-sample metrics to {args.output}")
        else:
            # Fallback: write summary only
            with open(args.output.replace('.csv', '_summary.txt'), 'w') as f:
                f.write(f"MSE: {results['heatmap']['mse']:.6f}\n")
                f.write(f"MS-SSIM: {results['heatmap']['ms_ssim']:.4f}\n")
                f.write(f"Pixel Corr: {results['heatmap']['pixel_correlation']:.4f}\n")
                f.write(f"Rel Mag Error: {results['wfm']['relative_magnitude_error']:.4f}\n")
            print(f"Saved summary to {args.output.replace('.csv', '_summary.txt')}")


if __name__ == '__main__':
    main()