| | |
| | import io |
| | import os |
| | from typing import List |
| |
|
| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | import torchvision.models as models |
| | import torchvision.transforms as T |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | class AgeGenderClassifier(nn.Module): |
| | def __init__(self): |
| | super(AgeGenderClassifier, self).__init__() |
| | self.intermediate = nn.Sequential( |
| | nn.Linear(2048, 512), |
| | nn.ReLU(), |
| | nn.Dropout(0.4), |
| | nn.Linear(512, 128), |
| | nn.ReLU(), |
| | nn.Dropout(0.4), |
| | nn.Linear(128, 64), |
| | nn.ReLU(), |
| | ) |
| | self.age_classifier = nn.Sequential( |
| | nn.Linear(64, 1), |
| | nn.Sigmoid() |
| | ) |
| | self.gender_classifier = nn.Sequential( |
| | nn.Linear(64, 1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.intermediate(x) |
| | age = self.age_classifier(x) |
| | gender = self.gender_classifier(x) |
| | return age, gender |
| |
|
| |
|
| | def build_model(weights_path: str): |
| | """Rebuild VGG16 backbone + custom avgpool/classifier then load weights.""" |
| | backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) |
| |
|
| | |
| | for p in backbone.parameters(): |
| | p.requires_grad = False |
| | |
| | for p in backbone.features[24:].parameters(): |
| | p.requires_grad = True |
| |
|
| | |
| | backbone.avgpool = nn.Sequential( |
| | nn.Conv2d(512, 512, kernel_size=3), |
| | nn.MaxPool2d(2), |
| | nn.ReLU(), |
| | nn.Flatten() |
| | ) |
| |
|
| | |
| | model = backbone |
| | model.classifier = AgeGenderClassifier() |
| |
|
| | |
| | if not os.path.exists(weights_path): |
| | raise FileNotFoundError(f"Model weights not found at {weights_path}") |
| |
|
| | state = torch.load(weights_path, map_location=device) |
| | try: |
| | model.load_state_dict(state) |
| | except Exception: |
| | if "model_state_dict" in state: |
| | model.load_state_dict(state["model_state_dict"]) |
| | else: |
| | raise |
| |
|
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | |
| | transform = T.Compose([ |
| | T.Resize((224, 224)), |
| | T.ToTensor(), |
| | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | INV_AGE_SCALE = 80 |
| |
|
| |
|
| | |
| | def predict_images_with_text(images: List[Image.Image], model): |
| | """Return original images and captions for each.""" |
| | if not images: |
| | return [], [] |
| |
|
| | tensors = [] |
| | for im in images: |
| | if im.mode != "RGB": |
| | im = im.convert("RGB") |
| | tensors.append(transform(im)) |
| |
|
| | batch = torch.stack(tensors).to(device) |
| |
|
| | with torch.no_grad(): |
| | pred_age, pred_gender = model(batch) |
| | pred_age = pred_age.squeeze(-1).cpu().numpy() |
| | pred_gender = pred_gender.squeeze(-1).cpu().numpy() |
| |
|
| | output_images = [] |
| | captions = [] |
| |
|
| | for img, pa, pg in zip(images, pred_age, pred_gender): |
| | age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE) |
| | gender_label = "Female" if pg > 0.5 else "Male" |
| | gender_emoji = "π©" if pg > 0.5 else "π¨" |
| | conf = float(pg if pg > 0.5 else 1 - pg) |
| |
|
| | output_images.append(np.array(img)) |
| | captions.append(f"{gender_emoji} {gender_label} ({conf:.2f}) β’ π Age β {age_val}") |
| |
|
| | return output_images, captions |
| |
|
| |
|
| | |
| | MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth") |
| | model = build_model(MODEL_WEIGHTS) |
| |
|
| |
|
| | |
| | with gr.Blocks(title="FairFace Age & Gender β Multi-image Demo") as demo: |
| | gr.Markdown(""" |
| | # π§ FairFace Multi-task Age & Gender Predictor |
| | Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image. |
| | """) |
| |
|
| | with gr.Row(): |
| | img_input = gr.File(file_count="multiple", label="Upload images") |
| | run_btn = gr.Button("Run βΆοΈ") |
| |
|
| | gallery = gr.Gallery( |
| | label="Uploaded Images", |
| | columns=3, |
| | height="auto" |
| | ) |
| |
|
| | captions = gr.HTML(label="Predictions") |
| |
|
| | def run_and_predict(files): |
| | if not files: |
| | return [], "" |
| | pil_imgs = [] |
| | for f in files: |
| | path = f if isinstance(f, str) else f.name |
| | pil_imgs.append(Image.open(path).convert("RGB")) |
| |
|
| | imgs, texts = predict_images_with_text(pil_imgs, model) |
| | captions_html = "<br>".join([f"<h2>{t}</h2>" for t in texts]) |
| | return imgs, captions_html |
| |
|
| | run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery, captions]) |
| |
|
| | gr.Markdown(""" |
| | --- |
| | **Tips & Notes** |
| | - Age is normalized to 0β80 years (approx.). |
| | - For best results, upload clear frontal face images. |
| | - This is a demo β respect privacy when using photos. π |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|