SaniaE's picture
revamped complete API structure
cf0f372 verified
import os
import io
import asyncio
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import StreamingResponse
from huggingface_hub import snapshot_download, login
from transformers import (
BlipProcessor, BlipForConditionalGeneration,
ViTImageProcessor, AutoProcessor, AutoModelForCausalLM,
CLIPModel, CLIPProcessor
)
app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
# --- Configuration & Paths ---
REPO_ID = "SaniaE/Image_Captioning_Ensemble"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS = {}
# Metadata for loading
MODEL_CONFIGS = {
"blip": {
"subfolder": "blip",
"proc_class": BlipProcessor,
"model_class": BlipForConditionalGeneration,
"base_path": "Salesforce/blip-image-captioning-large"
},
"vit": {
"subfolder": "vit",
"proc_classes": [ViTImageProcessor, AutoProcessor],
"model_class": AutoModelForCausalLM,
"base_paths": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"]
},
"clip": {
"model_subfolder": "clip/clip_model",
"proc_subfolder": "clip/clip_processor"
}
}
@app.on_event("startup")
async def startup_event():
global MODELS
token = os.getenv("HF_Token")
if token: login(token=token)
print(f"Syncing weights from {REPO_ID}...")
local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
# 1. Load BLIP
cfg_b = MODEL_CONFIGS["blip"]
MODELS["blip"] = {
"model": cfg_b["model_class"].from_pretrained(os.path.join(local_dir, cfg_b["subfolder"])).to(DEVICE),
"processor": cfg_b["proc_class"].from_pretrained(cfg_b["base_path"])
}
# 2. Load ViT/GIT Ensemble
cfg_v = MODEL_CONFIGS["vit"]
MODELS["vit"] = {
"model": cfg_v["model_class"].from_pretrained(os.path.join(local_dir, cfg_v["subfolder"])).to(DEVICE),
"processor": (
cfg_v["proc_classes"][0].from_pretrained(cfg_v["base_paths"][0]),
cfg_v["proc_classes"][1].from_pretrained(cfg_v["base_paths"][1])
)
}
# 3. Load Fine-Tuned CLIP (Your Jury)
cfg_c = MODEL_CONFIGS["clip"]
MODELS["clip"] = {
"model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
"processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
}
print("All models synchronized. Auditor is active.")
# --- Utilities ---
def _generate_sync(m_name, image, temp, top_k, top_p):
m_data = MODELS[m_name]
if m_name == "vit":
i_proc, t_proc = m_data["processor"]
inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
return t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
else:
proc = m_data["processor"]
inputs = proc(images=image, return_tensors="pt").to(DEVICE)
ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
return proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
# --- Endpoints ---
@app.post("/generate")
async def generate_captions(
file: UploadFile = File(...),
temp: float = Query(0.8),
top_k: int = Query(50),
top_p: float = Query(0.9)
):
"""Generates 5 diverse captions using the model ensemble."""
image = Image.open(file.file).convert("RGB")
architectures = ["blip", "vit"]
selection = random.choices(architectures, k=5)
tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in selection]
captions = await asyncio.gather(*tasks)
return {"captions": captions, "metadata": {"models_used": selection, "temp": temp}}
@app.post("/saliency")
async def get_vision_saliency(file: UploadFile = File(...)):
"""Objective Saliency: Shows what the Vision Encoder focuses on (Self-Attention)."""
image_bytes = await file.read()
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
blip = MODELS["blip"]
inputs = blip["processor"](images=orig_img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = blip["model"].vision_model(inputs.pixel_values, output_attentions=True)
attentions = outputs.attentions[-1] # Last layer
# Average heads, look at CLS token attention to patches
mask_1d = attentions[0, :, 0, 1:].mean(dim=0)
grid_size = int(np.sqrt(mask_1d.shape[-1]))
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=10))
heatmap = plt.get_cmap('magma')(np.array(mask_img)/255.0)
heatmap_img = Image.fromarray((heatmap[:, :, :3] * 255).astype('uint8')).convert("RGB")
blended = Image.blend(orig_img, heatmap_img, alpha=0.6)
buf = io.BytesIO()
blended.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
@app.post("/audit")
async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
"""The CLIP-Powered Jury: Compares User Intent vs. Model Perception."""
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# 1. Model Perception
blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9)
# 2. CLIP Scoring (Multimodal Alignment)
clip_m = MODELS["clip"]["model"]
clip_p = MODELS["clip"]["processor"]
inputs = clip_p(text=[user_prompt, blip_caption], images=image, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
outputs = clip_m(**inputs)
probs = outputs.logits_per_image.softmax(dim=-1).cpu().numpy()[0]
u_score, m_score = float(probs[0]), float(probs[1])
# 3. Decision Logic
if u_score < 0.35:
verdict = "Perspective Divergence: Intent not grounded in image."
elif abs(u_score - m_score) < 0.15:
verdict = "Consensus: High Alignment."
else:
verdict = "Model Bias Detected."
return {
"perspectives": {"user": user_prompt, "ai": blip_caption},
"audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},
"verdict": verdict
}