| | |
| | import io, base64, torch |
| | from PIL import Image |
| | from transformers import CLIPModel, CLIPProcessor |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Custom zero‑shot classifier replicating local OpenAI‑CLIP logic. |
| | |
| | Client JSON must look like: |
| | { |
| | "inputs": { |
| | "image": "<base64 PNG/JPEG>", |
| | "candidate_labels": ["car", "teddy bear", ...] |
| | } |
| | } |
| | """ |
| |
|
| | |
| | def __init__(self, path: str = ""): |
| | |
| | self.model = CLIPModel.from_pretrained(path) |
| | self.processor = CLIPProcessor.from_pretrained(path) |
| |
|
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model.to(self.device).eval() |
| |
|
| | |
| | self.cache: dict[str, torch.Tensor] = {} |
| |
|
| | |
| | def __call__(self, data): |
| | payload = data.get("inputs", data) |
| |
|
| | img_b64 = payload["image"] |
| | names = payload.get("candidate_labels", []) |
| | if not names: |
| | return {"error": "candidate_labels list is empty"} |
| |
|
| | |
| | prompts = [f"a photo of a {p}" for p in names] |
| |
|
| | |
| | missing = [p for p in prompts if p not in self.cache] |
| | if missing: |
| | txt_in = self.processor(text=missing, return_tensors="pt", |
| | padding=True).to(self.device) |
| | with torch.no_grad(): |
| | txt_emb = self.model.get_text_features(**txt_in) |
| | txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True) |
| | for p, e in zip(missing, txt_emb): |
| | self.cache[p] = e |
| | txt_feat = torch.stack([self.cache[p] for p in prompts]) |
| |
|
| | |
| | img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
| | img_in = self.processor(images=img, return_tensors="pt").to(self.device) |
| |
|
| | with torch.no_grad(), torch.cuda.amp.autocast(): |
| | img_feat = self.model.get_image_features(**img_in) |
| | img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) |
| |
|
| | |
| | probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist() |
| |
|
| | |
| | return [ |
| | {"label": n, "score": float(p)} |
| | for n, p in sorted(zip(names, probs), key=lambda x: x[1], reverse=True) |
| | ] |
| |
|