File size: 1,672 Bytes
942415c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch, torch.nn as nn, json, io
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image

HF_REPO_ID = "ferdaouskachouri/financial_classification"
SECRET_KEY  = "mon-api-key-secret-2004"  # ← change par ce que tu veux

with open(hf_hub_download(HF_REPO_ID, "config.json")) as f:
    config = json.load(f)

m = models.efficientnet_b0(weights=None)
in_f = m.classifier[1].in_features
m.classifier = nn.Sequential(
    nn.Dropout(0.4), nn.Linear(in_f, 256),
    nn.ReLU(), nn.Dropout(0.2),
    nn.Linear(256, config["num_classes"])
)
m.load_state_dict(torch.load(
    hf_hub_download(HF_REPO_ID, "pytorch_model.bin"),
    map_location="cpu"
))
m.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(config["normalize_mean"], config["normalize_std"]),
])

def predict(image, api_key):
    if api_key != SECRET_KEY:
        return "❌ Clé API invalide", {}
    tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        probs = torch.softmax(m(tensor), dim=1)[0]
    top_idx = probs.argmax().item()
    return config["classes"][top_idx], {
        cls: round(probs[i].item(), 4)
        for i, cls in enumerate(config["classes"])
    }

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Document"),
        gr.Textbox(label="Clé API", type="password")
    ],
    outputs=[
        gr.Textbox(label="Classe prédite"),
        gr.JSON(label="Probabilités")
    ],
    title="Document Classifier"
)

demo.launch()