OpenFakeDemo / app /model.py
vicliv's picture
added softmax temperature for better uncertainty calibration
d2354a4
Raw
History Blame Contribute Delete
2.36 kB
import os
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModelForImageClassification
HF_NAME = "microsoft/swinv2-base-patch4-window16-256"
WEIGHTS_PATH = Path(__file__).parent.parent / "model.safetensors"
NUM_LABELS = 2
SOFTMAX_TEMPERATURE = 2.0
_device = torch.device("cpu")
_processor = None
_model = None
def _strip_prefixes(state_dict: dict) -> dict:
"""Strip common wrapper prefixes (DDP, Lightning) from state dict keys."""
prefixes = ("module.", "model.")
cleaned = {}
for k, v in state_dict.items():
new_k = k
for p in prefixes:
if new_k.startswith(p):
new_k = new_k[len(p):]
break
cleaned[new_k] = v
return cleaned
def load_detector():
global _processor, _model
if _model is not None:
return _processor, _model
cache_dir = os.environ.get("HF_HOME", "/tmp/hf-cache")
processor = AutoImageProcessor.from_pretrained(
HF_NAME, cache_dir=cache_dir
)
model = AutoModelForImageClassification.from_pretrained(
HF_NAME, cache_dir=cache_dir
)
model.num_labels = NUM_LABELS
model.config.num_labels = NUM_LABELS
model.config.id2label = {0: "real", 1: "fake"}
model.config.label2id = {"real": 0, "fake": 1}
model.classifier = nn.Linear(model.swinv2.num_features, NUM_LABELS)
state_dict = load_file(str(WEIGHTS_PATH))
state_dict = _strip_prefixes(state_dict)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"[load_detector] missing keys ({len(missing)}): {missing[:10]}")
if unexpected:
print(f"[load_detector] unexpected keys ({len(unexpected)}): {unexpected[:10]}")
model.eval().to(_device)
_processor = processor
_model = model
return _processor, _model
@torch.no_grad()
def predict_image(image: Image.Image) -> float:
"""Returns P(fake) in [0, 1]."""
processor, model = load_detector()
if image.mode != "RGB":
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(_device)
logits = model(**inputs).logits
probs = torch.softmax(logits / SOFTMAX_TEMPERATURE, dim=-1)
return float(probs[0, 1].item())