Wendy-Fly commited on
Commit
9d30467
·
verified ·
1 Parent(s): 37c2c5b

Upload batch_top5_match.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. batch_top5_match.py +293 -0
batch_top5_match.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """把 golden_set.csv (≈1000 条) 全部和 ruler 200 条做 cosine 相似度,
3
+ 每条算 Top-K 最近的 ruler items,并把结果保存到本地。
4
+
5
+ 用法:
6
+ # 默认路径
7
+ python3 batch_top5_match.py
8
+
9
+ # 自定义
10
+ python3 batch_top5_match.py \
11
+ --csv /mnt/.../aipf_golden_set.csv \
12
+ --ruler /mnt/.../ruler_items.json \
13
+ --model /mnt/.../Qwen3-Embedding-8B \
14
+ --output golden_top5.jsonl \
15
+ --top-k 5 \
16
+ --boundary-score 44.72 \
17
+ --cache-dir cache_emb \
18
+ --limit 50 # 先小跑 50 条 sanity check
19
+
20
+ 输出:
21
+ - {output}.jsonl 每行一条样本,含 task_id / label / Top-K 详情 / weighted_score / 预测
22
+ - {output}.summary.csv 按行汇总,便于在 Excel / pandas 里筛
23
+ - cache_emb/*.npy (可选)embedding 缓存,重跑时自动复用
24
+ """
25
+ import argparse
26
+ import json
27
+ import re
28
+ import sys
29
+ import time
30
+ from pathlib import Path
31
+
32
+ import numpy as np
33
+ import pandas as pd
34
+ import torch
35
+ import torch.nn.functional as F
36
+ from torch import Tensor
37
+ from transformers import AutoTokenizer, AutoModel
38
+
39
+
40
+ DEFAULT_MODEL = "/mnt/bn/tns-algo-ue-my/biaowu/WorkSpace/Models/Qwen3-Embedding-8B"
41
+ DEFAULT_RULER = "/mnt/bn/tns-algo-ue-my/biaowu/aipf_dm_metric/ranking_moderation/data/dm/youth_sexual_and_physical_abuse_aigt_v009/ranking_bucket/ruler_items.json"
42
+ DEFAULT_CSV = "/mnt/bn/tns-algo-ue-my/biaowu/aipf_dm_metric/example/yss_ruler_eval/data/aipf_golden_set.csv"
43
+
44
+
45
+ # ---------- model utils ----------
46
+ def last_token_pool(h: Tensor, attn: Tensor) -> Tensor:
47
+ if (attn[:, -1].sum() == attn.shape[0]): # left padding
48
+ return h[:, -1]
49
+ lens = attn.sum(dim=1) - 1
50
+ bsz = h.shape[0]
51
+ return h[torch.arange(bsz, device=h.device), lens]
52
+
53
+
54
+ @torch.no_grad()
55
+ def encode(texts, tokenizer, model, max_length, batch_size, label="encode"):
56
+ embs = []
57
+ n = len(texts)
58
+ t0 = time.time()
59
+ for i in range(0, n, batch_size):
60
+ batch = texts[i:i + batch_size]
61
+ d = tokenizer(batch, padding=True, truncation=True,
62
+ max_length=max_length, return_tensors="pt").to(model.device)
63
+ out = model(**d)
64
+ e = last_token_pool(out.last_hidden_state, d["attention_mask"])
65
+ e = F.normalize(e, p=2, dim=1)
66
+ embs.append(e.cpu().float())
67
+ del out, d, e
68
+ if torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+ done = min(i + batch_size, n)
71
+ if done % (batch_size * 10) == 0 or done == n:
72
+ elapsed = time.time() - t0
73
+ rate = done / max(elapsed, 1e-3)
74
+ eta = (n - done) / max(rate, 1e-3)
75
+ print(f" [{label}] {done}/{n} | {rate:.1f} ex/s | eta {eta:.0f}s", flush=True)
76
+ return torch.cat(embs, dim=0).numpy()
77
+
78
+
79
+ # ---------- data utils ----------
80
+ def load_ruler_items(path):
81
+ with open(path, "r", encoding="utf-8") as f:
82
+ data = json.load(f)
83
+ items = data if isinstance(data, list) else (
84
+ data.get("items") or data.get("ruler_items") or data.get("data") or [])
85
+ out = []
86
+ for it in items:
87
+ inner = it.get("item", {}) if isinstance(it.get("item"), dict) else {}
88
+ conv = inner.get("conv_text") or it.get("conv_text") or ""
89
+ out.append({
90
+ "rank": it.get("rank"),
91
+ "score": float(it.get("score", 0.0)),
92
+ "item_id": str(it.get("item_id")),
93
+ "text": conv,
94
+ })
95
+ return out
96
+
97
+
98
+ _M_PREFIX = re.compile(r"<m\d+>")
99
+
100
+
101
+ def extract_conv(raw):
102
+ """golden_set 的 text 里可能带 alias-age dict 前缀,这里只取 <m0>... 之后的。"""
103
+ if not isinstance(raw, str):
104
+ return ""
105
+ m = _M_PREFIX.search(raw)
106
+ return raw[m.start():] if m else raw.strip()
107
+
108
+
109
+ def load_csv(path, text_col, id_col, label_col, limit=None):
110
+ df = pd.read_csv(path, keep_default_na=False)
111
+ needed = [c for c in (id_col, label_col) if c not in df.columns]
112
+ if needed:
113
+ raise ValueError(f"missing columns: {needed}; available: {list(df.columns)}")
114
+ if text_col not in df.columns:
115
+ if "conv_text" in df.columns:
116
+ text_col = "conv_text"
117
+ else:
118
+ raise ValueError("no text/conv_text column")
119
+ if limit:
120
+ df = df.head(limit).copy()
121
+ rows = []
122
+ for _, r in df.iterrows():
123
+ rows.append({
124
+ "task_id": str(r[id_col]),
125
+ "label": str(r[label_col]).strip().upper(),
126
+ "raw_text": str(r[text_col]),
127
+ "conv_text": extract_conv(r[text_col]),
128
+ })
129
+ return rows
130
+
131
+
132
+ # ---------- cache ----------
133
+ def cache_path(cache_dir, name, n_items, max_length):
134
+ return Path(cache_dir) / f"{name}_n{n_items}_L{max_length}.npy"
135
+
136
+
137
+ def encode_with_cache(texts, tokenizer, model, *, max_length, batch_size,
138
+ cache_dir, name):
139
+ if cache_dir:
140
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
141
+ p = cache_path(cache_dir, name, len(texts), max_length)
142
+ if p.exists():
143
+ print(f" [{name}] using cached embeddings: {p}")
144
+ return np.load(p)
145
+ emb = encode(texts, tokenizer, model, max_length, batch_size, label=name)
146
+ if cache_dir:
147
+ np.save(p, emb)
148
+ print(f" [{name}] saved cache: {p}")
149
+ return emb
150
+
151
+
152
+ # ---------- args ----------
153
+ def parse_args():
154
+ p = argparse.ArgumentParser()
155
+ p.add_argument("--csv", default=DEFAULT_CSV)
156
+ p.add_argument("--ruler", default=DEFAULT_RULER)
157
+ p.add_argument("--model", default=DEFAULT_MODEL)
158
+ p.add_argument("--output", default="golden_top5.jsonl")
159
+ p.add_argument("--text-col", default="text")
160
+ p.add_argument("--id-col", default="task_id")
161
+ p.add_argument("--label-col", default="label")
162
+ p.add_argument("--top-k", type=int, default=5)
163
+ p.add_argument("--boundary-score", type=float, default=44.72,
164
+ help="预测阈值,weighted_score >= 该值则 pred=1(默认从 pipeline.yaml 抄过来的 youth 类阈值)")
165
+ p.add_argument("--max-length", type=int, default=4096)
166
+ p.add_argument("--batch-size", type=int, default=4)
167
+ p.add_argument("--cache-dir", default="cache_emb",
168
+ help="embedding 缓存目录;设空字符串关闭缓存")
169
+ p.add_argument("--limit", type=int, default=None,
170
+ help="只跑前 N 条做 smoke test")
171
+ p.add_argument("--cpu", action="store_true")
172
+ p.add_argument("--no-flash-attn", action="store_true")
173
+ return p.parse_args()
174
+
175
+
176
+ def main():
177
+ args = parse_args()
178
+
179
+ # 1) data
180
+ print(f"[1/4] load csv: {args.csv}")
181
+ rows = load_csv(args.csv, args.text_col, args.id_col, args.label_col, args.limit)
182
+ print(f" -> {len(rows)} samples")
183
+
184
+ print(f"[2/4] load ruler: {args.ruler}")
185
+ ruler = load_ruler_items(args.ruler)
186
+ print(f" -> {len(ruler)} ruler items")
187
+
188
+ # 2) model
189
+ print(f"[3/4] load model: {args.model}")
190
+ device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")
191
+ print(f" device: {device}")
192
+ mk = {}
193
+ if device == "cuda":
194
+ mk["torch_dtype"] = torch.float16
195
+ if not args.no_flash_attn:
196
+ mk["attn_implementation"] = "flash_attention_2"
197
+ tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left")
198
+ model = AutoModel.from_pretrained(args.model, **mk).to(device).eval()
199
+
200
+ # 3) encode(分别缓存 csv 和 ruler)
201
+ cd = args.cache_dir or None
202
+ print(f"[4/4] encode (batch_size={args.batch_size}, max_length={args.max_length})")
203
+ csv_emb = encode_with_cache([r["conv_text"] for r in rows],
204
+ tokenizer, model,
205
+ max_length=args.max_length,
206
+ batch_size=args.batch_size,
207
+ cache_dir=cd, name=f"csv_{Path(args.csv).stem}")
208
+ ruler_emb = encode_with_cache([it["text"] for it in ruler],
209
+ tokenizer, model,
210
+ max_length=args.max_length,
211
+ batch_size=args.batch_size,
212
+ cache_dir=cd, name=f"ruler_{Path(args.ruler).parent.name}")
213
+
214
+ # 4) sim matrix + Top-K
215
+ sims = csv_emb @ ruler_emb.T # (N_csv, N_ruler)
216
+ K = min(args.top_k, len(ruler))
217
+ # argpartition 找 K 个最大,再排序
218
+ top_idx_part = np.argpartition(-sims, K - 1, axis=1)[:, :K]
219
+ # 在每行内按 sim 排序
220
+ row_arange = np.arange(sims.shape[0])[:, None]
221
+ top_sims_part = sims[row_arange, top_idx_part]
222
+ order = np.argsort(-top_sims_part, axis=1)
223
+ top_idx = np.take_along_axis(top_idx_part, order, axis=1)
224
+ top_sims = np.take_along_axis(top_sims_part, order, axis=1)
225
+
226
+ # 5) 写 JSONL + summary
227
+ out_path = Path(args.output)
228
+ out_path.parent.mkdir(parents=True, exist_ok=True)
229
+ summary_rows = []
230
+ print(f"[write] {out_path}")
231
+ with out_path.open("w", encoding="utf-8") as f:
232
+ for i, row in enumerate(rows):
233
+ topk = []
234
+ for j in range(K):
235
+ idx = int(top_idx[i, j])
236
+ topk.append({
237
+ "rank": ruler[idx]["rank"],
238
+ "score": ruler[idx]["score"],
239
+ "sim": float(top_sims[i, j]),
240
+ "item_id": ruler[idx]["item_id"],
241
+ })
242
+ sims_arr = np.array([t["sim"] for t in topk], dtype=float)
243
+ scores_arr = np.array([t["score"] for t in topk], dtype=float)
244
+ wsim = float(sims_arr.sum())
245
+ weighted_score = float((sims_arr * scores_arr).sum() / wsim) if wsim > 0 else 0.0
246
+ top1_score = topk[0]["score"]
247
+ pred = int(weighted_score >= args.boundary_score)
248
+ gt = int(row["label"] == "Y")
249
+ record = {
250
+ "task_id": row["task_id"],
251
+ "label": row["label"],
252
+ "ground_truth": gt,
253
+ "weighted_score": weighted_score,
254
+ "top1_score": top1_score,
255
+ "top1_sim": topk[0]["sim"],
256
+ "top1_rank": topk[0]["rank"],
257
+ "pred_by_weighted": pred,
258
+ "topk": topk,
259
+ }
260
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
261
+ summary_rows.append({
262
+ "task_id": row["task_id"],
263
+ "label": row["label"],
264
+ "ground_truth": gt,
265
+ "weighted_score": round(weighted_score, 4),
266
+ "top1_rank": topk[0]["rank"],
267
+ "top1_score": round(top1_score, 4),
268
+ "top1_sim": round(topk[0]["sim"], 4),
269
+ "top1_item_id": topk[0]["item_id"],
270
+ "pred_by_weighted": pred,
271
+ })
272
+
273
+ summary_csv = out_path.with_suffix(".summary.csv")
274
+ pd.DataFrame(summary_rows).to_csv(summary_csv, index=False)
275
+ print(f"[write] {summary_csv}")
276
+
277
+ # 6) 顺手算个总指标
278
+ sdf = pd.DataFrame(summary_rows)
279
+ if "ground_truth" in sdf.columns and len(sdf):
280
+ tp = int(((sdf.pred_by_weighted == 1) & (sdf.ground_truth == 1)).sum())
281
+ fp = int(((sdf.pred_by_weighted == 1) & (sdf.ground_truth == 0)).sum())
282
+ tn = int(((sdf.pred_by_weighted == 0) & (sdf.ground_truth == 0)).sum())
283
+ fn = int(((sdf.pred_by_weighted == 0) & (sdf.ground_truth == 1)).sum())
284
+ prec = tp / (tp + fp) if tp + fp else 0.0
285
+ rec = tp / (tp + fn) if tp + fn else 0.0
286
+ f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0.0
287
+ print(f"\n[metrics @ weighted_score >= {args.boundary_score}]")
288
+ print(f" TP={tp} FP={fp} TN={tn} FN={fn}")
289
+ print(f" precision={prec:.4f} recall={rec:.4f} f1={f1:.4f}")
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()