# AIPF Warm-Start Patch — 用 embedding_position 当 binary search 起点 ## 改动总览 | 文件 | 改什么 | 为什么 | |---|---|---| | 1. `aipf_golden_set.csv` | 已加 `estimated_position` 列 | 数据源 | | 2. `pipeline/prepare_local_eval_data.py` | 把 csv 的 `estimated_position` 复制进 jsonl | 把数据带进 pipeline | | 3. `vendor/ranking_moderation/scripts/find_positions.py` | 从 item 取 `estimated_position`,传给 `find_sample_index` | 传参 | | 4. `vendor/ranking_moderation/src/ranking_moderation/true_skill_ranking.py` | `find_sample_index` 和 `_binary_search` 加 `start_pos` 参数;第一轮 mid 用 `start_pos` 而不是中点 | 真正起作用的地方 | 数据流向: ``` csv (estimated_position) ↓ prepare_local_eval_data.py jsonl (estimated_position 字段) ↓ find_positions.py (process_single_item) bt_ranker.find_sample_index(start_pos=...) ↓ _binary_search(start_pos=embedding_position) # 第一轮 mid = start_pos,不是 (0+199)//2 ``` --- ## 改动 1 — `pipeline/prepare_local_eval_data.py` 在 `main()` 里构建 records 那块(约第 108 行附近),给字典加一个字段: ```python records.append( { "report_id": report_id, "text": text, "label": label, "conversation_id": report_id, "conv_text": conv_text, "store_region": "", "alias2age_map": alias2age_map, "uid2alias_map": {}, "msg_metadata": [], "msg_dict": {}, # ↓↓↓ 新增 ↓↓↓ "estimated_position": ( int(row["estimated_position"]) if "estimated_position" in df.columns and str(row["estimated_position"]).strip() not in ("", "nan", "None") else None ), } ) ``` --- ## 改动 2 — `vendor/ranking_moderation/scripts/find_positions.py` 找到 `process_single_item` 函数(约第 267 行),改成: ```python def process_single_item(item, bt_ranker, config): if "search_method" in config["ranking"]: search_method = config["ranking"]["search_method"] else: search_method = 'similarity_search' # ↓↓↓ 新增:从 item 里取 estimated_position 当起点 ↓↓↓ start_pos = item.get("estimated_position") if start_pos is not None: try: start_pos = int(start_pos) except (TypeError, ValueError): start_pos = None try: result = bt_ranker.find_sample_index( new_item=item, initial_comparisons=config["ranking"]["initial_comparisons"], num_rounds=config["ranking"]["num_rounds"], search_method=search_method, start_pos=start_pos, # ← 新增传参 ) return {"success": True, "payload": result} except Exception as e: return {"success": False, "error": str(e)} ``` --- ## 改动 3 — `true_skill_ranking.py: find_sample_index` 加 start_pos 参数 约第 936-940 行: ```python def find_sample_index(self, new_item: Dict, initial_comparisons: int = 5, num_rounds: int = 3, search_method: str = 'binary_search', start_pos: int = None) -> Dict: # ← 新增参数 """...原 docstring... 新增 Args: start_pos: 二分查找的起始位置(warm start)。None 时退化为标准二分(mid=中点)。 """ ``` 约第 973-982 行的 dispatcher 改成: ```python if search_method == 'heuristic_search': round_idx, round_candidates = self._heuristic_search( new_item, num_rounds, initial_comparisons, ruler_items, start_pos=start_pos) elif search_method == 'binary_search': round_idx, round_candidates = self._binary_search( new_item, ruler_items, start_pos=start_pos) elif search_method == 'full_traversal': round_idx, round_candidates = self._full_traversal(new_item, ruler_items) elif search_method == 'similarity_search': round_idx, round_candidates = self._similarity_search( new_item, num_rounds, initial_comparisons, ruler_items) else: raise ValueError(f"Unknown search method: {search_method}") ``` --- ## 改动 4 — `_binary_search` 用 start_pos 当首轮 mid 替换整个 `_binary_search` 函数(约第 851-901 行): ```python def _binary_search(self, new_item: RulerItem, ruler_items: List[RulerItem], start_pos: int = None) -> Tuple[int, List[Dict]]: """二分搜索 ruler 上的位置。 start_pos: 首轮 mid 的位置(用于 embedding warm start)。 None 时使用标准二分中点。 """ left, right = 0, len(ruler_items) - 1 round_candidates = [] round_idx = 0 while left <= right: logger.debug(f"Starting round {round_idx + 1}") round_info = { 'round': round_idx + 1, 'candidates': [], 'sampling_method': 'binary_search' } # ↓↓↓ 关键改动:首轮 mid 用 start_pos(夹到 [left, right] 之间) ↓↓↓ if round_idx == 0 and start_pos is not None: mid = max(left, min(right, int(start_pos))) round_info['sampling_method'] = 'binary_search_warm_start' else: mid = (left + right) // 2 mid_item = ruler_items[mid] pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, mid_item.item_id) if self.pairwise_comparison.get_comparison_count(pair_key) == 0: candidates = [(mid, mid_item)] round_info['candidates'] = [{'id': mid_item.item_id, 'position': mid, 'score': mid_item.score, 'rank': mid_item.rank}] else: candidates = [] logger.debug(f"Round {round_idx + 1} sampling info: method={round_info['sampling_method']}") for i, candidate in enumerate(round_info['candidates']): logger.debug(f" {i+1}. ID: {candidate['id']}, Position: {candidate['position']}, Score: {candidate['score']:.4f}") if candidates: pairs = [(new_item.item, candidate[1].item) for candidate in candidates] self.pairwise_comparison.compare_pairs(pairs) new_item.score, new_item.sigma = self.estimate_new_item_score(new_item, ruler_items) if new_item.score <= mid_item.score: left = mid + 1 else: right = mid - 1 round_idx += 1 round_info['item_trueskill'] = {'mu': new_item.score, 'sigma': new_item.sigma} round_candidates.append(round_info) return round_idx, round_candidates ``` --- ## 改动 5(可选) — `_heuristic_search` 也支持 start_pos 如果你 pipeline.yaml 里 `search_method: heuristic_search`,也要改这个: ```python def _heuristic_search(self, new_item, num_rounds, initial_comparisons, ruler_items, start_pos: int = None): round_candidates = [] for round_idx in range(num_rounds): round_info = {'round': round_idx + 1, 'candidates': [], 'sampling_method': None} if round_idx == 0: # ↓↓↓ 关键改动:首轮如果有 start_pos,直接用它当唯一候选 ↓↓↓ if start_pos is not None: idx = max(0, min(len(ruler_items) - 1, int(start_pos))) temp_item = ruler_items[idx] pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id) if self.pairwise_comparison.get_comparison_count(pair_key) == 0: candidates = [(idx, temp_item)] else: candidates = [] round_info['sampling_method'] = 'heuristic_warm_start' round_info['candidates'] = [{'id': temp_item.item_id, 'position': idx, 'score': temp_item.score, 'rank': temp_item.rank}] else: # 原来的 segment 均匀采样保持不变 num_comparisons = initial_comparisons segment_size = max(1, len(ruler_items) // num_comparisons) candidates = [] for i in range(0, len(ruler_items), segment_size): segment = ruler_items[i:i + segment_size] if segment: for temp_item in segment: pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id) if self.pairwise_comparison.get_comparison_count(pair_key) == 0: candidates.append((i, temp_item)) break round_info['sampling_method'] = 'uniform_segments' round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0], 'score': c[1].score, 'rank': c[1].rank} for c in candidates] else: # round 1+ 维持原 score-based 逻辑(不变) score_diffs = [] for i, item in enumerate(ruler_items): pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, item.item_id) if self.pairwise_comparison.get_comparison_count(pair_key) > 0: continue score_diff = abs(item.score - new_item.score) score_diffs.append((i, item, score_diff)) score_diffs.sort(key=lambda x: x[2]) candidates = [(i, item) for i, item, _ in score_diffs[:initial_comparisons]] if not candidates: logger.info("No more pairs to compare, proceeding with final fitting") break round_info['sampling_method'] = 'score_based' round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0], 'score': c[1].score, 'rank': c[1].rank} for c in candidates] # (后面原来的 compare_pairs / TrueSkill 更新逻辑保持不变) # ...省略你原来的代码... round_candidates.append(round_info) return num_rounds, round_candidates ``` --- ## 验证步骤 改完之后: ```bash # 1) 重新跑 prepare(确保 jsonl 里有 estimated_position 字段) python pipeline/prepare_local_eval_data.py \ --input_csv /mnt/.../aipf_golden_set.csv \ --output_jsonl /tmp/test.jsonl # 2) 看 jsonl 里有没有这个字段 head -1 /tmp/test.jsonl | python -m json.tool | grep estimated_position # 应该看到: # "estimated_position": 91, # 3) 跑完整流水线 bash adhoc_run.sh # 4) 对比 metrics(同样 num_rounds=8 但 warm start vs cold start) diff <(jq . cold_start_metrics.json) <(jq . warm_start_metrics.json) ``` ## 预期效果 | 配置 | num_rounds=4 | num_rounds=8 | |---|---|---| | 冷启动(中点开始)| F1=0.81 | F1=0.82 | | **热启动(embedding)** | F1=??? | F1=??? | 理论上热启动**4 轮就能达到原来 8 轮的效果**(甚至更好),因为 embedding_position 已经把搜索范围缩到了一个比较准的局部,binary search 不需要全局收敛。 如果嵌入位置非常准(mean abs error < 30),4 轮足够;如果不准(error > 60),warm start 反而可能更差(被错误起点带偏)。 ## 风险点 1. **embedding 估错的 case**:A 类样本被 embedding 估到 89(< 106),warm start 会从 89 开始反而收敛到附近 → 假阳性风险增加。这是 trade-off。 2. **_heuristic_search round 0 改动**:原版用 init_comparisons 个候选并行启动,warm start 只有 1 个候选 → 早期收敛信息少,后续轮次得跟上。