Sound / AIPF_warmstart_patch.md
Wendy-Fly's picture
AIPF warm-start patch: use embedding_position as binary search start
75f3322 verified

AIPF Warm-Start Patch — 用 embedding_position 当 binary search 起点

改动总览

文件 改什么 为什么
1. aipf_golden_set.csv 已加 estimated_position 数据源
2. pipeline/prepare_local_eval_data.py 把 csv 的 estimated_position 复制进 jsonl 把数据带进 pipeline
3. vendor/ranking_moderation/scripts/find_positions.py 从 item 取 estimated_position,传给 find_sample_index 传参
4. vendor/ranking_moderation/src/ranking_moderation/true_skill_ranking.py find_sample_index_binary_searchstart_pos 参数;第一轮 mid 用 start_pos 而不是中点 真正起作用的地方

数据流向:

csv (estimated_position)
   ↓ prepare_local_eval_data.py
jsonl (estimated_position 字段)
   ↓ find_positions.py (process_single_item)
bt_ranker.find_sample_index(start_pos=...)
   ↓
_binary_search(start_pos=embedding_position)   # 第一轮 mid = start_pos,不是 (0+199)//2

改动 1 — pipeline/prepare_local_eval_data.py

main() 里构建 records 那块(约第 108 行附近),给字典加一个字段:

records.append(
    {
        "report_id": report_id,
        "text": text,
        "label": label,
        "conversation_id": report_id,
        "conv_text": conv_text,
        "store_region": "",
        "alias2age_map": alias2age_map,
        "uid2alias_map": {},
        "msg_metadata": [],
        "msg_dict": {},
        # ↓↓↓ 新增 ↓↓↓
        "estimated_position": (
            int(row["estimated_position"])
            if "estimated_position" in df.columns
               and str(row["estimated_position"]).strip() not in ("", "nan", "None")
            else None
        ),
    }
)

改动 2 — vendor/ranking_moderation/scripts/find_positions.py

找到 process_single_item 函数(约第 267 行),改成:

def process_single_item(item, bt_ranker, config):
    if "search_method" in config["ranking"]:
        search_method = config["ranking"]["search_method"]
    else:
        search_method = 'similarity_search'

    # ↓↓↓ 新增:从 item 里取 estimated_position 当起点 ↓↓↓
    start_pos = item.get("estimated_position")
    if start_pos is not None:
        try:
            start_pos = int(start_pos)
        except (TypeError, ValueError):
            start_pos = None

    try:
        result = bt_ranker.find_sample_index(
            new_item=item,
            initial_comparisons=config["ranking"]["initial_comparisons"],
            num_rounds=config["ranking"]["num_rounds"],
            search_method=search_method,
            start_pos=start_pos,           # ← 新增传参
        )
        return {"success": True, "payload": result}
    except Exception as e:
        return {"success": False, "error": str(e)}

改动 3 — true_skill_ranking.py: find_sample_index 加 start_pos 参数

约第 936-940 行:

def find_sample_index(self,
                     new_item: Dict,
                     initial_comparisons: int = 5,
                     num_rounds: int = 3,
                     search_method: str = 'binary_search',
                     start_pos: int = None) -> Dict:    # ← 新增参数
    """...原 docstring...
    
    新增 Args:
        start_pos: 二分查找的起始位置(warm start)。None 时退化为标准二分(mid=中点)。
    """

约第 973-982 行的 dispatcher 改成:

    if search_method == 'heuristic_search':
        round_idx, round_candidates = self._heuristic_search(
            new_item, num_rounds, initial_comparisons, ruler_items, start_pos=start_pos)
    elif search_method == 'binary_search':
        round_idx, round_candidates = self._binary_search(
            new_item, ruler_items, start_pos=start_pos)
    elif search_method == 'full_traversal':
        round_idx, round_candidates = self._full_traversal(new_item, ruler_items)
    elif search_method == 'similarity_search':
        round_idx, round_candidates = self._similarity_search(
            new_item, num_rounds, initial_comparisons, ruler_items)
    else:
        raise ValueError(f"Unknown search method: {search_method}")

改动 4 — _binary_search 用 start_pos 当首轮 mid

替换整个 _binary_search 函数(约第 851-901 行):

def _binary_search(self, new_item: RulerItem, ruler_items: List[RulerItem],
                   start_pos: int = None) -> Tuple[int, List[Dict]]:
    """二分搜索 ruler 上的位置。
    
    start_pos: 首轮 mid 的位置(用于 embedding warm start)。
               None 时使用标准二分中点。
    """
    left, right = 0, len(ruler_items) - 1
    round_candidates = []
    round_idx = 0
    while left <= right:
        logger.debug(f"Starting round {round_idx + 1}")
        round_info = {
            'round': round_idx + 1,
            'candidates': [],
            'sampling_method': 'binary_search'
        }

        # ↓↓↓ 关键改动:首轮 mid 用 start_pos(夹到 [left, right] 之间) ↓↓↓
        if round_idx == 0 and start_pos is not None:
            mid = max(left, min(right, int(start_pos)))
            round_info['sampling_method'] = 'binary_search_warm_start'
        else:
            mid = (left + right) // 2

        mid_item = ruler_items[mid]

        pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, mid_item.item_id)
        if self.pairwise_comparison.get_comparison_count(pair_key) == 0:
            candidates = [(mid, mid_item)]
            round_info['candidates'] = [{'id': mid_item.item_id, 'position': mid,
                                         'score': mid_item.score, 'rank': mid_item.rank}]
        else:
            candidates = []

        logger.debug(f"Round {round_idx + 1} sampling info: method={round_info['sampling_method']}")
        for i, candidate in enumerate(round_info['candidates']):
            logger.debug(f"  {i+1}. ID: {candidate['id']}, Position: {candidate['position']}, Score: {candidate['score']:.4f}")

        if candidates:
            pairs = [(new_item.item, candidate[1].item) for candidate in candidates]
            self.pairwise_comparison.compare_pairs(pairs)
            new_item.score, new_item.sigma = self.estimate_new_item_score(new_item, ruler_items)

        if new_item.score <= mid_item.score:
            left = mid + 1
        else:
            right = mid - 1

        round_idx += 1
        round_info['item_trueskill'] = {'mu': new_item.score, 'sigma': new_item.sigma}
        round_candidates.append(round_info)

    return round_idx, round_candidates

