SaniaE commited on
Commit
b5397cf
·
verified ·
1 Parent(s): c441112

updated endpoint logic

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -159,93 +159,84 @@ async def ui_tester(file: UploadFile = File(...), description: str = Query(...))
159
  async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Query(...)):
160
  image = Image.open(file.file).convert("RGB")
161
  blip = MODELS["blip"]
162
-
163
- # 1. Model Baseline (Generating its own perception)
164
  inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
165
  with torch.no_grad():
166
  generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
167
  model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
168
 
169
- # 2. Embedding Calculation
170
- texts = [user_prompt, model_caption]
171
- inputs_text = blip["processor"](text=texts, return_tensors="pt", padding=True).to(DEVICE)
 
 
 
 
 
 
 
 
 
172
 
 
173
  with torch.no_grad():
174
- # 1. Get Image Embeddings from the vision_model
175
  vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
176
- image_embeds = vision_outputs.last_hidden_state[:, 0, :] # Use [CLS] token
177
-
178
- # 2. Get Text Embeddings using the text_decoder's bert model
179
- # BLIP's text_decoder typically wraps a BERT-like architecture
180
- text_outputs = blip["model"].text_decoder.bert(**inputs_text)
181
- text_embeds = text_outputs.last_hidden_state[:, 0, :] # Use [CLS] token
182
-
183
- # Normalize
184
- image_embeds = F.normalize(image_embeds, p=2, dim=-1)
185
- text_embeds = F.normalize(text_embeds, p=2, dim=-1)
186
-
187
- # Similarity Matrix calculation
188
- sim_image_user = torch.matmul(image_embeds, text_embeds[0].T).item()
189
- sim_image_model = torch.matmul(image_embeds, text_embeds[1].T).item()
190
- sim_user_model = torch.matmul(text_embeds[0], text_embeds[1].T).item()
191
 
192
  return {
193
- "captions": {
194
- "user": user_prompt,
195
- "model_best_guess": model_caption
196
- },
197
  "similarity_scores": {
198
- "visual_alignment_user": round(float(sim_image_user), 4),
199
- "visual_alignment_model": round(float(sim_image_model), 4),
200
- "semantic_overlap": round(float(sim_user_model), 4)
201
  },
202
- "interpretation": "Strong Agreement" if sim_user_model > 0.85 else "Diverse Perspectives"
203
  }
204
 
 
205
  @app.post("/saliency-explorer/image")
206
  async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Query(...)):
207
- # 1. Load Image
208
  image_bytes = await file.read()
209
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
210
 
211
  blip = MODELS["blip"]
212
- # Ensure pixel_values can track gradients
213
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
214
- inputs.pixel_values.requires_grad = True
215
 
216
- # 2. Extract Gradients for Saliency
217
- outputs = blip["model"](**inputs, labels=inputs["input_ids"])
218
- loss = outputs.loss
219
- loss.backward()
220
-
221
- # Generate Saliency from gradients of pixel values
222
- # We take the maximum absolute gradient across the RGB channels
223
- grad = inputs.pixel_values.grad.abs().max(dim=1)[0][0].cpu().numpy()
224
-
225
- # 3. Create Heatmap with "Glow" Effect (XAI Style)
226
- # Normalize to [0, 1]
227
- grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
228
-
229
- # Apply Gaussian Blur to smooth tiny speckles into a professional heatmap
230
- grad_pill = Image.fromarray((grad * 255).astype('uint8'))
231
- grad_pill = grad_pill.filter(ImageFilter.GaussianBlur(radius=8))
232
- grad_smoothed = np.array(grad_pill) / 255.0
233
 
234
- # Apply colormap (jet)
 
 
 
 
235
  cm = plt.get_cmap('jet')
236
- heatmap_rgba = cm(grad_smoothed)
237
-
238
- # Convert heatmap to PIL and resize to original image dimensions
239
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
240
- heatmap_img = heatmap_img.resize(orig_img.size, resample=Image.BILINEAR)
241
 
