| |
| |
|
|
| import os |
| import torch |
| from PIL import Image |
| from typing import Tuple |
|
|
| from transformers import BlipForQuestionAnswering, BlipProcessor |
|
|
| |
| |
| |
| HF_MODEL = os.getenv("HF_VQA_MODEL", "Salesforce/blip-vqa-base") |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| _processor = None |
| _model = None |
|
|
| def _load(): |
| global _processor, _model |
| if _processor is None or _model is None: |
| _processor = BlipProcessor.from_pretrained(HF_MODEL) |
| _model = BlipForQuestionAnswering.from_pretrained(HF_MODEL) |
| _model.to(DEVICE) |
| _model.eval() |
|
|
| def _answer_baseline(image: Image.Image, question: str) -> str: |
| _load() |
| inputs = _processor(images=image, text=question, return_tensors="pt").to(DEVICE) |
| with torch.inference_mode(): |
| out = _model.generate(**inputs, max_new_tokens=10) |
| ans = _processor.decode(out[0], skip_special_tokens=True) |
| return ans.strip() |
|
|
| |
| def _answer_with_memory(image: Image.Image, question: str) -> str: |
| |
| return _answer_baseline(image, question) |
|
|
| def _gate_auto(image: Image.Image, question: str) -> Tuple[int, str]: |
| |
| return 0, "baseline" |
|
|
| def _gate_distilled(image: Image.Image, question: str) -> Tuple[int, str]: |
| |
| return 0, "baseline" |
|
|
| def _gate_ppo(image: Image.Image, question: str) -> Tuple[int, str]: |
| |
| return 0, "baseline" |
|
|
| |
| |
| |
| def answer_with_controller( |
| image: Image.Image, |
| question: str, |
| source: str = "auto", |
| distilled_model: str = "auto", |
| ) -> Tuple[str, str, int]: |
| """ |
| Returns: |
| pred (str): predicted answer |
| strategy_name (str): chosen strategy name |
| action_id (int): numeric action (0=baseline, 1=memory in future, etc.) |
| """ |
| source = (source or "auto").lower() |
|
|
| if source == "baseline": |
| ans = _answer_baseline(image, question) |
| return ans, "baseline", 0 |
| elif source == "distilled": |
| aid, label = _gate_distilled(image, question) |
| elif source == "ppo": |
| aid, label = _gate_ppo(image, question) |
| else: |
| aid, label = _gate_auto(image, question) |
|
|
| |
| if aid == 1: |
| ans = _answer_with_memory(image, question) |
| else: |
| ans = _answer_baseline(image, question) |
|
|
| return ans, label, aid |
|
|