compare 4096D cosine top-100 vs t-SNE 2D top-100
Browse files- tsne_circle_eval.py +157 -0
tsne_circle_eval.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""对比:4096D cosine 邻居 vs t-SNE 2D 圆内邻居,看哪种判别更准。
|
| 3 |
+
|
| 4 |
+
复用 cache_emb/。t-SNE 用 1000 golden + 200 ruler 一起做(1200 点),保证投影一致。
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from sklearn.manifold import TSNE
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
DEFAULTS = dict(
|
| 16 |
+
cache_dir = "cache_emb",
|
| 17 |
+
csv = "/mnt/bn/tns-algo-ue-my/biaowu/aipf_dm_metric/example/yss_ruler_eval/data/aipf_golden_set.csv",
|
| 18 |
+
ruler = "/mnt/bn/tns-algo-ue-my/biaowu/aipf_dm_metric/ranking_moderation/data/dm/youth_sexual_and_physical_abuse_aigt_v009/ranking_bucket/ruler_items.json",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_npy_pair(cache_dir, n_csv, n_ruler, max_length=4096):
|
| 23 |
+
cd = Path(cache_dir)
|
| 24 |
+
csvs = list(cd.glob(f"csv_*_n{n_csv}_L{max_length}.npy"))
|
| 25 |
+
rulers = list(cd.glob(f"ruler_*_n{n_ruler}_L{max_length}.npy"))
|
| 26 |
+
if not csvs or not rulers:
|
| 27 |
+
raise FileNotFoundError(f"找不到缓存。期望 {cd}/csv_*_n{n_csv}_L{max_length}.npy")
|
| 28 |
+
return np.load(csvs[0]), np.load(rulers[0])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_ruler_meta(path):
|
| 32 |
+
with open(path) as f:
|
| 33 |
+
data = json.load(f)
|
| 34 |
+
items = data if isinstance(data, list) else (data.get("items") or data.get("ruler_items") or data.get("data") or [])
|
| 35 |
+
ranks = np.array([int(it["rank"]) for it in items])
|
| 36 |
+
scores = np.array([float(it["score"]) for it in items])
|
| 37 |
+
return ranks, scores
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def metrics(preds, gts):
|
| 41 |
+
tp = int(((preds == 1) & (gts == 1)).sum())
|
| 42 |
+
fp = int(((preds == 1) & (gts == 0)).sum())
|
| 43 |
+
tn = int(((preds == 0) & (gts == 0)).sum())
|
| 44 |
+
fn = int(((preds == 0) & (gts == 1)).sum())
|
| 45 |
+
p = tp/(tp+fp) if tp+fp else 0.0
|
| 46 |
+
r = tp/(tp+fn) if tp+fn else 0.0
|
| 47 |
+
f = 2*p*r/(p+r) if p+r else 0.0
|
| 48 |
+
a = (tp+tn)/len(preds)
|
| 49 |
+
return tp, fp, tn, fn, p, r, f, a
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def best_threshold(scores, gts):
|
| 53 |
+
cands = sorted(set(scores.tolist()))
|
| 54 |
+
best = (-1.0, None, None, None)
|
| 55 |
+
for c in cands:
|
| 56 |
+
preds = (scores >= c).astype(int)
|
| 57 |
+
_, _, _, _, p, r, f, _ = metrics(preds, gts)
|
| 58 |
+
if f > best[0]:
|
| 59 |
+
best = (f, c, p, r)
|
| 60 |
+
return best
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def topk_neighbors(query_xy, ruler_xy, k):
|
| 64 |
+
"""对每个 query,找 ruler 里最近的 k 个,返回 (idx, dist)"""
|
| 65 |
+
# query_xy (Nq, 2), ruler_xy (Nr, 2)
|
| 66 |
+
diffs = query_xy[:, None, :] - ruler_xy[None, :, :]
|
| 67 |
+
dists = np.linalg.norm(diffs, axis=-1) # (Nq, Nr)
|
| 68 |
+
idx = np.argpartition(dists, k - 1, axis=1)[:, :k]
|
| 69 |
+
row = np.arange(len(query_xy))[:, None]
|
| 70 |
+
selected = dists[row, idx]
|
| 71 |
+
order = np.argsort(selected, axis=1)
|
| 72 |
+
return np.take_along_axis(idx, order, axis=1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
p = argparse.ArgumentParser()
|
| 77 |
+
p.add_argument("--cache-dir", default=DEFAULTS["cache_dir"])
|
| 78 |
+
p.add_argument("--csv", default=DEFAULTS["csv"])
|
| 79 |
+
p.add_argument("--ruler", default=DEFAULTS["ruler"])
|
| 80 |
+
p.add_argument("--positive-label", default="Y")
|
| 81 |
+
p.add_argument("--boundary-rank", type=int, default=106)
|
| 82 |
+
p.add_argument("--max-length", type=int, default=4096)
|
| 83 |
+
p.add_argument("--perplexity", type=float, default=30.0)
|
| 84 |
+
p.add_argument("--k", type=int, default=100)
|
| 85 |
+
p.add_argument("--seed", type=int, default=42)
|
| 86 |
+
args = p.parse_args()
|
| 87 |
+
|
| 88 |
+
print("[1] load")
|
| 89 |
+
df = pd.read_csv(args.csv, keep_default_na=False)
|
| 90 |
+
gts = df["label"].astype(str).str.upper().eq(args.positive_label.upper()).astype(int).values
|
| 91 |
+
ruler_rank, ruler_score = load_ruler_meta(args.ruler)
|
| 92 |
+
n_csv, n_ruler = len(gts), len(ruler_rank)
|
| 93 |
+
csv_emb, ruler_emb = load_npy_pair(args.cache_dir, n_csv, n_ruler, args.max_length)
|
| 94 |
+
|
| 95 |
+
K = args.k
|
| 96 |
+
methods = {}
|
| 97 |
+
|
| 98 |
+
# ---- baseline: 4096D cosine ----
|
| 99 |
+
print(f"[2] baseline: 4096D cosine top-{K}")
|
| 100 |
+
sims = csv_emb @ ruler_emb.T
|
| 101 |
+
top_idx = np.argpartition(-sims, K-1, axis=1)[:, :K]
|
| 102 |
+
row = np.arange(n_csv)[:, None]
|
| 103 |
+
top_sims = sims[row, top_idx]
|
| 104 |
+
top_score_4096 = ruler_score[top_idx]
|
| 105 |
+
raw_w = (top_sims * top_score_4096).sum(axis=1) / np.maximum(top_sims.sum(axis=1), 1e-12)
|
| 106 |
+
raw_mean = top_score_4096.mean(axis=1)
|
| 107 |
+
raw_vote = (ruler_rank[top_idx] < args.boundary_rank).sum(axis=1)
|
| 108 |
+
methods["4096D cosine | weighted_score"] = raw_w
|
| 109 |
+
methods["4096D cosine | mean(score)"] = raw_mean
|
| 110 |
+
methods["4096D cosine | vote_count"] = raw_vote.astype(float)
|
| 111 |
+
|
| 112 |
+
# ---- t-SNE 2D ----
|
| 113 |
+
print(f"[3] t-SNE on 1200 points (perplexity={args.perplexity})")
|
| 114 |
+
all_emb = np.vstack([csv_emb, ruler_emb])
|
| 115 |
+
tsne = TSNE(n_components=2, perplexity=args.perplexity,
|
| 116 |
+
init="pca", random_state=args.seed,
|
| 117 |
+
metric="cosine", learning_rate="auto")
|
| 118 |
+
xy = tsne.fit_transform(all_emb)
|
| 119 |
+
csv_xy, ruler_xy = xy[:n_csv], xy[n_csv:]
|
| 120 |
+
|
| 121 |
+
# 2D top-K(等价于"圆扩张到正好包含 100 个 ruler")
|
| 122 |
+
print(f"[4] 2D Euclidean top-{K} (in t-SNE space)")
|
| 123 |
+
top_idx_2d = topk_neighbors(csv_xy, ruler_xy, K)
|
| 124 |
+
top_score_2d = ruler_score[top_idx_2d]
|
| 125 |
+
rank_2d = ruler_rank[top_idx_2d]
|
| 126 |
+
methods["t-SNE 2D | mean(score)"] = top_score_2d.mean(axis=1)
|
| 127 |
+
methods["t-SNE 2D | vote_count"] = (rank_2d < args.boundary_rank).sum(axis=1).astype(float)
|
| 128 |
+
# weighted by 1/dist?也试一下
|
| 129 |
+
diffs = csv_xy[:, None, :] - ruler_xy[None, :, :]
|
| 130 |
+
dists2d = np.linalg.norm(diffs, axis=-1)
|
| 131 |
+
selected_dist = np.take_along_axis(dists2d, top_idx_2d, axis=1)
|
| 132 |
+
weights = 1.0 / (selected_dist + 1e-6)
|
| 133 |
+
weighted_2d = (weights * top_score_2d).sum(axis=1) / weights.sum(axis=1)
|
| 134 |
+
methods["t-SNE 2D | inv_dist weighted"] = weighted_2d
|
| 135 |
+
|
| 136 |
+
# ---- 邻居重叠率 ----
|
| 137 |
+
overlap = []
|
| 138 |
+
for i in range(n_csv):
|
| 139 |
+
a = set(top_idx[i].tolist())
|
| 140 |
+
b = set(top_idx_2d[i].tolist())
|
| 141 |
+
overlap.append(len(a & b) / K)
|
| 142 |
+
print(f"\n[5] 邻居重叠率(4096D vs 2D 各取 top-{K}):")
|
| 143 |
+
print(f" 平均 = {np.mean(overlap):.2%}")
|
| 144 |
+
print(f" 中位数 = {np.median(overlap):.2%}")
|
| 145 |
+
print(f" p10 / p90 = {np.percentile(overlap, 10):.2%} / {np.percentile(overlap, 90):.2%}")
|
| 146 |
+
|
| 147 |
+
# ---- 各方法 best F1 ----
|
| 148 |
+
print(f"\n[6] best F1 by sweeping threshold (K={K})")
|
| 149 |
+
print(f"{'method':<35}{'F1':>9}{'thr':>10}{'P':>9}{'R':>9}")
|
| 150 |
+
print("-" * 75)
|
| 151 |
+
for name, scores in methods.items():
|
| 152 |
+
f1, thr, prec, rec = best_threshold(scores, gts)
|
| 153 |
+
print(f"{name:<35}{f1:>9.4f}{thr:>10.4f}{prec:>9.4f}{rec:>9.4f}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
main()
|