242
- # 4. Blend Original + Heatmap (Adjust alpha for visibility on dark/light UIs)
243
- # 0.5 alpha provides a strong clear highlight for the "Rorompok" sofa
244
  blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
245
 
246
- # 5. Stream back
247
  buf = io.BytesIO()
248
  blended_img.save(buf, format="PNG")
249
  buf.seek(0)
250
-
251
  return StreamingResponse(buf, media_type="image/png")
 
159
  async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Query(...)):
160
  image = Image.open(file.file).convert("RGB")
161
  blip = MODELS["blip"]
162
+
163
+ # Get model's caption
164
  inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
165
  with torch.no_grad():
166
  generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
167
  model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
168
 
169
+ # 1. NEW: Localized Keyword Embedding
170
+ # We focus on the core nouns and adjectives to prevent 'template bias'
171
+ def get_focused_embedding(text):
172
+ inputs = blip["processor"](text=text, return_tensors="pt", padding=True).to(DEVICE)
173
+ with torch.no_grad():
174
+ # Get output from the BERT-based text decoder
175
+ outputs = blip["model"].text_decoder.bert(**inputs)
176
+ # Average hidden states of ALL tokens to capture keyword specifics
177
+ return F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=-1)
178
+
179
+ user_embed = get_focused_embedding(user_prompt)
180
+ model_embed = get_focused_embedding(model_caption)
181
 
182
+ # Visual alignment
183
  with torch.no_grad():
 
184
  vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
185
+ image_embed = F.normalize(vision_outputs.last_hidden_state[:, 0, :], p=2, dim=-1)
186
+
187
+ # 2. Calculate Corrected Scores
188
+ sim_image_user = torch.matmul(image_embed, user_embed.T).item()
189
+ sim_image_model = torch.matmul(image_embed, model_embed.T).item()
190
+ sim_user_model = torch.matmul(user_embed, model_embed.T).item()
 
 
 
 
 
 
 
 
 
191
 
192
  return {
193
+ "captions": {"user": user_prompt, "model": model_caption},
 
 
 
194
  "similarity_scores": {
195
+ "visual_alignment_user": round(sim_image_user, 4),
196
+ "visual_alignment_model": round(sim_image_model, 4),
197
+ "semantic_overlap": round(sim_user_model, 4)
198
  },
199
+ "interpretation": "Strong Agreement" if sim_user_model > 0.8 else "Perspective Divergence"
200
  }
201
 
202
+
203
  @app.post("/saliency-explorer/image")
204
  async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Query(...)):
 
205
  image_bytes = await file.read()
206
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
207
 
208
  blip = MODELS["blip"]
209
+ # We enable 'output_attentions' to grab the internal map directly
210
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
 
211
 
212
+ with torch.no_grad():
213
+ outputs = blip["model"](**inputs, output_attentions=True)
214
+ # Use the last layer of vision encoder self-attention
215
+ # Shape: (batch, heads, patches, patches)
216
+ attentions = outputs.vision_model_output.attentions[-1]
217
+
218
+ # Average across heads and take the attention from the [CLS] token to all patches
219
+ # Patch size for BLIP is typically 14x14 or 16x16
220
+ grid_size = int(np.sqrt(attentions.shape[-1] - 1))
221
+ # Remove [CLS] token and reshape to grid
222
+ mask = attentions[0, :, 0, 1:].mean(0).view(grid_size, grid_size).cpu().numpy()
223
+
224
+ # 1. Normalize and Upscale
225
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
226
+ mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
 
 
227
 
228
+ # 2. Apply Gaussian Glow for XAI Aesthetic
229
+ mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=15))
230
+ mask_final = np.array(mask_pill) / 255.0
231
+
232
+ # 3. Apply Colormap and Blend
233
  cm = plt.get_cmap('jet')
234
+ heatmap_rgba = cm(mask_final)
 
 
235
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
 
236
 
 
 
237
  blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
238
 
 
239
  buf = io.BytesIO()
240
  blended_img.save(buf, format="PNG")
241
  buf.seek(0)
 
242
  return StreamingResponse(buf, media_type="image/png")