SaniaE commited on
Commit
4debe0a
·
verified ·
1 Parent(s): 0129fbe

revamped entire API logic

Browse files
Files changed (1) hide show
  1. app.py +87 -147
app.py CHANGED
@@ -2,29 +2,27 @@ import os
2
  import torch
3
  import random
4
  import asyncio
 
 
 
5
  from PIL import Image, ImageFilter
6
  from fastapi import FastAPI, UploadFile, File, Query
7
- from fastapi.middleware.cors import CORSMiddleware
8
  from huggingface_hub import snapshot_download, login
 
 
9
  from transformers import (
10
  BlipProcessor, BlipForConditionalGeneration,
11
  ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
12
  )
13
- import torch.nn.functional as F
14
- import numpy as np
15
- import io
16
- from fastapi.responses import StreamingResponse
17
- import matplotlib.pyplot as plt
18
-
19
 
20
- app = FastAPI()
21
 
22
- # Configuration
23
  REPO_ID = "SaniaE/Image_Captioning_Ensemble"
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  MODELS = {}
26
 
27
- # Removed GIT, kept BLIP and ViT
28
  MODEL_SETTINGS = {
29
  "blip": {
30
  "subfolder": "blip",
@@ -53,10 +51,8 @@ async def startup_event():
53
  ckpt_path = os.path.join(local_dir, cfg["subfolder"])
54
  print(f"Loading {name} from {ckpt_path}...")
55
 
56
- # Load Model
57
  model = cfg["inference_model"].from_pretrained(ckpt_path).to(DEVICE)
58
 
59
- # Load Processor
60
  if name == "vit":
61
  i_proc = cfg["processor"][0].from_pretrained(cfg["pretrained_path"][0])
62
  t_proc = cfg["processor"][1].from_pretrained(cfg["pretrained_path"][1])
@@ -65,183 +61,127 @@ async def startup_event():
65
  processor = cfg["processor"].from_pretrained(cfg["pretrained_path"])
66
 
67
  MODELS[name] = {"model": model, "processor": processor}
68
- print("Optimization Complete: GIT and Search removed. Ensemble is live!")
69
 
70
- # --- Helper for Parallel Inference ---
71
- def _generate_sync(m_name, image, temp, top_k, top_p):
 
 
72
  m_data = MODELS[m_name]
73
  model = m_data["model"]
74
 
75
  if m_name == "vit":
76
  i_proc, t_proc = m_data["processor"]
77
  inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
78
- gen_ids = model.generate(
79
- **inputs, max_length=300, do_sample=True,
80
- temperature=temp, top_k=top_k, top_p=top_p
81
- )
82
  return t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
83
  else:
84
  proc = m_data["processor"]
85
  inputs = proc(images=image, return_tensors="pt").to(DEVICE)
86
- gen_ids = model.generate(
87
- **inputs, max_length=300, do_sample=True,
88
- temperature=temp, top_k=top_k, top_p=top_p
89
- )
90
  return proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
91
 
92
- @app.post("/generate")
93
- async def generate_endpoint(
94
- file: UploadFile = File(...),
95
- temp: float = Query(0.8),
96
- top_k: int = Query(100),
97
- top_p: float = Query(0.9)
98
- ):
99
- image = Image.open(file.file).convert("RGB")
100
- available = list(MODELS.keys()) # Only blip and vit
101
- # Create 5 slots from the 2 remaining models
102
- model_selection = random.choices(available, k=5)
103
-
104
- tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in model_selection]
105
- captions = await asyncio.gather(*tasks)
106
-
107
- return {"captions": captions, "mix": model_selection}
108
 
109
- @app.post("/ui-tester")
110
- async def ui_tester(file: UploadFile = File(...), description: str = Query(...)):
111
- image = Image.open(file.file).convert("RGB")
112
- blip_data = MODELS["blip"]
113
 
114
- # 1. GET THE BASELINE (The model's "Perfect" loss for its own perception)
115
- # We generate a caption using high-precision parameters to see its "truth"
116
- inputs_gen = blip_data["processor"](images=image, return_tensors="pt").to(DEVICE)
117
- with torch.no_grad():
118
- generated_ids = blip_data["model"].generate(
119
- **inputs_gen,
120
- max_length=50,
121
- num_beams=5, # Higher beams for a more stable "best guess"
122
- temperature=1.0
123
- )
124
- baseline_caption = blip_data["processor"].decode(generated_ids[0], skip_special_tokens=True)
125
-
126
- # Calculate loss for the model's own generated caption
127
- baseline_inputs = blip_data["processor"](images=image, text=baseline_caption, return_tensors="pt").to(DEVICE)
128
- baseline_outputs = blip_data["model"](**baseline_inputs, labels=baseline_inputs["input_ids"])
129
- baseline_loss = baseline_outputs.loss.item()
130
-
131
- # 2. CALCULATE USER LOSS
132
- user_inputs = blip_data["processor"](images=image, text=description, return_tensors="pt").to(DEVICE)
133
- with torch.no_grad():
134
- user_outputs = blip_data["model"](**user_inputs, labels=user_inputs["input_ids"])
135
- user_loss = user_outputs.loss.item()
136
-
137
- # 3. RELATIVE SCORING (The "Intuition" Fix)
138
- # This ratio tells us how close the user is to the model's internal maximum confidence
139
- relative_ratio = baseline_loss / user_loss
140
 
141
- # Scaling: If the user matches the model's perception, they get ~95%.
142
- # If they are significantly off (like Orange vs Yellow), they land in the 60s.
143
- # This prevents the 0% "confusion ceiling" you saw earlier.
144
- confidence_score = min(100.0, round((relative_ratio ** 1.5) * 100, 2))
145
-
146
  return {
147
- "confidence_score": f"{confidence_score}%",
148
- "model_perceived_caption": baseline_caption,
149
- "raw_metrics": {
150
- "user_loss": round(user_loss, 4),
151
- "baseline_loss": round(baseline_loss, 4),
152
- "delta": round(user_loss - baseline_loss, 4)
153
- },
154
- "status": "Match Found" if confidence_score > 55 else "Partial Match" if confidence_score > 30 else "No Match",
155
- "is_valid": confidence_score > 55
156
  }
157
 
158
- @app.post("/concept-ensemble")
159
- async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Query(...)):
160
- image = Image.open(file.file).convert("RGB")
161
- blip = MODELS["blip"]
162
-
163
- inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
164
- with torch.no_grad():
165
- generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
166
- model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
167
-
168
- def get_clean_embedding(text):
169
- inputs = blip["processor"](text=text, return_tensors="pt", padding=True).to(DEVICE)
170
- with torch.no_grad():
171
- outputs = blip["model"].text_decoder.bert(**inputs)
172
- return F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=-1)
173
-
174
- user_embed = get_clean_embedding(user_prompt)
175
- model_embed = get_clean_embedding(model_caption)
176
-
177
- # --- MLE TRICK: Word-Level Calibration ---
178
- # This prevents 'Pink Cafe' and 'Yellow Sofa' from being 0.99
179
- user_words = set(user_prompt.lower().split())
180
- model_words = set(model_caption.lower().split())
181
- intersection = user_words.intersection(model_words)
182
- union = user_words.union(model_words)
183
- jaccard_sim = len(intersection) / len(union) if len(union) > 0 else 0
184
-
185
- # Calculate raw embedding similarity
186
- raw_sim = torch.matmul(user_embed, model_embed.T).item()
187
-
188
- # Weighted Similarity: Combine vector meaning with actual word overlap
189
- # This will pull the 0.99 score down if the keywords don't match
190
- calibrated_overlap = (raw_sim * 0.4) + (jaccard_sim * 0.6)
191
 
