Spaces:
Sleeping
Sleeping
| """GuardLLM - Precompute Embeddings & t-SNE (resumable, MiniLM-based). | |
| Uses sentence-transformers/all-MiniLM-L6-v2 (22M params) to compute | |
| embeddings for t-SNE visualization. The downstream risk classifier | |
| (Llama Prompt Guard 2) is *not* loaded here - it is loaded by the | |
| Gradio app on-demand when a user clicks a point. | |
| """ | |
| import sys, os, json, logging, time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger("precompute") | |
| CACHE_DIR = Path(__file__).parent / "cache" | |
| CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz" | |
| META_FILE = CACHE_DIR / "metadata.json" | |
| SAMPLES_FILE = CACHE_DIR / "samples.json" | |
| EMB_CHUNKS_DIR = CACHE_DIR / "emb_chunks_mini" # NEW folder so old chunks don't collide | |
| EMB_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" | |
| DATASET_ID = "neuralchemy/Prompt-injection-dataset" | |
| DATASET_CONFIG = "core" | |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32")) | |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "256")) | |
| TSNE_PERPLEXITY = 30 | |
| TSNE_SEED = 42 | |
| SAMPLE_SIZE = int(os.environ.get("SAMPLE_SIZE", "0")) or None | |
| TIME_BUDGET = int(os.environ.get("TIME_BUDGET", "35")) | |
| STAGE = os.environ.get("STAGE", "auto") | |
| def prepare_samples(): | |
| if SAMPLES_FILE.exists(): | |
| with open(SAMPLES_FILE, "r", encoding="utf-8") as f: | |
| s = json.load(f) | |
| logger.info("Loaded existing samples.json (%d samples)", len(s)) | |
| return s | |
| from datasets import load_dataset | |
| logger.info("Downloading %s/%s", DATASET_ID, DATASET_CONFIG) | |
| ds = load_dataset(DATASET_ID, DATASET_CONFIG) | |
| all_samples = [] | |
| for split_name in ["train", "validation", "test"]: | |
| if split_name in ds: | |
| for row in ds[split_name]: | |
| all_samples.append({ | |
| "text": row["text"], | |
| "label": int(row["label"]), | |
| "category": row.get("category", "unknown"), | |
| "severity": row.get("severity", ""), | |
| "source": row.get("source", ""), | |
| "split": split_name, | |
| }) | |
| logger.info("Total %d", len(all_samples)) | |
| if SAMPLE_SIZE and SAMPLE_SIZE < len(all_samples): | |
| import random | |
| random.seed(42) | |
| by_cat = {} | |
| for s in all_samples: | |
| by_cat.setdefault(s["category"], []).append(s) | |
| total = len(all_samples) | |
| sampled = [] | |
| for cat, items in by_cat.items(): | |
| n = max(1, round(len(items) / total * SAMPLE_SIZE)) | |
| sampled.extend(random.sample(items, min(n, len(items)))) | |
| random.shuffle(sampled) | |
| all_samples = sampled | |
| logger.info("Subsampled to %d", len(all_samples)) | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| with open(SAMPLES_FILE, "w", encoding="utf-8") as f: | |
| json.dump(all_samples, f, ensure_ascii=False) | |
| return all_samples | |
| def mean_pool(last_hidden, attention_mask): | |
| mask = attention_mask.unsqueeze(-1).float() | |
| s = (last_hidden * mask).sum(dim=1) | |
| d = mask.sum(dim=1).clamp(min=1e-9) | |
| return s / d | |
| def embed_chunked(samples): | |
| EMB_CHUNKS_DIR.mkdir(parents=True, exist_ok=True) | |
| num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE | |
| done = {int(p.stem) for p in EMB_CHUNKS_DIR.glob("*.npy")} | |
| todo = [b for b in range(num_batches) if b not in done] | |
| logger.info("Batches: total=%d done=%d todo=%d", num_batches, len(done), len(todo)) | |
| if not todo: | |
| return True | |
| from transformers import AutoTokenizer, AutoModel | |
| logger.info("Loading MiniLM model...") | |
| t0 = time.time() | |
| tok = AutoTokenizer.from_pretrained(EMB_MODEL_ID) | |
| mdl = AutoModel.from_pretrained(EMB_MODEL_ID) | |
| mdl.eval() | |
| logger.info("Model loaded in %.1fs", time.time() - t0) | |
| texts = [s["text"] for s in samples] | |
| start = time.time() | |
| processed = 0 | |
| for b in todo: | |
| if time.time() - start > TIME_BUDGET: | |
| logger.info("Time budget reached after %d batches", processed) | |
| break | |
| i = b * BATCH_SIZE | |
| bt = texts[i:i + BATCH_SIZE] | |
| inputs = tok(bt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True) | |
| with torch.no_grad(): | |
| out = mdl(**inputs) | |
| emb = mean_pool(out.last_hidden_state, inputs["attention_mask"]) | |
| emb = torch.nn.functional.normalize(emb, p=2, dim=1) | |
| emb = emb.cpu().numpy().astype(np.float32) | |
| np.save(EMB_CHUNKS_DIR / f"{b}.npy", emb) | |
| processed += 1 | |
| if processed % 10 == 0 or processed == len(todo): | |
| logger.info("batch %d/%d (this run=%d elapsed=%.1fs)", b+1, num_batches, processed, time.time()-start) | |
| remaining = len(todo) - processed | |
| logger.info("This run: %d batches; remaining: %d", processed, remaining) | |
| return remaining == 0 | |
| def assemble_and_tsne(samples): | |
| from sklearn.manifold import TSNE | |
| num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE | |
| parts = [] | |
| for b in range(num_batches): | |
| parts.append(np.load(EMB_CHUNKS_DIR / f"{b}.npy")) | |
| emb = np.concatenate(parts, axis=0) | |
| logger.info("Embeddings shape %s", emb.shape) | |
| n = emb.shape[0] | |
| perp = min(TSNE_PERPLEXITY, max(5, n - 1)) | |
| logger.info("t-SNE perp=%d...", perp) | |
| t0 = time.time() | |
| try: | |
| tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca") | |
| except TypeError: | |
| tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca") | |
| coords = tsne.fit_transform(emb) | |
| logger.info("t-SNE done %.1fs", time.time() - t0) | |
| np.savez_compressed(CACHE_FILE, embeddings=emb, tsne_2d=coords) | |
| meta = [{"text": s["text"], "label": s["label"], "category": s["category"], | |
| "severity": s["severity"], "source": s["source"], "split": s["split"]} | |
| for s in samples] | |
| with open(META_FILE, "w", encoding="utf-8") as f: | |
| json.dump(meta, f, ensure_ascii=False) | |
| logger.info("Cache complete at %s", CACHE_DIR) | |
| def status(): | |
| samples_exists = SAMPLES_FILE.exists() | |
| n_samples = 0 | |
| if samples_exists: | |
| with open(SAMPLES_FILE, "r", encoding="utf-8") as f: | |
| n_samples = len(json.load(f)) | |
| n_done = len(list(EMB_CHUNKS_DIR.glob("*.npy"))) if EMB_CHUNKS_DIR.exists() else 0 | |
| n_batches = (n_samples + BATCH_SIZE - 1) // BATCH_SIZE if n_samples else 0 | |
| cache_done = CACHE_FILE.exists() and META_FILE.exists() | |
| print(f"samples={n_samples} batches_done={n_done}/{n_batches} final_cache={cache_done}") | |
| def main(): | |
| if STAGE == "status": | |
| status(); return | |
| if STAGE in ("download", "auto"): | |
| samples = prepare_samples() | |
| if STAGE == "download": | |
| return | |
| else: | |
| with open(SAMPLES_FILE, "r", encoding="utf-8") as f: | |
| samples = json.load(f) | |
| if STAGE in ("embed", "auto"): | |
| all_done = embed_chunked(samples) | |
| if STAGE == "embed" or not all_done: | |
| return | |
| if STAGE in ("tsne", "auto"): | |
| assemble_and_tsne(samples) | |
| if __name__ == "__main__": | |
| main() | |