| # AIPF Warm-Start Patch — 用 embedding_position 当 binary search 起点 |
| |
| ## 改动总览 |
| |
| | 文件 | 改什么 | 为什么 | |
| |---|---|---| |
| | 1. `aipf_golden_set.csv` | 已加 `estimated_position` 列 | 数据源 | |
| | 2. `pipeline/prepare_local_eval_data.py` | 把 csv 的 `estimated_position` 复制进 jsonl | 把数据带进 pipeline | |
| | 3. `vendor/ranking_moderation/scripts/find_positions.py` | 从 item 取 `estimated_position`,传给 `find_sample_index` | 传参 | |
| | 4. `vendor/ranking_moderation/src/ranking_moderation/true_skill_ranking.py` | `find_sample_index` 和 `_binary_search` 加 `start_pos` 参数;第一轮 mid 用 `start_pos` 而不是中点 | 真正起作用的地方 | |
|
|
| 数据流向: |
|
|
| ``` |
| csv (estimated_position) |
| ↓ prepare_local_eval_data.py |
| jsonl (estimated_position 字段) |
| ↓ find_positions.py (process_single_item) |
| bt_ranker.find_sample_index(start_pos=...) |
| ↓ |
| _binary_search(start_pos=embedding_position) # 第一轮 mid = start_pos,不是 (0+199)//2 |
| ``` |
|
|
| --- |
|
|
| ## 改动 1 — `pipeline/prepare_local_eval_data.py` |
| |
| 在 `main()` 里构建 records 那块(约第 108 行附近),给字典加一个字段: |
| |
| ```python |
| records.append( |
| { |
| "report_id": report_id, |
| "text": text, |
| "label": label, |
| "conversation_id": report_id, |
| "conv_text": conv_text, |
| "store_region": "", |
| "alias2age_map": alias2age_map, |
| "uid2alias_map": {}, |
| "msg_metadata": [], |
| "msg_dict": {}, |
| # ↓↓↓ 新增 ↓↓↓ |
| "estimated_position": ( |
| int(row["estimated_position"]) |
| if "estimated_position" in df.columns |
| and str(row["estimated_position"]).strip() not in ("", "nan", "None") |
| else None |
| ), |
| } |
| ) |
| ``` |
| |
| --- |
|
|
| ## 改动 2 — `vendor/ranking_moderation/scripts/find_positions.py` |
|
|
| 找到 `process_single_item` 函数(约第 267 行),改成: |
|
|
| ```python |
| def process_single_item(item, bt_ranker, config): |
| if "search_method" in config["ranking"]: |
| search_method = config["ranking"]["search_method"] |
| else: |
| search_method = 'similarity_search' |
| |
| # ↓↓↓ 新增:从 item 里取 estimated_position 当起点 ↓↓↓ |
| start_pos = item.get("estimated_position") |
| if start_pos is not None: |
| try: |
| start_pos = int(start_pos) |
| except (TypeError, ValueError): |
| start_pos = None |
| |
| try: |
| result = bt_ranker.find_sample_index( |
| new_item=item, |
| initial_comparisons=config["ranking"]["initial_comparisons"], |
| num_rounds=config["ranking"]["num_rounds"], |
| search_method=search_method, |
| start_pos=start_pos, # ← 新增传参 |
| ) |
| return {"success": True, "payload": result} |
| except Exception as e: |
| return {"success": False, "error": str(e)} |
| ``` |
|
|
| --- |
|
|
| ## 改动 3 — `true_skill_ranking.py: find_sample_index` 加 start_pos 参数 |
| |
| 约第 936-940 行: |
| |
| ```python |
| def find_sample_index(self, |
| new_item: Dict, |
| initial_comparisons: int = 5, |
| num_rounds: int = 3, |
| search_method: str = 'binary_search', |
| start_pos: int = None) -> Dict: # ← 新增参数 |
| """...原 docstring... |
| |
| 新增 Args: |
| start_pos: 二分查找的起始位置(warm start)。None 时退化为标准二分(mid=中点)。 |
| """ |
| ``` |
| |
| 约第 973-982 行的 dispatcher 改成: |
|
|
| ```python |
| if search_method == 'heuristic_search': |
| round_idx, round_candidates = self._heuristic_search( |
| new_item, num_rounds, initial_comparisons, ruler_items, start_pos=start_pos) |
| elif search_method == 'binary_search': |
| round_idx, round_candidates = self._binary_search( |
| new_item, ruler_items, start_pos=start_pos) |
| elif search_method == 'full_traversal': |
| round_idx, round_candidates = self._full_traversal(new_item, ruler_items) |
| elif search_method == 'similarity_search': |
| round_idx, round_candidates = self._similarity_search( |
| new_item, num_rounds, initial_comparisons, ruler_items) |
| else: |
| raise ValueError(f"Unknown search method: {search_method}") |
| ``` |
|
|
| --- |
|
|
| ## 改动 4 — `_binary_search` 用 start_pos 当首轮 mid |
| |
| 替换整个 `_binary_search` 函数(约第 851-901 行): |
| |
| ```python |
| def _binary_search(self, new_item: RulerItem, ruler_items: List[RulerItem], |
| start_pos: int = None) -> Tuple[int, List[Dict]]: |
| """二分搜索 ruler 上的位置。 |
| |
| start_pos: 首轮 mid 的位置(用于 embedding warm start)。 |
| None 时使用标准二分中点。 |
| """ |
| left, right = 0, len(ruler_items) - 1 |
| round_candidates = [] |
| round_idx = 0 |
| while left <= right: |
| logger.debug(f"Starting round {round_idx + 1}") |
| round_info = { |
| 'round': round_idx + 1, |
| 'candidates': [], |
| 'sampling_method': 'binary_search' |
| } |
| |
| # ↓↓↓ 关键改动:首轮 mid 用 start_pos(夹到 [left, right] 之间) ↓↓↓ |
| if round_idx == 0 and start_pos is not None: |
| mid = max(left, min(right, int(start_pos))) |
| round_info['sampling_method'] = 'binary_search_warm_start' |
| else: |
| mid = (left + right) // 2 |
| |
| mid_item = ruler_items[mid] |
| |
| pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, mid_item.item_id) |
| if self.pairwise_comparison.get_comparison_count(pair_key) == 0: |
| candidates = [(mid, mid_item)] |
| round_info['candidates'] = [{'id': mid_item.item_id, 'position': mid, |
| 'score': mid_item.score, 'rank': mid_item.rank}] |
| else: |
| candidates = [] |
| |
| logger.debug(f"Round {round_idx + 1} sampling info: method={round_info['sampling_method']}") |
| for i, candidate in enumerate(round_info['candidates']): |
| logger.debug(f" {i+1}. ID: {candidate['id']}, Position: {candidate['position']}, Score: {candidate['score']:.4f}") |
| |
| if candidates: |
| pairs = [(new_item.item, candidate[1].item) for candidate in candidates] |
| self.pairwise_comparison.compare_pairs(pairs) |
| new_item.score, new_item.sigma = self.estimate_new_item_score(new_item, ruler_items) |
| |
| if new_item.score <= mid_item.score: |
| left = mid + 1 |
| else: |
| right = mid - 1 |
| |
| round_idx += 1 |
| round_info['item_trueskill'] = {'mu': new_item.score, 'sigma': new_item.sigma} |
| round_candidates.append(round_info) |
| |
| return round_idx, round_candidates |
| ``` |
| |
| --- |
|
|
| ## 改动 5(可选) — `_heuristic_search` 也支持 start_pos |
| |
| 如果你 pipeline.yaml 里 `search_method: heuristic_search`,也要改这个: |
| |
| ```python |
| def _heuristic_search(self, new_item, num_rounds, initial_comparisons, ruler_items, |
| start_pos: int = None): |
| round_candidates = [] |
| |
| for round_idx in range(num_rounds): |
| round_info = {'round': round_idx + 1, 'candidates': [], 'sampling_method': None} |
| |
| if round_idx == 0: |
| # ↓↓↓ 关键改动:首轮如果有 start_pos,直接用它当唯一候选 ↓↓↓ |
| if start_pos is not None: |
| idx = max(0, min(len(ruler_items) - 1, int(start_pos))) |
| temp_item = ruler_items[idx] |
| pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id) |
| if self.pairwise_comparison.get_comparison_count(pair_key) == 0: |
| candidates = [(idx, temp_item)] |
| else: |
| candidates = [] |
| round_info['sampling_method'] = 'heuristic_warm_start' |
| round_info['candidates'] = [{'id': temp_item.item_id, 'position': idx, |
| 'score': temp_item.score, 'rank': temp_item.rank}] |
| else: |
| # 原来的 segment 均匀采样保持不变 |
| num_comparisons = initial_comparisons |
| segment_size = max(1, len(ruler_items) // num_comparisons) |
| candidates = [] |
| for i in range(0, len(ruler_items), segment_size): |
| segment = ruler_items[i:i + segment_size] |
| if segment: |
| for temp_item in segment: |
| pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id) |
| if self.pairwise_comparison.get_comparison_count(pair_key) == 0: |
| candidates.append((i, temp_item)) |
| break |
| round_info['sampling_method'] = 'uniform_segments' |
| round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0], |
| 'score': c[1].score, 'rank': c[1].rank} for c in candidates] |
| else: |
| # round 1+ 维持原 score-based 逻辑(不变) |
| score_diffs = [] |
| for i, item in enumerate(ruler_items): |
| pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, item.item_id) |
| if self.pairwise_comparison.get_comparison_count(pair_key) > 0: |
| continue |
| score_diff = abs(item.score - new_item.score) |
| score_diffs.append((i, item, score_diff)) |
| score_diffs.sort(key=lambda x: x[2]) |
| candidates = [(i, item) for i, item, _ in score_diffs[:initial_comparisons]] |
| if not candidates: |
| logger.info("No more pairs to compare, proceeding with final fitting") |
| break |
| round_info['sampling_method'] = 'score_based' |
| round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0], |
| 'score': c[1].score, 'rank': c[1].rank} for c in candidates] |
| |
| # (后面原来的 compare_pairs / TrueSkill 更新逻辑保持不变) |
| # ...省略你原来的代码... |
| round_candidates.append(round_info) |
| |
| return num_rounds, round_candidates |
| ``` |
| |
| --- |
|
|
| ## 验证步骤 |
|
|
| 改完之后: |
|
|
| ```bash |
| # 1) 重新跑 prepare(确保 jsonl 里有 estimated_position 字段) |
| python pipeline/prepare_local_eval_data.py \ |
| --input_csv /mnt/.../aipf_golden_set.csv \ |
| --output_jsonl /tmp/test.jsonl |
| |
| # 2) 看 jsonl 里有没有这个字段 |
| head -1 /tmp/test.jsonl | python -m json.tool | grep estimated_position |
| |
| # 应该看到: |
| # "estimated_position": 91, |
| |
| # 3) 跑完整流水线 |
| bash adhoc_run.sh |
| |
| # 4) 对比 metrics(同样 num_rounds=8 但 warm start vs cold start) |
| diff <(jq . cold_start_metrics.json) <(jq . warm_start_metrics.json) |
| ``` |
|
|
| ## 预期效果 |
|
|
| | 配置 | num_rounds=4 | num_rounds=8 | |
| |---|---|---| |
| | 冷启动(中点开始)| F1=0.81 | F1=0.82 | |
| | **热启动(embedding)** | F1=??? | F1=??? | |
|
|
| 理论上热启动**4 轮就能达到原来 8 轮的效果**(甚至更好),因为 embedding_position 已经把搜索范围缩到了一个比较准的局部,binary search 不需要全局收敛。 |
| |
| 如果嵌入位置非常准(mean abs error < 30),4 轮足够;如果不准(error > 60),warm start 反而可能更差(被错误起点带偏)。 |
| |
| ## 风险点 |
| |
| 1. **embedding 估错的 case**:A 类样本被 embedding 估到 89(< 106),warm start 会从 89 开始反而收敛到附近 → 假阳性风险增加。这是 trade-off。 |
| 2. **_heuristic_search round 0 改动**:原版用 init_comparisons 个候选并行启动,warm start 只有 1 个候选 → 早期收敛信息少,后续轮次得跟上。 |
|
|