bdappv-models / model.py
gabrielkasmi's picture
Upload model.py with huggingface_hub
7f0920d verified
"""
model.py — helpers to load BDAPPV baseline models from HuggingFace Hub.
Usage:
from huggingface_hub import hf_hub_download
import importlib.util, sys
path = hf_hub_download("gabrielkasmi/bdappv-models", "model.py")
spec = importlib.util.spec_from_file_location("bdappv_model", path)
mod = importlib.util.module_from_spec(spec); spec.loader.exec_module(mod)
seg = mod.load_segmentation_model("google") # or "ign"
clf = mod.load_classification_model("google") # or "ign"
"""
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision.models import inception_v3
REPO = "gabrielkasmi/bdappv-models"
def load_segmentation_model(provider: str = "google", device: str = "cpu"):
"""
Load the DeepLabV3-ResNet101 segmentation model.
Args:
provider : "google" or "ign"
device : "cpu", "cuda", or "mps"
Returns:
model in eval mode
"""
assert provider in ("google", "ign"), "provider must be 'google' or 'ign'"
path = hf_hub_download(REPO, f"deeplab_{provider}_best.pth")
model = deeplabv3_resnet101(weights=None, aux_loss=False)
model.classifier[-1] = nn.Conv2d(256, 1, kernel_size=1)
state = torch.load(path, map_location=device, weights_only=False)
model_dict = model.state_dict()
compatible = {k: v for k, v in state.items()
if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(compatible)
model.load_state_dict(model_dict)
return model.eval().to(device)
def load_classification_model(provider: str = "google", device: str = "cpu"):
"""
Load the InceptionV3 classification model (panel / no panel).
Args:
provider : "google" or "ign"
device : "cpu", "cuda", or "mps"
Returns:
model in eval mode
"""
assert provider in ("google", "ign"), "provider must be 'google' or 'ign'"
path = hf_hub_download(REPO, f"inception_{provider}_best.pth")
model = inception_v3(weights=None, aux_logits=True)
model.fc = nn.Linear(model.fc.in_features, 1)
model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, 1)
state = torch.load(path, map_location=device, weights_only=False)
model.load_state_dict(state)
model.aux_logits = False # désactive pour l'inférence
return model.eval().to(device)