Spaces:
Sleeping
Sleeping
File size: 5,679 Bytes
b786614 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
profiler.py — Offline attention pattern collection for any HuggingFace model.
Run once per model family to collect per-head attention distributions.
The resulting patterns are then clustered into prototypes (see prototypes.py)
which drive the O(n) eviction policy at inference time.
Usage:
from proactive_cache import profile_model
patterns = profile_model(model, tokenizer, corpus="wikitext", num_docs=50)
"""
from __future__ import annotations
import torch
import numpy as np
from typing import Optional, Union, List, Dict
from tqdm import tqdm
def profile_model(
model,
tokenizer,
corpus: Union[str, List[str]] = "wikitext",
num_docs: int = 50,
seq_len: int = 512,
output_attentions: bool = True,
) -> Dict:
"""
Collect per-head attention distributions over a calibration corpus.
Args:
model: A HuggingFace CausalLM model (any architecture).
tokenizer: Corresponding tokenizer.
corpus: Either a dataset name ("wikitext", "pg19") or a list
of raw text strings to profile on.
num_docs: Number of documents to sample for profiling.
seq_len: Sequence length for profiling chunks.
output_attentions: Whether to collect full attention matrices.
Returns:
patterns: Dict mapping ``(layer_idx, head_idx) → np.ndarray`` of shape
``(num_docs, seq_len)`` — mean attention received per position.
"""
model.eval()
device = next(model.parameters()).device
# ── Load corpus ───────────────────────────────────────────────────────────
texts = _load_corpus(corpus, num_docs)
print(f"[ProactiveCache] Profiling on {len(texts)} documents, seq_len={seq_len}")
all_patterns: List[Dict] = []
for text in tqdm(texts, desc="Profiling attention patterns"):
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=seq_len,
)
input_ids = enc["input_ids"].to(device)
if input_ids.shape[1] < 32:
continue
with torch.no_grad():
out = model(
input_ids,
output_attentions=output_attentions,
use_cache=False,
)
if out.attentions is None:
raise RuntimeError(
"Model did not return attention weights. "
"Ensure the model config has `output_attentions=True` support, "
"or set output_attentions=True in the model config."
)
doc_pattern = {}
for layer_idx, attn in enumerate(out.attentions):
# attn: (batch=1, num_heads, seq_len, seq_len)
attn_np = attn[0].float().cpu().numpy() # (heads, seq, seq)
num_heads, slen, _ = attn_np.shape
for head_idx in range(num_heads):
# Mean attention received at each position (column-wise mean)
received = attn_np[head_idx].mean(axis=0) # (seq_len,)
# Pad / truncate to fixed seq_len
padded = np.zeros(seq_len, dtype=np.float32)
padded[:min(slen, seq_len)] = received[:seq_len]
doc_pattern[(layer_idx, head_idx)] = padded
all_patterns.append(doc_pattern)
if not all_patterns:
raise RuntimeError("No valid documents found in corpus. Try increasing num_docs.")
print(f"[ProactiveCache] Profiled {len(all_patterns)} documents across "
f"{len(all_patterns[0])} (layer, head) pairs.")
return all_patterns
def _load_corpus(corpus: Union[str, List[str]], num_docs: int) -> List[str]:
"""Load a text corpus for profiling."""
if isinstance(corpus, list):
return corpus[:num_docs]
if corpus == "wikitext":
return _load_wikitext(num_docs)
elif corpus == "pg19":
return _load_pg19(num_docs)
else:
# Try to load as a HuggingFace dataset name
return _load_hf_dataset(corpus, num_docs)
def _load_wikitext(num_docs: int) -> List[str]:
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-103-v1", split="validation", streaming=True)
texts, current = [], ""
for item in ds:
t = item["text"].strip()
if t:
current += " " + t
if len(current) > 2000:
texts.append(current.strip())
current = ""
if len(texts) >= num_docs:
break
return texts
def _load_pg19(num_docs: int) -> List[str]:
from datasets import load_dataset
ds = load_dataset("emozilla/pg19", split="test", streaming=True)
texts = []
for item in ds:
text = item.get("text", "")
if len(text) > 500:
texts.append(text[:4000])
if len(texts) >= num_docs:
break
return texts
def _load_hf_dataset(name: str, num_docs: int) -> List[str]:
from datasets import load_dataset
try:
ds = load_dataset(name, split="train", streaming=True)
texts = []
for item in ds:
# Try common text field names
for field in ["text", "content", "body", "sentence"]:
if field in item and isinstance(item[field], str) and len(item[field]) > 100:
texts.append(item[field])
break
if len(texts) >= num_docs:
break
return texts
except Exception as e:
raise ValueError(f"Could not load corpus '{name}': {e}")
|