192
- # Visual alignment
193
- with torch.no_grad():
194
- vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
195
- image_embed = F.normalize(vision_outputs.last_hidden_state[:, 0, :], p=2, dim=-1)
196
- sim_image_user = torch.matmul(image_embed, user_embed.T).item()
197
-
198
- return {
199
- "captions": {"user": user_prompt, "model": model_caption},
200
- "similarity_scores": {
201
- "semantic_overlap": round(calibrated_overlap, 4),
202
- "visual_alignment": round(sim_image_user, 4),
203
- "word_match_penalty": round(1 - jaccard_sim, 2)
204
- },
205
- "interpretation": "Perspective Divergence" if calibrated_overlap < 0.6 else "Strong Agreement"
206
- }
207
-
208
-
209
- @app.post("/saliency-explorer/image")
210
- async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Query(...)):
211
  image_bytes = await file.read()
212
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
213
 
214
  blip = MODELS["blip"]
215
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
216
 
217
- # We use the text_decoder because that's where the image and text actually 'meet'
218
  with torch.no_grad():
 
219
  outputs = blip["model"].text_decoder(
220
  input_ids=inputs.input_ids,
221
  attention_mask=inputs.attention_mask,
222
- encoder_hidden_states=blip["model"].vision_model(inputs.pixel_values).last_hidden_state,
223
- output_attentions=True # This is key
224
  )
