GuardLLM / precompute.py
AlephBeth-AI's picture
Upload precompute.py with huggingface_hub
79046c2 verified
Raw
History Blame Contribute Delete
7.18 kB
"""GuardLLM - Precompute Embeddings & t-SNE (resumable, MiniLM-based).
Uses sentence-transformers/all-MiniLM-L6-v2 (22M params) to compute
embeddings for t-SNE visualization. The downstream risk classifier
(Llama Prompt Guard 2) is *not* loaded here - it is loaded by the
Gradio app on-demand when a user clicks a point.
"""
import sys, os, json, logging, time
from pathlib import Path
import numpy as np
import torch
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("precompute")
CACHE_DIR = Path(__file__).parent / "cache"
CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
META_FILE = CACHE_DIR / "metadata.json"
SAMPLES_FILE = CACHE_DIR / "samples.json"
EMB_CHUNKS_DIR = CACHE_DIR / "emb_chunks_mini" # NEW folder so old chunks don't collide
EMB_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
DATASET_ID = "neuralchemy/Prompt-injection-dataset"
DATASET_CONFIG = "core"
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32"))
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "256"))
TSNE_PERPLEXITY = 30
TSNE_SEED = 42
SAMPLE_SIZE = int(os.environ.get("SAMPLE_SIZE", "0")) or None
TIME_BUDGET = int(os.environ.get("TIME_BUDGET", "35"))
STAGE = os.environ.get("STAGE", "auto")
def prepare_samples():
if SAMPLES_FILE.exists():
with open(SAMPLES_FILE, "r", encoding="utf-8") as f:
s = json.load(f)
logger.info("Loaded existing samples.json (%d samples)", len(s))
return s
from datasets import load_dataset
logger.info("Downloading %s/%s", DATASET_ID, DATASET_CONFIG)
ds = load_dataset(DATASET_ID, DATASET_CONFIG)
all_samples = []
for split_name in ["train", "validation", "test"]:
if split_name in ds:
for row in ds[split_name]:
all_samples.append({
"text": row["text"],
"label": int(row["label"]),
"category": row.get("category", "unknown"),
"severity": row.get("severity", ""),
"source": row.get("source", ""),
"split": split_name,
})
logger.info("Total %d", len(all_samples))
if SAMPLE_SIZE and SAMPLE_SIZE < len(all_samples):
import random
random.seed(42)
by_cat = {}
for s in all_samples:
by_cat.setdefault(s["category"], []).append(s)
total = len(all_samples)
sampled = []
for cat, items in by_cat.items():
n = max(1, round(len(items) / total * SAMPLE_SIZE))
sampled.extend(random.sample(items, min(n, len(items))))
random.shuffle(sampled)
all_samples = sampled
logger.info("Subsampled to %d", len(all_samples))
CACHE_DIR.mkdir(parents=True, exist_ok=True)
with open(SAMPLES_FILE, "w", encoding="utf-8") as f:
json.dump(all_samples, f, ensure_ascii=False)
return all_samples
def mean_pool(last_hidden, attention_mask):
mask = attention_mask.unsqueeze(-1).float()
s = (last_hidden * mask).sum(dim=1)
d = mask.sum(dim=1).clamp(min=1e-9)
return s / d
def embed_chunked(samples):
EMB_CHUNKS_DIR.mkdir(parents=True, exist_ok=True)
num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE
done = {int(p.stem) for p in EMB_CHUNKS_DIR.glob("*.npy")}
todo = [b for b in range(num_batches) if b not in done]
logger.info("Batches: total=%d done=%d todo=%d", num_batches, len(done), len(todo))
if not todo:
return True
from transformers import AutoTokenizer, AutoModel
logger.info("Loading MiniLM model...")
t0 = time.time()
tok = AutoTokenizer.from_pretrained(EMB_MODEL_ID)
mdl = AutoModel.from_pretrained(EMB_MODEL_ID)
mdl.eval()
logger.info("Model loaded in %.1fs", time.time() - t0)
texts = [s["text"] for s in samples]
start = time.time()
processed = 0
for b in todo:
if time.time() - start > TIME_BUDGET:
logger.info("Time budget reached after %d batches", processed)
break
i = b * BATCH_SIZE
bt = texts[i:i + BATCH_SIZE]
inputs = tok(bt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True)
with torch.no_grad():
out = mdl(**inputs)
emb = mean_pool(out.last_hidden_state, inputs["attention_mask"])
emb = torch.nn.functional.normalize(emb, p=2, dim=1)
emb = emb.cpu().numpy().astype(np.float32)
np.save(EMB_CHUNKS_DIR / f"{b}.npy", emb)
processed += 1
if processed % 10 == 0 or processed == len(todo):
logger.info("batch %d/%d (this run=%d elapsed=%.1fs)", b+1, num_batches, processed, time.time()-start)
remaining = len(todo) - processed
logger.info("This run: %d batches; remaining: %d", processed, remaining)
return remaining == 0
def assemble_and_tsne(samples):
from sklearn.manifold import TSNE
num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE
parts = []
for b in range(num_batches):
parts.append(np.load(EMB_CHUNKS_DIR / f"{b}.npy"))
emb = np.concatenate(parts, axis=0)
logger.info("Embeddings shape %s", emb.shape)
n = emb.shape[0]
perp = min(TSNE_PERPLEXITY, max(5, n - 1))
logger.info("t-SNE perp=%d...", perp)
t0 = time.time()
try:
tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca")
except TypeError:
tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca")
coords = tsne.fit_transform(emb)
logger.info("t-SNE done %.1fs", time.time() - t0)
np.savez_compressed(CACHE_FILE, embeddings=emb, tsne_2d=coords)
meta = [{"text": s["text"], "label": s["label"], "category": s["category"],
"severity": s["severity"], "source": s["source"], "split": s["split"]}
for s in samples]
with open(META_FILE, "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False)
logger.info("Cache complete at %s", CACHE_DIR)
def status():
samples_exists = SAMPLES_FILE.exists()
n_samples = 0
if samples_exists:
with open(SAMPLES_FILE, "r", encoding="utf-8") as f:
n_samples = len(json.load(f))
n_done = len(list(EMB_CHUNKS_DIR.glob("*.npy"))) if EMB_CHUNKS_DIR.exists() else 0
n_batches = (n_samples + BATCH_SIZE - 1) // BATCH_SIZE if n_samples else 0
cache_done = CACHE_FILE.exists() and META_FILE.exists()
print(f"samples={n_samples} batches_done={n_done}/{n_batches} final_cache={cache_done}")
def main():
if STAGE == "status":
status(); return
if STAGE in ("download", "auto"):
samples = prepare_samples()
if STAGE == "download":
return
else:
with open(SAMPLES_FILE, "r", encoding="utf-8") as f:
samples = json.load(f)
if STAGE in ("embed", "auto"):
all_done = embed_chunked(samples)
if STAGE == "embed" or not all_done:
return
if STAGE in ("tsne", "auto"):
assemble_and_tsne(samples)
if __name__ == "__main__":
main()