import os import io import asyncio import random import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt from PIL import Image, ImageFilter from fastapi import FastAPI, UploadFile, File, Query from fastapi.responses import StreamingResponse from huggingface_hub import snapshot_download, login from transformers import ( BlipProcessor, BlipForConditionalGeneration, ViTImageProcessor, AutoProcessor, AutoModelForCausalLM, CLIPModel, CLIPProcessor ) app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury") # --- Configuration & Paths --- REPO_ID = "SaniaE/Image_Captioning_Ensemble" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODELS = {} # Metadata for loading MODEL_CONFIGS = { "blip": { "subfolder": "blip", "proc_class": BlipProcessor, "model_class": BlipForConditionalGeneration, "base_path": "Salesforce/blip-image-captioning-large" }, "vit": { "subfolder": "vit", "proc_classes": [ViTImageProcessor, AutoProcessor], "model_class": AutoModelForCausalLM, "base_paths": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"] }, "clip": { "model_subfolder": "clip/clip_model", "proc_subfolder": "clip/clip_processor" } } @app.on_event("startup") async def startup_event(): global MODELS token = os.getenv("HF_Token") if token: login(token=token) print(f"Syncing weights from {REPO_ID}...") local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights") # 1. Load BLIP cfg_b = MODEL_CONFIGS["blip"] MODELS["blip"] = { "model": cfg_b["model_class"].from_pretrained(os.path.join(local_dir, cfg_b["subfolder"])).to(DEVICE), "processor": cfg_b["proc_class"].from_pretrained(cfg_b["base_path"]) } # 2. Load ViT/GIT Ensemble cfg_v = MODEL_CONFIGS["vit"] MODELS["vit"] = { "model": cfg_v["model_class"].from_pretrained(os.path.join(local_dir, cfg_v["subfolder"])).to(DEVICE), "processor": ( cfg_v["proc_classes"][0].from_pretrained(cfg_v["base_paths"][0]), cfg_v["proc_classes"][1].from_pretrained(cfg_v["base_paths"][1]) ) } # 3. Load Fine-Tuned CLIP (Your Jury) cfg_c = MODEL_CONFIGS["clip"] MODELS["clip"] = { "model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE), "processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"])) } print("All models synchronized. Auditor is active.") # --- Utilities --- def _generate_sync(m_name, image, temp, top_k, top_p): m_data = MODELS[m_name] if m_name == "vit": i_proc, t_proc = m_data["processor"] inputs = i_proc(images=image, return_tensors="pt").to(DEVICE) ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p) return t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip() else: proc = m_data["processor"] inputs = proc(images=image, return_tensors="pt").to(DEVICE) ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p) return proc.batch_decode(ids, skip_special_tokens=True)[0].strip() # --- Endpoints --- @app.post("/generate") async def generate_captions( file: UploadFile = File(...), temp: float = Query(0.8), top_k: int = Query(50), top_p: float = Query(0.9) ): """Generates 5 diverse captions using the model ensemble.""" image = Image.open(file.file).convert("RGB") architectures = ["blip", "vit"] selection = random.choices(architectures, k=5) tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in selection] captions = await asyncio.gather(*tasks) return {"captions": captions, "metadata": {"models_used": selection, "temp": temp}} @app.post("/saliency") async def get_vision_saliency(file: UploadFile = File(...)): """Objective Saliency: Shows what the Vision Encoder focuses on (Self-Attention).""" image_bytes = await file.read() orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") blip = MODELS["blip"] inputs = blip["processor"](images=orig_img, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = blip["model"].vision_model(inputs.pixel_values, output_attentions=True) attentions = outputs.attentions[-1] # Last layer # Average heads, look at CLS token attention to patches mask_1d = attentions[0, :, 0, 1:].mean(dim=0) grid_size = int(np.sqrt(mask_1d.shape[-1])) mask = mask_1d.view(grid_size, grid_size).cpu().numpy() mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8) mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC) mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=10)) heatmap = plt.get_cmap('magma')(np.array(mask_img)/255.0) heatmap_img = Image.fromarray((heatmap[:, :, :3] * 255).astype('uint8')).convert("RGB") blended = Image.blend(orig_img, heatmap_img, alpha=0.6) buf = io.BytesIO() blended.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") @app.post("/audit") async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)): """The CLIP-Powered Jury: Compares User Intent vs. Model Perception.""" image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # 1. Model Perception blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9) # 2. CLIP Scoring (Multimodal Alignment) clip_m = MODELS["clip"]["model"] clip_p = MODELS["clip"]["processor"] inputs = clip_p(text=[user_prompt, blip_caption], images=image, return_tensors="pt", padding=True).to(DEVICE) with torch.no_grad(): outputs = clip_m(**inputs) probs = outputs.logits_per_image.softmax(dim=-1).cpu().numpy()[0] u_score, m_score = float(probs[0]), float(probs[1]) # 3. Decision Logic if u_score < 0.35: verdict = "Perspective Divergence: Intent not grounded in image." elif abs(u_score - m_score) < 0.15: verdict = "Consensus: High Alignment." else: verdict = "Model Bias Detected." return { "perspectives": {"user": user_prompt, "ai": blip_caption}, "audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)}, "verdict": verdict }