File size: 3,413 Bytes
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54e160a
 
2b9ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
"""
S2F training script.
Usage:
  python -m training.train --data path/to/dataset --model single_cell --epochs 100
  python -m training.train --data path/to/dataset --model spheroid --epochs 50
"""
import os
import sys
import argparse

S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if S2F_ROOT not in sys.path:
    sys.path.insert(0, S2F_ROOT)


def main():
    parser = argparse.ArgumentParser(description='Train S2F model')
    parser.add_argument('--data', required=True, help='Path to dataset (must have train/ and test/ subfolders)')
    parser.add_argument('--model', choices=['single_cell', 'spheroid'], default='single_cell',
                        help='Model type: single_cell (with substrate) or spheroid')
    parser.add_argument('--substrate', default=None,
                        help='Substrate name for single_cell when metadata not in dataset (e.g. fibroblasts_PDMS)')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--img_size', type=int, default=1024)
    parser.add_argument('--save_dir', default='ckp', help='Checkpoint save directory')
    parser.add_argument('--g_lr', type=float, default=2e-4)
    parser.add_argument('--d_lr', type=float, default=2e-4)
    parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--no_augment', action='store_true', help='Disable augmentations')
    parser.add_argument('--use_force_consistency', action='store_true')
    parser.add_argument('--force_target', choices=['mean', 'sum'], default='mean')
    args = parser.parse_args()

    from data.cell_dataset import prepare_data
    from models.s2f_model import create_s2f_model
    from training.s2f_trainer import train_s2f

    use_settings = args.model == 'single_cell'
    substrate = args.substrate or 'fibroblasts_PDMS'
    return_metadata = use_settings

    print(f"Loading data from {args.data} (model={args.model})")
    train_loader, val_loader = prepare_data(
        args.data,
        batch_size=args.batch_size,
        target_size=(args.img_size, args.img_size),
        use_augmentations=not args.no_augment,
        train_test_sep_folder=True,
        return_metadata=return_metadata,
        substrate=substrate if use_settings else None,
    )

    in_channels = 3 if use_settings else 1
    model_type = 's2f' if use_settings else 's2f_spheroid'
    generator, discriminator = create_s2f_model(in_channels=in_channels, model_type=model_type)

    if args.resume:
        ckpt = __import__('torch').load(args.resume, map_location='cpu', weights_only=False)
        generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True)
        print(f"Resumed from {args.resume}")

    history = train_s2f(
        generator, discriminator,
        train_loader, val_loader,
        device=args.device,
        num_epochs=args.epochs,
        g_lr=args.g_lr, d_lr=args.d_lr,
        save_dir=args.save_dir,
        loaded_metadata=return_metadata,
        use_settings=use_settings,
        use_force_consistency=args.use_force_consistency,
        force_consistency_target=args.force_target,
    )
    print(f"Training complete. Checkpoints saved to {args.save_dir}")


if __name__ == '__main__':
    main()