225
 
226
- # Get Cross-Attentions (the link between text and image)
227
- # Shape: (layers, batch, heads, text_tokens, image_patches)
228
  cross_attentions = outputs.cross_attentions[-1]
229
-
230
- mask_1d = cross_attentions[0, :, 1:-1, 1:].mean(dim=(0, 1)) # Note the 1: at the end
231
- grid_size = int(np.sqrt(mask_1d.shape[-1])) # This will now be 24
232
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
233
 
234
- # Normalize and create the "Glow"
235
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
236
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
237
- mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12)) # The XAI Glow
238
 
239
  heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
240
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
241
-
242
  blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
243
 
244
  buf = io.BytesIO()
245
  blended_img.save(buf, format="PNG")
246
  buf.seek(0)
247
- return StreamingResponse(buf, media_type="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import random
4
  import asyncio
5
+ import io
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
  from PIL import Image, ImageFilter
9
  from fastapi import FastAPI, UploadFile, File, Query
10
+ from fastapi.responses import StreamingResponse
11
  from huggingface_hub import snapshot_download, login
12
+ import torch.nn.functional as F
13
+
14
  from transformers import (
15
  BlipProcessor, BlipForConditionalGeneration,
16
  ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
17
  )
 
 
 
 
 
 
18
 
19
+ app = FastAPI(title="XAI Auditor Ensemble")
20
 
21
+ # --- Configuration & State ---
22
  REPO_ID = "SaniaE/Image_Captioning_Ensemble"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  MODELS = {}
25
 
 
26
  MODEL_SETTINGS = {
27
  "blip": {
28
  "subfolder": "blip",
 
51
  ckpt_path = os.path.join(local_dir, cfg["subfolder"])
52
  print(f"Loading {name} from {ckpt_path}...")
53
 
 
54
  model = cfg["inference_model"].from_pretrained(ckpt_path).to(DEVICE)
55
 
 
56
  if name == "vit":
57
  i_proc = cfg["processor"][0].from_pretrained(cfg["pretrained_path"][0])
58
  t_proc = cfg["processor"][1].from_pretrained(cfg["pretrained_path"][1])
 
61
  processor = cfg["processor"].from_pretrained(cfg["pretrained_path"])
62
 
63
  MODELS[name] = {"model": model, "processor": processor}
64
+ print("Optimization Complete: Ensemble is live!")
65
 
66
+ # --- Core Logic Helpers ---
67
+
68
+ def _generate_sync(m_name, image, temp=0.7):
69
+ """Synchronous generator tailored for the specific architecture."""
70
  m_data = MODELS[m_name]
71
  model = m_data["model"]
72
 
73
  if m_name == "vit":
74
  i_proc, t_proc = m_data["processor"]
75
  inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
76
+ gen_ids = model.generate(**inputs, max_length=50, do_sample=True, temperature=temp)
 
 
 
77
  return t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
78
  else:
79
  proc = m_data["processor"]
80
  inputs = proc(images=image, return_tensors="pt").to(DEVICE)
81
+ gen_ids = model.generate(**inputs, max_length=50, do_sample=True, temperature=temp)
 
 
 
82
  return proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
83
 
84
+ # --- Endpoint 1: The Multi-Perspective Generator ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ @app.post("/generate-caption")
87
+ async def generate_caption(file: UploadFile = File(...), temp: float = Query(0.7)):
88
+ image_bytes = await file.read()
89
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
90
 
91
+ # Run both architectures in parallel
92
+ tasks = [
93
+ asyncio.to_thread(_generate_sync, "blip", image, temp),
94
+ asyncio.to_thread(_generate_sync, "vit", image, temp)
95
+ ]
96
+ captions = await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
98
  return {
99
+ "blip_caption": captions[0],
100
+ "vit_git_caption": captions[1]
 
 
 
 
 
 
 
101
  }
102
 
103
+ # --- Endpoint 2: The Saliency Explorer (XAI Glow) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ @app.post("/saliency-explorer")
106
+ async def get_saliency_map(file: UploadFile = File(...), query_text: str = Query(...)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  image_bytes = await file.read()
108
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
109
 
110
  blip = MODELS["blip"]
111
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
112
 
 
113
  with torch.no_grad():
114
+ vision_hidden = blip["model"].vision_model(inputs.pixel_values).last_hidden_state
115
  outputs = blip["model"].text_decoder(
116
  input_ids=inputs.input_ids,
117
  attention_mask=inputs.attention_mask,
118
+ encoder_hidden_states=vision_hidden,
119
+ output_attentions=True
120
  )
121
 
122
+ # Slicing out the [CLS] token from cross-attentions
 
123
  cross_attentions = outputs.cross_attentions[-1]
124
+ mask_1d = cross_attentions[0, :, 1:-1, 1:].mean(dim=(0, 1))
125
+ grid_size = int(np.sqrt(mask_1d.shape[-1]))
 
126
  mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
127
 
128
+ # Normalization & XAI Glow Application
129
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
130
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
131
+ mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12))
132
 
133
  heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
134
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
 
135
  blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
136
 
137
  buf = io.BytesIO()
138
  blended_img.save(buf, format="PNG")
139
  buf.seek(0)
140
+ return StreamingResponse(buf, media_type="image/png")
141
+
142
+ # --- Endpoint 3: Internal Debate (Audit Mode) ---
143
+
144
+ @app.post("/internal-debate")
145
+ async def internal_debate(file: UploadFile = File(...), user_prompt: str = Query(...)):
146
+ image_bytes = await file.read()
147
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
148
+
149
+ # 1. Gather model perceptions
150
+ blip_caption = await asyncio.to_thread(_generate_sync, "blip", image)
151
+ vit_caption = await asyncio.to_thread(_generate_sync, "vit", image)
152
+
153
+ # 2. Semantic Embedding Logic
154
+ blip_data = MODELS["blip"]
155
+ def get_emb(text):
156
+ inputs = blip_data["processor"](text=text, return_tensors="pt", padding=True).to(DEVICE)
157
+ with torch.no_grad():
158
+ return F.normalize(blip_data["model"].text_decoder.bert(**inputs).last_hidden_state.mean(dim=1), p=2, dim=-1)
159
+
160
+ u_emb = get_emb(user_prompt)
161
+ b_emb = get_emb(blip_caption)
162
+ v_emb = get_emb(vit_caption)
163
+
164
+ # 3. MLE Calibration (Jaccard Weighting)
165
+ def calibrate(emb1, emb2, t1, t2):
166
+ s1, s2 = set(t1.lower().split()), set(t2.lower().split())
167
+ jaccard = len(s1 & s2) / len(s1 | s2) if s1 | s2 else 0
168
+ cosine = torch.matmul(emb1, emb2.T).item()
169
+ return (cosine * 0.4) + (jaccard * 0.6)
170
+
171
+ score_blip = calibrate(u_emb, b_emb, user_prompt, blip_caption)
172
+ score_vit = calibrate(u_emb, v_emb, user_prompt, vit_caption)
173
+ consensus = calibrate(b_emb, v_emb, blip_caption, vit_caption)
174
+
175
+ return {
176
+ "perspectives": {
177
+ "user_intent": user_prompt,
178
+ "blip_view": blip_caption,
179
+ "vit_git_view": vit_caption
180
+ },
181
+ "audit_metrics": {
182
+ "user_vs_blip": round(score_blip, 4),
183
+ "user_vs_vit": round(score_vit, 4),
184
+ "inter_model_consensus": round(consensus, 4)
185
+ },
186
+ "verdict": "Consensus" if consensus > 0.65 else "Perspective Divergence"
187
+ }