| import gradio as gr
|
| import torch, torch.nn as nn, json, io
|
| from torchvision import models, transforms
|
| from huggingface_hub import hf_hub_download
|
| from PIL import Image
|
|
|
| HF_REPO_ID = "ferdaouskachouri/financial_classification"
|
| SECRET_KEY = "mon-api-key-secret-2004"
|
|
|
| with open(hf_hub_download(HF_REPO_ID, "config.json")) as f:
|
| config = json.load(f)
|
|
|
| m = models.efficientnet_b0(weights=None)
|
| in_f = m.classifier[1].in_features
|
| m.classifier = nn.Sequential(
|
| nn.Dropout(0.4), nn.Linear(in_f, 256),
|
| nn.ReLU(), nn.Dropout(0.2),
|
| nn.Linear(256, config["num_classes"])
|
| )
|
| m.load_state_dict(torch.load(
|
| hf_hub_download(HF_REPO_ID, "pytorch_model.bin"),
|
| map_location="cpu"
|
| ))
|
| m.eval()
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(config["normalize_mean"], config["normalize_std"]),
|
| ])
|
|
|
| def predict(image, api_key):
|
| if api_key != SECRET_KEY:
|
| return "❌ Clé API invalide", {}
|
| tensor = transform(image).unsqueeze(0)
|
| with torch.no_grad():
|
| probs = torch.softmax(m(tensor), dim=1)[0]
|
| top_idx = probs.argmax().item()
|
| return config["classes"][top_idx], {
|
| cls: round(probs[i].item(), 4)
|
| for i, cls in enumerate(config["classes"])
|
| }
|
|
|
| demo = gr.Interface(
|
| fn=predict,
|
| inputs=[
|
| gr.Image(type="pil", label="Document"),
|
| gr.Textbox(label="Clé API", type="password")
|
| ],
|
| outputs=[
|
| gr.Textbox(label="Classe prédite"),
|
| gr.JSON(label="Probabilités")
|
| ],
|
| title="Document Classifier"
|
| )
|
|
|
| demo.launch() |