| | import contextlib, io, base64, torch, json, os, threading |
| | from PIL import Image |
| | import open_clip |
| | from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd |
| | from safetensors.torch import save_file, load_file |
| | from reparam import reparameterize_model |
| |
|
| | ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") |
| | HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") |
| | HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "") |
| | HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN) |
| |
|
| |
|
| | def _fingerprint(device: str, dtype: torch.dtype) -> dict: |
| | return { |
| | "model_id": "MobileCLIP-B", |
| | "pretrained": "datacompdr", |
| | "open_clip": getattr(open_clip, "__version__", "unknown"), |
| | "torch": torch.__version__, |
| | "cuda": torch.version.cuda if torch.cuda.is_available() else None, |
| | "dtype_runtime": str(dtype), |
| | "text_norm": "L2", |
| | "logit_scale": 100.0, |
| | } |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
| |
|
| | |
| | model, _, self.preprocess = open_clip.create_model_and_transforms( |
| | "MobileCLIP-B", pretrained="datacompdr" |
| | ) |
| | model.eval() |
| | model = reparameterize_model(model) |
| | model.to(self.device) |
| | if self.device == "cuda": |
| | model = model.to(torch.float16) |
| | self.model = model |
| | self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") |
| | self.fingerprint = _fingerprint(self.device, self.dtype) |
| | self._lock = threading.Lock() |
| |
|
| | |
| | loaded = False |
| | if HF_LABEL_REPO: |
| | with contextlib.suppress(Exception): |
| | loaded = self._load_snapshot_from_hub_latest() |
| | if not loaded: |
| | items_path = "items.json" if not path else f"{path}/items.json" |
| | with open(items_path, "r", encoding="utf-8") as f: |
| | items = json.load(f) |
| | prompts = [it["prompt"] for it in items] |
| | self.class_ids = [int(it["id"]) for it in items] |
| | self.class_names = [it["name"] for it in items] |
| | with torch.no_grad(): |
| | toks = self.tokenizer(prompts).to(self.device) |
| | feats = self.model.encode_text(toks) |
| | feats = feats / feats.norm(dim=-1, keepdim=True) |
| | self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous() |
| | self._to_device() |
| | self.labels_version = 1 |
| |
|
| | def __call__(self, data): |
| | payload = data.get("inputs", data) |
| |
|
| | |
| | op = payload.get("op") |
| | if op == "upsert_labels": |
| | if payload.get("token") != ADMIN_TOKEN: |
| | return {"error": "unauthorized"} |
| | items = payload.get("items", []) or [] |
| | added = self._upsert_items(items) |
| | if added > 0: |
| | new_ver = int(getattr(self, "labels_version", 1)) + 1 |
| | try: |
| | self._persist_snapshot_to_hub(new_ver) |
| | self.labels_version = new_ver |
| | except Exception as e: |
| | return {"status": "error", "added": added, "detail": str(e)} |
| | return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)} |
| |
|
| | |
| | if op == "reload_labels": |
| | if payload.get("token") != ADMIN_TOKEN: |
| | return {"error": "unauthorized"} |
| | try: |
| | ver = int(payload.get("version")) |
| | except Exception: |
| | return {"error": "invalid_version"} |
| | ok = self._load_snapshot_from_hub_version(ver) |
| | return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)} |
| |
|
| | |
| | if op == "remove_labels": |
| | if payload.get("token") != ADMIN_TOKEN: |
| | return {"error": "unauthorized"} |
| | ids_to_remove = set(payload.get("ids", [])) |
| | if not ids_to_remove: |
| | return {"error": "no_ids_provided"} |
| | |
| | removed = self._remove_items(ids_to_remove) |
| | if removed > 0: |
| | new_ver = int(getattr(self, "labels_version", 1)) + 1 |
| | try: |
| | self._persist_snapshot_to_hub(new_ver) |
| | self.labels_version = new_ver |
| | except Exception as e: |
| | return {"status": "error", "removed": removed, "detail": str(e)} |
| | return {"status": "ok", "removed": removed, "labels_version": getattr(self, "labels_version", 1)} |
| |
|
| | |
| | min_ver = payload.get("min_labels_version") |
| | if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0): |
| | with contextlib.suppress(Exception): |
| | self._load_snapshot_from_hub_version(min_ver) |
| |
|
| | |
| | img_b64 = payload["image"] |
| | image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
| | img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
| | if self.device == "cuda": |
| | img_tensor = img_tensor.to(torch.float16) |
| | with torch.no_grad(): |
| | img_feat = self.model.encode_image(img_tensor) |
| | img_feat /= img_feat.norm(dim=-1, keepdim=True) |
| | probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0] |
| | results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist()) |
| | top_k = int(payload.get("top_k", len(self.class_ids))) |
| | return sorted( |
| | [{"id": i, "label": name, "score": float(p)} for i, name, p in results], |
| | key=lambda x: x["score"], |
| | reverse=True, |
| | )[:top_k] |
| |
|
| | |
| | def _encode_text(self, prompts): |
| | with torch.no_grad(): |
| | toks = self.tokenizer(prompts).to(self.device) |
| | feats = self.model.encode_text(toks) |
| | feats = feats / feats.norm(dim=-1, keepdim=True) |
| | return feats |
| |
|
| | def _to_device(self): |
| | self.text_features = self.text_features_cpu.to( |
| | self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32) |
| | ) |
| |
|
| | def _upsert_items(self, new_items): |
| | if not new_items: |
| | return 0 |
| | with self._lock: |
| | |
| | known_ids = set(getattr(self, "class_ids", [])) |
| | |
| | known_names_lower = set(name.lower() for name in getattr(self, "class_names", [])) |
| | |
| | |
| | batch = [] |
| | for it in new_items: |
| | item_id = int(it.get("id")) |
| | item_name = it.get("name") |
| | |
| | |
| | if item_id in known_ids: |
| | continue |
| | elif item_name.lower() in known_names_lower: |
| | continue |
| | else: |
| | batch.append(it) |
| | |
| | if not batch: |
| | return 0 |
| | |
| | |
| | prompts = [it["prompt"] for it in batch] |
| | feats = self._encode_text(prompts).detach().cpu().to(torch.float32) |
| | |
| | |
| | if not hasattr(self, "text_features_cpu"): |
| | self.text_features_cpu = feats.contiguous() |
| | self.class_ids = [int(it["id"]) for it in batch] |
| | self.class_names = [it["name"] for it in batch] |
| | else: |
| | self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous() |
| | self.class_ids.extend([int(it["id"]) for it in batch]) |
| | self.class_names.extend([it["name"] for it in batch]) |
| | |
| | self._to_device() |
| | return len(batch) |
| |
|
| | def _remove_items(self, ids_to_remove): |
| | if not ids_to_remove or not hasattr(self, "class_ids"): |
| | return 0 |
| | with self._lock: |
| | ids_to_remove = set(int(id) for id in ids_to_remove) |
| | |
| | indices_to_keep = [] |
| | removed_count = 0 |
| | for i, class_id in enumerate(self.class_ids): |
| | if class_id not in ids_to_remove: |
| | indices_to_keep.append(i) |
| | else: |
| | removed_count += 1 |
| | |
| | if removed_count == 0: |
| | return 0 |
| | |
| | |
| | if indices_to_keep: |
| | self.text_features_cpu = self.text_features_cpu[indices_to_keep].contiguous() |
| | self.class_ids = [self.class_ids[i] for i in indices_to_keep] |
| | self.class_names = [self.class_names[i] for i in indices_to_keep] |
| | else: |
| | |
| | self.text_features_cpu = torch.empty(0, self.text_features_cpu.shape[1]) |
| | self.class_ids = [] |
| | self.class_names = [] |
| | |
| | self._to_device() |
| | return removed_count |
| |
|
| | def _persist_snapshot_to_hub(self, version: int): |
| | if not HF_LABEL_REPO: |
| | raise RuntimeError("HF_LABEL_REPO not set") |
| | if not HF_WRITE_TOKEN: |
| | raise RuntimeError("HF_WRITE_TOKEN not set for publishing") |
| |
|
| | emb_path = "/tmp/embeddings.safetensors" |
| | meta_path = "/tmp/meta.json" |
| | latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8")) |
| |
|
| | save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path) |
| | meta = { |
| | "items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)], |
| | "fingerprint": self.fingerprint, |
| | "dims": int(self.text_features_cpu.shape[1]), |
| | "count": int(self.text_features_cpu.shape[0]), |
| | "version": int(version), |
| | } |
| | with open(meta_path, "w", encoding="utf-8") as f: |
| | json.dump(meta, f) |
| |
|
| | ops = [ |
| | CommitOperationAdd( |
| | path_in_repo=f"snapshots/v{version}/embeddings.safetensors", |
| | path_or_fileobj=emb_path |
| | ), |
| | CommitOperationAdd( |
| | path_in_repo=f"snapshots/v{version}/meta.json", |
| | path_or_fileobj=meta_path |
| | ), |
| | CommitOperationAdd( |
| | path_in_repo="snapshots/latest.json", |
| | path_or_fileobj=latest_bytes |
| | ), |
| | ] |
| | create_commit( |
| | repo_id=HF_LABEL_REPO, |
| | repo_type="dataset", |
| | operations=ops, |
| | token=HF_WRITE_TOKEN, |
| | commit_message=f"labels v{version}", |
| | ) |
| |
|
| | def _load_snapshot_from_hub_version(self, version: int) -> bool: |
| | if not HF_LABEL_REPO: |
| | return False |
| | with self._lock: |
| | emb_p = hf_hub_download( |
| | HF_LABEL_REPO, |
| | f"snapshots/v{version}/embeddings.safetensors", |
| | repo_type="dataset", |
| | token=HF_READ_TOKEN, |
| | force_download=True, |
| | ) |
| | meta_p = hf_hub_download( |
| | HF_LABEL_REPO, |
| | f"snapshots/v{version}/meta.json", |
| | repo_type="dataset", |
| | token=HF_READ_TOKEN, |
| | force_download=True, |
| | ) |
| | meta = json.load(open(meta_p, "r", encoding="utf-8")) |
| | if meta.get("fingerprint") != self.fingerprint: |
| | raise RuntimeError("Embedding/model fingerprint mismatch") |
| | feats = load_file(emb_p)["embeddings"] |
| | self.text_features_cpu = feats.contiguous() |
| | self.class_ids = [int(x["id"]) for x in meta.get("items", [])] |
| | self.class_names = [x["name"] for x in meta.get("items", [])] |
| | self.labels_version = int(meta.get("version", version)) |
| | self._to_device() |
| | return True |
| |
|
| | def _load_snapshot_from_hub_latest(self) -> bool: |
| | if not HF_LABEL_REPO: |
| | return False |
| | try: |
| | latest_p = hf_hub_download( |
| | HF_LABEL_REPO, |
| | "snapshots/latest.json", |
| | repo_type="dataset", |
| | token=HF_READ_TOKEN, |
| | ) |
| | except Exception: |
| | return False |
| | latest = json.load(open(latest_p, "r", encoding="utf-8")) |
| | ver = int(latest.get("version", 0)) |
| | if ver <= 0: |
| | return False |
| | return self._load_snapshot_from_hub_version(ver) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|