File size: 2,938 Bytes
278b37e
 
5bd1204
278b37e
 
5bd1204
 
 
278b37e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bd1204
 
 
 
 
 
 
 
 
278b37e
 
5bd1204
278b37e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# controller.py — real VQA inference using BLIP (small, fast, no extra weights)
# Works on CPU Space. Uses HF Hub to download the model at first run.

import os
import torch
from PIL import Image
from typing import Tuple

from transformers import BlipForQuestionAnswering, BlipProcessor

# ---------------------------
# Load once at import time
# ---------------------------
HF_MODEL = os.getenv("HF_VQA_MODEL", "Salesforce/blip-vqa-base")  # small & good
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

_processor = None
_model = None

def _load():
    global _processor, _model
    if _processor is None or _model is None:
        _processor = BlipProcessor.from_pretrained(HF_MODEL)
        _model = BlipForQuestionAnswering.from_pretrained(HF_MODEL)
        _model.to(DEVICE)
        _model.eval()

def _answer_baseline(image: Image.Image, question: str) -> str:
    _load()
    inputs = _processor(images=image, text=question, return_tensors="pt").to(DEVICE)
    with torch.inference_mode():
        out = _model.generate(**inputs, max_new_tokens=10)
    ans = _processor.decode(out[0], skip_special_tokens=True)
    return ans.strip()

# --- optional future hooks (no-ops for now, keep API stable) ---
def _answer_with_memory(image: Image.Image, question: str) -> str:
    # Plug your FAISS/RAG here; fallback to baseline for now
    return _answer_baseline(image, question)

def _gate_auto(image: Image.Image, question: str) -> Tuple[int, str]:
    # When PPO or distilled are wired, pick actions here. For now: baseline (0).
    return 0, "baseline"

def _gate_distilled(image: Image.Image, question: str) -> Tuple[int, str]:
    # TODO: call your distilled classifier; fallback to baseline
    return 0, "baseline"

def _gate_ppo(image: Image.Image, question: str) -> Tuple[int, str]:
    # TODO: call your PPO policy; fallback to baseline
    return 0, "baseline"

# ---------------------------
# Public API for app.py
# ---------------------------
def answer_with_controller(
    image: Image.Image,
    question: str,
    source: str = "auto",
    distilled_model: str = "auto",
) -> Tuple[str, str, int]:
    """
    Returns:
        pred (str): predicted answer
        strategy_name (str): chosen strategy name
        action_id (int): numeric action (0=baseline, 1=memory in future, etc.)
    """
    source = (source or "auto").lower()

    if source == "baseline":
        ans = _answer_baseline(image, question)
        return ans, "baseline", 0
    elif source == "distilled":
        aid, label = _gate_distilled(image, question)
    elif source == "ppo":
        aid, label = _gate_ppo(image, question)
    else:  # auto
        aid, label = _gate_auto(image, question)

    # route by action id (for now all paths use baseline until you wire memory)
    if aid == 1:
        ans = _answer_with_memory(image, question)
    else:
        ans = _answer_baseline(image, question)

    return ans, label, aid