simplicityprevails / models.py
Lunahera's picture
Initial upload of simplicityprevails from local project
ce8f665 verified
"""Minimal model-loading code for the 7 VFM baselines in the paper."""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Callable
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import AutoConfig, AutoImageProcessor, AutoModel
ROOT = Path(__file__).resolve().parent
WEIGHTS_DIR = ROOT / "weights"
MODEL_SPECS = {
"metacliplin": {
"paper_name": "MetaCLIP-Linear",
"checkpoint": "metacliplin0.pth",
"hf_model": "facebook/metaclip-h14-fullcc2.5b",
"feature_dim": 1280,
"image_size": 224,
"pooler_output": True,
},
"metaclip2lin": {
"paper_name": "MetaCLIP2-Linear",
"checkpoint": "metaclip2lin0.pth",
"hf_model": "facebook/metaclip-2-worldwide-giant",
"feature_dim": 1280,
"image_size": 224,
"pooler_output": True,
},
"sigliplin": {
"paper_name": "SigLIP-Linear",
"checkpoint": "sigliplin0.pth",
"hf_model": "google/siglip-large-patch16-384",
"feature_dim": 1024,
"image_size": 384,
"pooler_output": True,
},
"siglip2lin": {
"paper_name": "SigLIP2-Linear",
"checkpoint": "siglip2lin0.pth",
"hf_model": "google/siglip2-giant-opt-patch16-384",
"feature_dim": 1536,
"image_size": 384,
"pooler_output": True,
},
"pelin": {
"paper_name": "PE-CLIP-Linear",
"checkpoint": "pelin0.pth",
"feature_dim": 1024,
"image_size": 336,
"pooler_output": False,
},
"dinov2lin": {
"paper_name": "DINOv2-Linear",
"checkpoint": "dinov2lin0.pth",
"feature_dim": 1024,
"pooler_output": False,
},
"dinov3lin": {
"paper_name": "DINOv3-Linear",
"checkpoint": "dinov3lin0.pth",
"hf_model": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
"feature_dim": 4096,
"pooler_output": False,
},
}
ALIASES = {
"MetaCLIP-Linear": "metacliplin",
"MetaCLIP2-Linear": "metaclip2lin",
"SigLIP-Linear": "sigliplin",
"SigLIP2-Linear": "siglip2lin",
"PE-CLIP-Linear": "pelin",
"DINOv2-Linear": "dinov2lin",
"DINOv3-Linear": "dinov3lin",
}
def canonical_model_name(name: str) -> str:
if name in MODEL_SPECS:
return name
if name in ALIASES:
return ALIASES[name]
raise KeyError(f"Unknown model: {name}")
def default_checkpoint_path(model_name: str) -> Path:
model_name = canonical_model_name(model_name)
return WEIGHTS_DIR / MODEL_SPECS[model_name]["checkpoint"]
def _resolve_device(device: str | torch.device | None = None) -> torch.device:
if device is None:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(device)
def _load_checkpoint(checkpoint_path: str | Path) -> dict:
checkpoint = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
if isinstance(checkpoint, dict):
for key in ("state_dict", "model", "model_state_dict"):
if key in checkpoint and isinstance(checkpoint[key], dict):
checkpoint = checkpoint[key]
break
normalized = {}
for key, value in checkpoint.items():
normalized[key[7:] if key.startswith("module.") else key] = value
return normalized
def _infer_feature_dim(state_dict: dict, default_dim: int) -> int:
head_weight = state_dict.get("head.weight")
if isinstance(head_weight, torch.Tensor) and head_weight.ndim == 2:
return int(head_weight.shape[1])
return default_dim
def _load_image_processor(model_name: str):
try:
return AutoImageProcessor.from_pretrained(model_name, local_files_only=True)
except Exception:
try:
return AutoImageProcessor.from_pretrained(model_name)
except Exception:
return None
def _load_backbone(model_name: str):
try:
return AutoModel.from_pretrained(model_name, local_files_only=True)
except Exception:
config = AutoConfig.from_pretrained(model_name)
return AutoModel.from_config(config)
class _PoolerLinearModel(nn.Module):
def __init__(self, backbone: nn.Module, feature_dim: int):
super().__init__()
self.backbone = backbone
self.head = nn.Linear(feature_dim, 2)
def forward(self, x):
with torch.no_grad():
outputs = self.backbone(x)
features = outputs.pooler_output.float()
return self.head(features)
class _ClsTokenLinearModel(nn.Module):
def __init__(self, backbone: nn.Module, feature_dim: int):
super().__init__()
self.backbone = backbone
self.head = nn.Linear(feature_dim, 2)
def forward(self, x):
with torch.no_grad():
outputs = self.backbone(x)
features = outputs.last_hidden_state[:, 0].float()
return self.head(features)
class _PELinearModel(nn.Module):
def __init__(self, backbone: nn.Module, feature_dim: int):
super().__init__()
self.backbone = backbone
self.head = nn.Linear(feature_dim, 2)
def forward(self, x):
with torch.no_grad():
features = self.backbone(x)
if isinstance(features, torch.Tensor):
features = features.float()
return self.head(features)
def _finalize_model(model: nn.Module, state_dict: dict, device=None) -> nn.Module:
model.load_state_dict(state_dict, strict=False)
model.to(_resolve_device(device))
model.eval()
return model
def _build_clip_transform(image_size: int, image_processor=None):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
if image_processor is not None:
mean = getattr(image_processor, "image_mean", mean)
std = getattr(image_processor, "image_std", std)
return transforms.Compose(
[
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
def _build_dino_transform():
return transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def load_metacliplin(checkpoint_path: str | Path | None = None, device=None):
spec = MODEL_SPECS["metacliplin"]
checkpoint_path = checkpoint_path or default_checkpoint_path("metacliplin")
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
image_processor = _load_image_processor(spec["hf_model"])
backbone = _load_backbone(spec["hf_model"])
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, _build_clip_transform(spec["image_size"], image_processor)
def load_metaclip2lin(checkpoint_path: str | Path | None = None, device=None):
spec = MODEL_SPECS["metaclip2lin"]
checkpoint_path = checkpoint_path or default_checkpoint_path("metaclip2lin")
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
image_processor = _load_image_processor(spec["hf_model"])
backbone = _load_backbone(spec["hf_model"])
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, _build_clip_transform(spec["image_size"], image_processor)
def load_sigliplin(checkpoint_path: str | Path | None = None, device=None):
spec = MODEL_SPECS["sigliplin"]
checkpoint_path = checkpoint_path or default_checkpoint_path("sigliplin")
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
image_processor = _load_image_processor(spec["hf_model"])
backbone = _load_backbone(spec["hf_model"])
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, _build_clip_transform(spec["image_size"], image_processor)
def load_siglip2lin(checkpoint_path: str | Path | None = None, device=None):
spec = MODEL_SPECS["siglip2lin"]
checkpoint_path = checkpoint_path or default_checkpoint_path("siglip2lin")
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
image_processor = _load_image_processor(spec["hf_model"])
backbone = _load_backbone(spec["hf_model"])
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, _build_clip_transform(spec["image_size"], image_processor)
def load_dinov2lin(checkpoint_path: str | Path | None = None, device=None):
checkpoint_path = checkpoint_path or default_checkpoint_path("dinov2lin")
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov2lin"]["feature_dim"])
if feature_dim == 1536:
candidates = ["facebook/dinov2-giant", "facebook/dinov2-large"]
elif feature_dim == 1024:
candidates = ["facebook/dinov2-large", "facebook/dinov2-base"]
elif feature_dim == 768:
candidates = ["facebook/dinov2-base", "facebook/dinov2-small"]
else:
candidates = ["facebook/dinov2-large"]
last_error = None
backbone = None
for candidate in candidates:
try:
backbone = _load_backbone(candidate)
break
except Exception as exc:
last_error = exc
if backbone is None:
raise RuntimeError(f"Failed to load DINOv2 backbone: {last_error}")
model = _ClsTokenLinearModel(backbone, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, _build_dino_transform()
def load_dinov3lin(checkpoint_path: str | Path | None = None, device=None):
checkpoint_path = checkpoint_path or default_checkpoint_path("dinov3lin")
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov3lin"]["feature_dim"])
backbone = _load_backbone(MODEL_SPECS["dinov3lin"]["hf_model"])
model = _ClsTokenLinearModel(backbone, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, _build_dino_transform()
def load_pelin(checkpoint_path: str | Path | None = None, device=None):
checkpoint_path = checkpoint_path or default_checkpoint_path("pelin")
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as pe_transforms
state_dict = _load_checkpoint(checkpoint_path)
feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["pelin"]["feature_dim"])
clip_model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=False)
model = _PELinearModel(clip_model.visual, feature_dim)
model = _finalize_model(model, state_dict, device=device)
return model, pe_transforms.get_image_transform(MODEL_SPECS["pelin"]["image_size"])
LOADERS: dict[str, Callable] = {
"metacliplin": load_metacliplin,
"metaclip2lin": load_metaclip2lin,
"sigliplin": load_sigliplin,
"siglip2lin": load_siglip2lin,
"pelin": load_pelin,
"dinov2lin": load_dinov2lin,
"dinov3lin": load_dinov3lin,
}
def load_model(model_name: str, checkpoint_path: str | Path | None = None, device=None):
model_name = canonical_model_name(model_name)
return LOADERS[model_name](checkpoint_path=checkpoint_path, device=device)