| """Minimal model-loading code for the 7 VFM baselines in the paper.""" |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| from typing import Callable |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from transformers import AutoConfig, AutoImageProcessor, AutoModel |
|
|
| ROOT = Path(__file__).resolve().parent |
| WEIGHTS_DIR = ROOT / "weights" |
|
|
| MODEL_SPECS = { |
| "metacliplin": { |
| "paper_name": "MetaCLIP-Linear", |
| "checkpoint": "metacliplin0.pth", |
| "hf_model": "facebook/metaclip-h14-fullcc2.5b", |
| "feature_dim": 1280, |
| "image_size": 224, |
| "pooler_output": True, |
| }, |
| "metaclip2lin": { |
| "paper_name": "MetaCLIP2-Linear", |
| "checkpoint": "metaclip2lin0.pth", |
| "hf_model": "facebook/metaclip-2-worldwide-giant", |
| "feature_dim": 1280, |
| "image_size": 224, |
| "pooler_output": True, |
| }, |
| "sigliplin": { |
| "paper_name": "SigLIP-Linear", |
| "checkpoint": "sigliplin0.pth", |
| "hf_model": "google/siglip-large-patch16-384", |
| "feature_dim": 1024, |
| "image_size": 384, |
| "pooler_output": True, |
| }, |
| "siglip2lin": { |
| "paper_name": "SigLIP2-Linear", |
| "checkpoint": "siglip2lin0.pth", |
| "hf_model": "google/siglip2-giant-opt-patch16-384", |
| "feature_dim": 1536, |
| "image_size": 384, |
| "pooler_output": True, |
| }, |
| "pelin": { |
| "paper_name": "PE-CLIP-Linear", |
| "checkpoint": "pelin0.pth", |
| "feature_dim": 1024, |
| "image_size": 336, |
| "pooler_output": False, |
| }, |
| "dinov2lin": { |
| "paper_name": "DINOv2-Linear", |
| "checkpoint": "dinov2lin0.pth", |
| "feature_dim": 1024, |
| "pooler_output": False, |
| }, |
| "dinov3lin": { |
| "paper_name": "DINOv3-Linear", |
| "checkpoint": "dinov3lin0.pth", |
| "hf_model": "facebook/dinov3-vit7b16-pretrain-lvd1689m", |
| "feature_dim": 4096, |
| "pooler_output": False, |
| }, |
| } |
|
|
| ALIASES = { |
| "MetaCLIP-Linear": "metacliplin", |
| "MetaCLIP2-Linear": "metaclip2lin", |
| "SigLIP-Linear": "sigliplin", |
| "SigLIP2-Linear": "siglip2lin", |
| "PE-CLIP-Linear": "pelin", |
| "DINOv2-Linear": "dinov2lin", |
| "DINOv3-Linear": "dinov3lin", |
| } |
|
|
|
|
| def canonical_model_name(name: str) -> str: |
| if name in MODEL_SPECS: |
| return name |
| if name in ALIASES: |
| return ALIASES[name] |
| raise KeyError(f"Unknown model: {name}") |
|
|
|
|
| def default_checkpoint_path(model_name: str) -> Path: |
| model_name = canonical_model_name(model_name) |
| return WEIGHTS_DIR / MODEL_SPECS[model_name]["checkpoint"] |
|
|
|
|
| def _resolve_device(device: str | torch.device | None = None) -> torch.device: |
| if device is None: |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return torch.device(device) |
|
|
|
|
| def _load_checkpoint(checkpoint_path: str | Path) -> dict: |
| checkpoint = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False) |
| if isinstance(checkpoint, dict): |
| for key in ("state_dict", "model", "model_state_dict"): |
| if key in checkpoint and isinstance(checkpoint[key], dict): |
| checkpoint = checkpoint[key] |
| break |
|
|
| normalized = {} |
| for key, value in checkpoint.items(): |
| normalized[key[7:] if key.startswith("module.") else key] = value |
| return normalized |
|
|
|
|
| def _infer_feature_dim(state_dict: dict, default_dim: int) -> int: |
| head_weight = state_dict.get("head.weight") |
| if isinstance(head_weight, torch.Tensor) and head_weight.ndim == 2: |
| return int(head_weight.shape[1]) |
| return default_dim |
|
|
|
|
| def _load_image_processor(model_name: str): |
| try: |
| return AutoImageProcessor.from_pretrained(model_name, local_files_only=True) |
| except Exception: |
| try: |
| return AutoImageProcessor.from_pretrained(model_name) |
| except Exception: |
| return None |
|
|
|
|
| def _load_backbone(model_name: str): |
| try: |
| return AutoModel.from_pretrained(model_name, local_files_only=True) |
| except Exception: |
| config = AutoConfig.from_pretrained(model_name) |
| return AutoModel.from_config(config) |
|
|
|
|
| class _PoolerLinearModel(nn.Module): |
| def __init__(self, backbone: nn.Module, feature_dim: int): |
| super().__init__() |
| self.backbone = backbone |
| self.head = nn.Linear(feature_dim, 2) |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| outputs = self.backbone(x) |
| features = outputs.pooler_output.float() |
| return self.head(features) |
|
|
|
|
| class _ClsTokenLinearModel(nn.Module): |
| def __init__(self, backbone: nn.Module, feature_dim: int): |
| super().__init__() |
| self.backbone = backbone |
| self.head = nn.Linear(feature_dim, 2) |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| outputs = self.backbone(x) |
| features = outputs.last_hidden_state[:, 0].float() |
| return self.head(features) |
|
|
|
|
| class _PELinearModel(nn.Module): |
| def __init__(self, backbone: nn.Module, feature_dim: int): |
| super().__init__() |
| self.backbone = backbone |
| self.head = nn.Linear(feature_dim, 2) |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| features = self.backbone(x) |
| if isinstance(features, torch.Tensor): |
| features = features.float() |
| return self.head(features) |
|
|
|
|
| def _finalize_model(model: nn.Module, state_dict: dict, device=None) -> nn.Module: |
| model.load_state_dict(state_dict, strict=False) |
| model.to(_resolve_device(device)) |
| model.eval() |
| return model |
|
|
|
|
| def _build_clip_transform(image_size: int, image_processor=None): |
| mean = [0.485, 0.456, 0.406] |
| std = [0.229, 0.224, 0.225] |
| if image_processor is not None: |
| mean = getattr(image_processor, "image_mean", mean) |
| std = getattr(image_processor, "image_std", std) |
| return transforms.Compose( |
| [ |
| transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC), |
| transforms.CenterCrop(image_size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=mean, std=std), |
| ] |
| ) |
|
|
|
|
| def _build_dino_transform(): |
| return transforms.Compose( |
| [ |
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
|
|
| def load_metacliplin(checkpoint_path: str | Path | None = None, device=None): |
| spec = MODEL_SPECS["metacliplin"] |
| checkpoint_path = checkpoint_path or default_checkpoint_path("metacliplin") |
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"]) |
| image_processor = _load_image_processor(spec["hf_model"]) |
| backbone = _load_backbone(spec["hf_model"]) |
| model = _PoolerLinearModel(backbone.vision_model, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, _build_clip_transform(spec["image_size"], image_processor) |
|
|
|
|
| def load_metaclip2lin(checkpoint_path: str | Path | None = None, device=None): |
| spec = MODEL_SPECS["metaclip2lin"] |
| checkpoint_path = checkpoint_path or default_checkpoint_path("metaclip2lin") |
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"]) |
| image_processor = _load_image_processor(spec["hf_model"]) |
| backbone = _load_backbone(spec["hf_model"]) |
| model = _PoolerLinearModel(backbone.vision_model, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, _build_clip_transform(spec["image_size"], image_processor) |
|
|
|
|
| def load_sigliplin(checkpoint_path: str | Path | None = None, device=None): |
| spec = MODEL_SPECS["sigliplin"] |
| checkpoint_path = checkpoint_path or default_checkpoint_path("sigliplin") |
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"]) |
| image_processor = _load_image_processor(spec["hf_model"]) |
| backbone = _load_backbone(spec["hf_model"]) |
| model = _PoolerLinearModel(backbone.vision_model, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, _build_clip_transform(spec["image_size"], image_processor) |
|
|
|
|
| def load_siglip2lin(checkpoint_path: str | Path | None = None, device=None): |
| spec = MODEL_SPECS["siglip2lin"] |
| checkpoint_path = checkpoint_path or default_checkpoint_path("siglip2lin") |
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"]) |
| image_processor = _load_image_processor(spec["hf_model"]) |
| backbone = _load_backbone(spec["hf_model"]) |
| model = _PoolerLinearModel(backbone.vision_model, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, _build_clip_transform(spec["image_size"], image_processor) |
|
|
|
|
| def load_dinov2lin(checkpoint_path: str | Path | None = None, device=None): |
| checkpoint_path = checkpoint_path or default_checkpoint_path("dinov2lin") |
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov2lin"]["feature_dim"]) |
| if feature_dim == 1536: |
| candidates = ["facebook/dinov2-giant", "facebook/dinov2-large"] |
| elif feature_dim == 1024: |
| candidates = ["facebook/dinov2-large", "facebook/dinov2-base"] |
| elif feature_dim == 768: |
| candidates = ["facebook/dinov2-base", "facebook/dinov2-small"] |
| else: |
| candidates = ["facebook/dinov2-large"] |
|
|
| last_error = None |
| backbone = None |
| for candidate in candidates: |
| try: |
| backbone = _load_backbone(candidate) |
| break |
| except Exception as exc: |
| last_error = exc |
| if backbone is None: |
| raise RuntimeError(f"Failed to load DINOv2 backbone: {last_error}") |
|
|
| model = _ClsTokenLinearModel(backbone, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, _build_dino_transform() |
|
|
|
|
| def load_dinov3lin(checkpoint_path: str | Path | None = None, device=None): |
| checkpoint_path = checkpoint_path or default_checkpoint_path("dinov3lin") |
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov3lin"]["feature_dim"]) |
| backbone = _load_backbone(MODEL_SPECS["dinov3lin"]["hf_model"]) |
| model = _ClsTokenLinearModel(backbone, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, _build_dino_transform() |
|
|
|
|
| def load_pelin(checkpoint_path: str | Path | None = None, device=None): |
| checkpoint_path = checkpoint_path or default_checkpoint_path("pelin") |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| import core.vision_encoder.pe as pe |
| import core.vision_encoder.transforms as pe_transforms |
|
|
| state_dict = _load_checkpoint(checkpoint_path) |
| feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["pelin"]["feature_dim"]) |
| clip_model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=False) |
| model = _PELinearModel(clip_model.visual, feature_dim) |
| model = _finalize_model(model, state_dict, device=device) |
| return model, pe_transforms.get_image_transform(MODEL_SPECS["pelin"]["image_size"]) |
|
|
|
|
| LOADERS: dict[str, Callable] = { |
| "metacliplin": load_metacliplin, |
| "metaclip2lin": load_metaclip2lin, |
| "sigliplin": load_sigliplin, |
| "siglip2lin": load_siglip2lin, |
| "pelin": load_pelin, |
| "dinov2lin": load_dinov2lin, |
| "dinov3lin": load_dinov3lin, |
| } |
|
|
|
|
| def load_model(model_name: str, checkpoint_path: str | Path | None = None, device=None): |
| model_name = canonical_model_name(model_name) |
| return LOADERS[model_name](checkpoint_path=checkpoint_path, device=device) |
|
|