File size: 5,679 Bytes
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
profiler.py — Offline attention pattern collection for any HuggingFace model.

Run once per model family to collect per-head attention distributions.
The resulting patterns are then clustered into prototypes (see prototypes.py)
which drive the O(n) eviction policy at inference time.

Usage:
    from proactive_cache import profile_model
    patterns = profile_model(model, tokenizer, corpus="wikitext", num_docs=50)
"""

from __future__ import annotations
import torch
import numpy as np
from typing import Optional, Union, List, Dict
from tqdm import tqdm


def profile_model(
    model,
    tokenizer,
    corpus: Union[str, List[str]] = "wikitext",
    num_docs: int = 50,
    seq_len: int = 512,
    output_attentions: bool = True,
) -> Dict:
    """
    Collect per-head attention distributions over a calibration corpus.

    Args:
        model:             A HuggingFace CausalLM model (any architecture).
        tokenizer:         Corresponding tokenizer.
        corpus:            Either a dataset name ("wikitext", "pg19") or a list
                           of raw text strings to profile on.
        num_docs:          Number of documents to sample for profiling.
        seq_len:           Sequence length for profiling chunks.
        output_attentions: Whether to collect full attention matrices.

    Returns:
        patterns: Dict mapping ``(layer_idx, head_idx) → np.ndarray`` of shape
                  ``(num_docs, seq_len)`` — mean attention received per position.
    """
    model.eval()
    device = next(model.parameters()).device

    # ── Load corpus ───────────────────────────────────────────────────────────
    texts = _load_corpus(corpus, num_docs)
    print(f"[ProactiveCache] Profiling on {len(texts)} documents, seq_len={seq_len}")

    all_patterns: List[Dict] = []

    for text in tqdm(texts, desc="Profiling attention patterns"):
        enc = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=seq_len,
        )
        input_ids = enc["input_ids"].to(device)
        if input_ids.shape[1] < 32:
            continue

        with torch.no_grad():
            out = model(
                input_ids,
                output_attentions=output_attentions,
                use_cache=False,
            )

        if out.attentions is None:
            raise RuntimeError(
                "Model did not return attention weights. "
                "Ensure the model config has `output_attentions=True` support, "
                "or set output_attentions=True in the model config."
            )

        doc_pattern = {}
        for layer_idx, attn in enumerate(out.attentions):
            # attn: (batch=1, num_heads, seq_len, seq_len)
            attn_np = attn[0].float().cpu().numpy()   # (heads, seq, seq)
            num_heads, slen, _ = attn_np.shape
            for head_idx in range(num_heads):
                # Mean attention received at each position (column-wise mean)
                received = attn_np[head_idx].mean(axis=0)   # (seq_len,)
                # Pad / truncate to fixed seq_len
                padded = np.zeros(seq_len, dtype=np.float32)
                padded[:min(slen, seq_len)] = received[:seq_len]
                doc_pattern[(layer_idx, head_idx)] = padded

        all_patterns.append(doc_pattern)

    if not all_patterns:
        raise RuntimeError("No valid documents found in corpus. Try increasing num_docs.")

    print(f"[ProactiveCache] Profiled {len(all_patterns)} documents across "
          f"{len(all_patterns[0])} (layer, head) pairs.")
    return all_patterns


def _load_corpus(corpus: Union[str, List[str]], num_docs: int) -> List[str]:
    """Load a text corpus for profiling."""
    if isinstance(corpus, list):
        return corpus[:num_docs]

    if corpus == "wikitext":
        return _load_wikitext(num_docs)
    elif corpus == "pg19":
        return _load_pg19(num_docs)
    else:
        # Try to load as a HuggingFace dataset name
        return _load_hf_dataset(corpus, num_docs)


def _load_wikitext(num_docs: int) -> List[str]:
    from datasets import load_dataset
    ds = load_dataset("wikitext", "wikitext-103-v1", split="validation", streaming=True)
    texts, current = [], ""
    for item in ds:
        t = item["text"].strip()
        if t:
            current += " " + t
        if len(current) > 2000:
            texts.append(current.strip())
            current = ""
        if len(texts) >= num_docs:
            break
    return texts


def _load_pg19(num_docs: int) -> List[str]:
    from datasets import load_dataset
    ds = load_dataset("emozilla/pg19", split="test", streaming=True)
    texts = []
    for item in ds:
        text = item.get("text", "")
        if len(text) > 500:
            texts.append(text[:4000])
        if len(texts) >= num_docs:
            break
    return texts


def _load_hf_dataset(name: str, num_docs: int) -> List[str]:
    from datasets import load_dataset
    try:
        ds = load_dataset(name, split="train", streaming=True)
        texts = []
        for item in ds:
            # Try common text field names
            for field in ["text", "content", "body", "sentence"]:
                if field in item and isinstance(item[field], str) and len(item[field]) > 100:
                    texts.append(item[field])
                    break
            if len(texts) >= num_docs:
                break
        return texts
    except Exception as e:
        raise ValueError(f"Could not load corpus '{name}': {e}")