import os from pathlib import Path import torch from torch import nn from torch.utils.data import DataLoader from torchvision import models as tvm from torchvision.datasets import ImageFolder from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead from lightly.transforms.simclr_transform import SimCLRTransform from torch.optim.lr_scheduler import LambdaLR # ---------------------------- # Config # ---------------------------- DATA_ROOT = "data/eurosat_custom/train" # prepared split (train only, unlabeled) BATCH_SIZE = 256 EPOCHS = 150 LR = 0.06 NUM_WORKERS = 8 IMG_SIZE = 224 # resize inside transform OUT_DIR = Path("checkpoints_ssl") OUT_DIR.mkdir(parents=True, exist_ok=True) warmup_epochs = 10 total_epochs = EPOCHS def lr_lambda(epoch): if epoch < warmup_epochs: return float(epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / float(total_epochs - warmup_epochs) return 0.5 * (1.0 + torch.cos(torch.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) device = "cuda" if torch.cuda.is_available() else "cpu" use_amp = torch.cuda.is_available() # mixed precision if GPU # ---------------------------- # Model: ResNet18 encoder + SimCLR projection head # ---------------------------- class SimCLR(nn.Module): def __init__(self, backbone, in_dim=512, proj_hidden=512, proj_out=128): super().__init__() self.backbone = backbone self.projection_head = SimCLRProjectionHead(in_dim, proj_hidden, proj_out) def forward(self, x): # backbone assumed to output [N, C, 1, 1] after global pooling x = self.backbone(x).flatten(start_dim=1) z = self.projection_head(x) return z # Build a torchvision resnet18 backbone without the FC layer resnet = tvm.resnet18(weights=None) # replace avgpool+fc stack with Identity + keep global avgpool: # torchvision resnet18 returns features after avgpool as 512-d before fc. backbone = nn.Sequential(*list(resnet.children())[:-1]) # until avgpool, outputs [N,512,1,1] model = SimCLR(backbone, in_dim=512, proj_hidden=512, proj_out=128).to(device) # ---------------------------- # Data: EuroSAT train images as unlabeled pairs of views # ---------------------------- # SimCLR default normalization in Lightly is ImageNet stats; perfect for ResNet18 at 224. transform = SimCLRTransform( input_size=IMG_SIZE, gaussian_blur=0.1, # EuroSAT is small; mild blur helps but keep modest cj_strength=0.5, # color jitter strength ) dataset = ImageFolder(DATA_ROOT, transform=transform) # ImageFolder returns ( (v1, v2), label ) because transform yields two views. # We'll ignore labels during pretraining. loader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available(), ) # ---------------------------- # Objective & Optimizer # ---------------------------- criterion = NTXentLoss(temperature=0.5) # standard SimCLR temperature optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4) scaler = torch.cuda.amp.GradScaler(enabled=use_amp) # ---------------------------- # Training loop # ---------------------------- print(f"Starting SimCLR pretraining on {device} for {EPOCHS} epochs…") model.train() for epoch in range(1, EPOCHS + 1): total_loss = 0.0 for (v1, v2), _ in loader: # labels are unused v1 = v1.to(device, non_blocking=True) v2 = v2.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=use_amp): z1 = model(v1) z2 = model(v2) loss = criterion(z1, z2) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.detach().item() avg_loss = total_loss / len(loader) current_lr = scheduler.get_last_lr()[0] print(f"epoch {epoch:03d} | loss {avg_loss:.5f} | lr {current_lr:.5f}") scheduler.step() # (optional) save checkpoints every N epochs if epoch % 25 == 0 or epoch == EPOCHS: # save only the encoder (backbone) weights for fine-tuning enc_state = {k.replace("backbone.", "", 1): v for k, v in model.state_dict().items() if k.startswith("backbone.")} torch.save(enc_state, OUT_DIR / f"simclr_resnet18_eurosat_epoch{epoch}.pt") print("Done.")