| | import numpy as np |
| | from typing import Dict, List, TYPE_CHECKING |
| |
|
| | import torch |
| | from sklearn.cluster import KMeans |
| |
|
| | if TYPE_CHECKING: |
| | from .model import BitTransformerLM |
| |
|
| |
|
| | class TelemetrySynthesizer: |
| | """Analyze telemetry batches and cluster activation patterns.""" |
| |
|
| | def __init__(self, n_clusters: int = 2) -> None: |
| | self.n_clusters = n_clusters |
| |
|
| | def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray: |
| | """Compute activation/attention summaries for a single telemetry dict.""" |
| | acts = telemetry["activations"] |
| | attn = telemetry["attention_maps"] |
| | summaries = [] |
| | for a, m in zip(acts, attn): |
| | mean = a.mean().item() |
| | var = a.var(unbiased=False).item() |
| | prob = m.softmax(-1) |
| | entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item() |
| | summaries.append([mean, var, entropy]) |
| | return np.array(summaries).ravel() |
| |
|
| | def synthesize( |
| | self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor |
| | ) -> Dict[str, List]: |
| | """Cluster telemetry summaries and return cluster info.""" |
| | data = np.stack([self._summary(t) for t in telemetries]) |
| | km = KMeans(n_clusters=self.n_clusters, n_init=1) |
| | labels = km.fit_predict(data) |
| | representatives: List[List[int]] = [] |
| | for c in range(self.n_clusters): |
| | idx = np.where(labels == c)[0] |
| | if len(idx) > 0: |
| | representatives.append(bit_seqs[idx[0]].tolist()) |
| | else: |
| | representatives.append([]) |
| | return {"cluster_assignments": labels.tolist(), "representatives": representatives} |
| |
|
| | def cluster_sequences( |
| | self, model: "BitTransformerLM", bit_seqs: torch.Tensor |
| | ) -> List[List[int]]: |
| | """Run the model to gather telemetry and return representative sequences. |
| | |
| | Parameters |
| | ---------- |
| | model: BitTransformerLM |
| | Model used to compute telemetry for each sequence. |
| | bit_seqs: torch.Tensor |
| | Tensor containing one bit sequence per row. |
| | |
| | Returns |
| | ------- |
| | list[list[int]] |
| | Representative sequences chosen from KMeans clusters. |
| | """ |
| | telemetries: List[Dict[str, List[torch.Tensor]]] = [] |
| | with torch.no_grad(): |
| | for seq in bit_seqs: |
| | _, tele = model(seq.unsqueeze(0)) |
| | telemetries.append(tele) |
| | info = self.synthesize(telemetries, bit_seqs) |
| | return info["representatives"] |
| |
|
| |
|
| | def detect_metric_drift( |
| | metrics_log: Dict[str, List[float]], |
| | window: int = 10, |
| | threshold: float = 0.2, |
| | ) -> Dict[str, bool]: |
| | """Detect metric drift between consecutive windows. |
| | |
| | Args: |
| | metrics_log: History of scalar metrics keyed by name. |
| | window: Number of recent steps to compare. |
| | threshold: Absolute difference required to flag drift. |
| | |
| | Returns: |
| | Dictionary mapping metric keys to a boolean drift indicator. |
| | """ |
| | drift = {} |
| | for key, values in metrics_log.items(): |
| | if len(values) < window * 2: |
| | drift[key] = False |
| | continue |
| | recent = np.mean(values[-window:]) |
| | prev = np.mean(values[-2 * window : -window]) |
| | drift[key] = abs(recent - prev) > threshold |
| | return drift |
| |
|