File size: 4,516 Bytes
dfafaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models as tvm
from torchvision.datasets import ImageFolder

from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
from torch.optim.lr_scheduler import LambdaLR

# ----------------------------
# Config
# ----------------------------
DATA_ROOT = "data/eurosat_custom/train"   # prepared split (train only, unlabeled)
BATCH_SIZE = 256
EPOCHS = 150
LR = 0.06
NUM_WORKERS = 8
IMG_SIZE = 224             # resize inside transform
OUT_DIR = Path("checkpoints_ssl")
OUT_DIR.mkdir(parents=True, exist_ok=True)
warmup_epochs = 10
total_epochs = EPOCHS

def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return float(epoch + 1) / warmup_epochs
    progress = (epoch - warmup_epochs) / float(total_epochs - warmup_epochs)
    return 0.5 * (1.0 + torch.cos(torch.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

device = "cuda" if torch.cuda.is_available() else "cpu"
use_amp = torch.cuda.is_available()  # mixed precision if GPU


# ----------------------------
# Model: ResNet18 encoder + SimCLR projection head
# ----------------------------
class SimCLR(nn.Module):
    def __init__(self, backbone, in_dim=512, proj_hidden=512, proj_out=128):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(in_dim, proj_hidden, proj_out)

    def forward(self, x):
        # backbone assumed to output [N, C, 1, 1] after global pooling
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

# Build a torchvision resnet18 backbone without the FC layer
resnet = tvm.resnet18(weights=None)
# replace avgpool+fc stack with Identity + keep global avgpool:
# torchvision resnet18 returns features after avgpool as 512-d before fc.
backbone = nn.Sequential(*list(resnet.children())[:-1])  # until avgpool, outputs [N,512,1,1]
model = SimCLR(backbone, in_dim=512, proj_hidden=512, proj_out=128).to(device)


# ----------------------------
# Data: EuroSAT train images as unlabeled pairs of views
# ----------------------------
# SimCLR default normalization in Lightly is ImageNet stats; perfect for ResNet18 at 224.
transform = SimCLRTransform(
    input_size=IMG_SIZE,
    gaussian_blur=0.1,        # EuroSAT is small; mild blur helps but keep modest
    cj_strength=0.5,          # color jitter strength
)

dataset = ImageFolder(DATA_ROOT, transform=transform)
# ImageFolder returns ( (v1, v2), label ) because transform yields two views.
# We'll ignore labels during pretraining.
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
)


# ----------------------------
# Objective & Optimizer
# ----------------------------
criterion = NTXentLoss(temperature=0.5)   # standard SimCLR temperature
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


# ----------------------------
# Training loop
# ----------------------------
print(f"Starting SimCLR pretraining on {device} for {EPOCHS} epochs…")
model.train()
for epoch in range(1, EPOCHS + 1):
    total_loss = 0.0
    for (v1, v2), _ in loader:      # labels are unused
        v1 = v1.to(device, non_blocking=True)
        v2 = v2.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=use_amp):
            z1 = model(v1)
            z2 = model(v2)
            loss = criterion(z1, z2)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.detach().item()

    avg_loss = total_loss / len(loader)
    current_lr = scheduler.get_last_lr()[0]
    print(f"epoch {epoch:03d} | loss {avg_loss:.5f} | lr {current_lr:.5f}")

    scheduler.step()

    # (optional) save checkpoints every N epochs
    if epoch % 25 == 0 or epoch == EPOCHS:
        # save only the encoder (backbone) weights for fine-tuning
        enc_state = {k.replace("backbone.", "", 1): v
                     for k, v in model.state_dict().items()
                     if k.startswith("backbone.")}
        torch.save(enc_state, OUT_DIR / f"simclr_resnet18_eurosat_epoch{epoch}.pt")

print("Done.")