File size: 1,672 Bytes
942415c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import gradio as gr
import torch, torch.nn as nn, json, io
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
HF_REPO_ID = "ferdaouskachouri/financial_classification"
SECRET_KEY = "mon-api-key-secret-2004" # ← change par ce que tu veux
with open(hf_hub_download(HF_REPO_ID, "config.json")) as f:
config = json.load(f)
m = models.efficientnet_b0(weights=None)
in_f = m.classifier[1].in_features
m.classifier = nn.Sequential(
nn.Dropout(0.4), nn.Linear(in_f, 256),
nn.ReLU(), nn.Dropout(0.2),
nn.Linear(256, config["num_classes"])
)
m.load_state_dict(torch.load(
hf_hub_download(HF_REPO_ID, "pytorch_model.bin"),
map_location="cpu"
))
m.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(config["normalize_mean"], config["normalize_std"]),
])
def predict(image, api_key):
if api_key != SECRET_KEY:
return "❌ Clé API invalide", {}
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
probs = torch.softmax(m(tensor), dim=1)[0]
top_idx = probs.argmax().item()
return config["classes"][top_idx], {
cls: round(probs[i].item(), 4)
for i, cls in enumerate(config["classes"])
}
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Document"),
gr.Textbox(label="Clé API", type="password")
],
outputs=[
gr.Textbox(label="Classe prédite"),
gr.JSON(label="Probabilités")
],
title="Document Classifier"
)
demo.launch() |