改动 5(可选) — _heuristic_search 也支持 start_pos

如果你 pipeline.yaml 里 search_method: heuristic_search,也要改这个:

def _heuristic_search(self, new_item, num_rounds, initial_comparisons, ruler_items,
                      start_pos: int = None):
    round_candidates = []

    for round_idx in range(num_rounds):
        round_info = {'round': round_idx + 1, 'candidates': [], 'sampling_method': None}

        if round_idx == 0:
            # ↓↓↓ 关键改动:首轮如果有 start_pos,直接用它当唯一候选 ↓↓↓
            if start_pos is not None:
                idx = max(0, min(len(ruler_items) - 1, int(start_pos)))
                temp_item = ruler_items[idx]
                pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id)
                if self.pairwise_comparison.get_comparison_count(pair_key) == 0:
                    candidates = [(idx, temp_item)]
                else:
                    candidates = []
                round_info['sampling_method'] = 'heuristic_warm_start'
                round_info['candidates'] = [{'id': temp_item.item_id, 'position': idx,
                                             'score': temp_item.score, 'rank': temp_item.rank}]
            else:
                # 原来的 segment 均匀采样保持不变
                num_comparisons = initial_comparisons
                segment_size = max(1, len(ruler_items) // num_comparisons)
                candidates = []
                for i in range(0, len(ruler_items), segment_size):
                    segment = ruler_items[i:i + segment_size]
                    if segment:
                        for temp_item in segment:
                            pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id)
                            if self.pairwise_comparison.get_comparison_count(pair_key) == 0:
                                candidates.append((i, temp_item))
                                break
                round_info['sampling_method'] = 'uniform_segments'
                round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0],
                                             'score': c[1].score, 'rank': c[1].rank} for c in candidates]
        else:
            # round 1+ 维持原 score-based 逻辑(不变)
            score_diffs = []
            for i, item in enumerate(ruler_items):
                pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, item.item_id)
                if self.pairwise_comparison.get_comparison_count(pair_key) > 0:
                    continue
                score_diff = abs(item.score - new_item.score)
                score_diffs.append((i, item, score_diff))
            score_diffs.sort(key=lambda x: x[2])
            candidates = [(i, item) for i, item, _ in score_diffs[:initial_comparisons]]
            if not candidates:
                logger.info("No more pairs to compare, proceeding with final fitting")
                break
            round_info['sampling_method'] = 'score_based'
            round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0],
                                         'score': c[1].score, 'rank': c[1].rank} for c in candidates]

        # (后面原来的 compare_pairs / TrueSkill 更新逻辑保持不变)
        # ...省略你原来的代码...
        round_candidates.append(round_info)

    return num_rounds, round_candidates

验证步骤

改完之后:

# 1) 重新跑 prepare(确保 jsonl 里有 estimated_position 字段)
python pipeline/prepare_local_eval_data.py \
  --input_csv /mnt/.../aipf_golden_set.csv \
  --output_jsonl /tmp/test.jsonl

# 2) 看 jsonl 里有没有这个字段
head -1 /tmp/test.jsonl | python -m json.tool | grep estimated_position

# 应该看到:
#   "estimated_position": 91,

# 3) 跑完整流水线
bash adhoc_run.sh

# 4) 对比 metrics(同样 num_rounds=8 但 warm start vs cold start)
diff <(jq . cold_start_metrics.json) <(jq . warm_start_metrics.json)

预期效果

配置 num_rounds=4 num_rounds=8
冷启动(中点开始) F1=0.81 F1=0.82
热启动(embedding) F1=??? F1=???

理论上热启动4 轮就能达到原来 8 轮的效果(甚至更好),因为 embedding_position 已经把搜索范围缩到了一个比较准的局部,binary search 不需要全局收敛。

如果嵌入位置非常准(mean abs error < 30),4 轮足够;如果不准(error > 60),warm start 反而可能更差(被错误起点带偏)。

风险点

  1. embedding 估错的 case:A 类样本被 embedding 估到 89(< 106),warm start 会从 89 开始反而收敛到附近 → 假阳性风险增加。这是 trade-off。
  2. _heuristic_search round 0 改动:原版用 init_comparisons 个候选并行启动,warm start 只有 1 个候选 → 早期收敛信息少,后续轮次得跟上。