AIPF Warm-Start Patch — 用 embedding_position 当 binary search 起点
改动总览
| 文件 | 改什么 | 为什么 |
|---|---|---|
1. aipf_golden_set.csv |
已加 estimated_position 列 |
数据源 |
2. pipeline/prepare_local_eval_data.py |
把 csv 的 estimated_position 复制进 jsonl |
把数据带进 pipeline |
3. vendor/ranking_moderation/scripts/find_positions.py |
从 item 取 estimated_position,传给 find_sample_index |
传参 |
4. vendor/ranking_moderation/src/ranking_moderation/true_skill_ranking.py |
find_sample_index 和 _binary_search 加 start_pos 参数;第一轮 mid 用 start_pos 而不是中点 |
真正起作用的地方 |
数据流向:
csv (estimated_position)
↓ prepare_local_eval_data.py
jsonl (estimated_position 字段)
↓ find_positions.py (process_single_item)
bt_ranker.find_sample_index(start_pos=...)
↓
_binary_search(start_pos=embedding_position) # 第一轮 mid = start_pos,不是 (0+199)//2
改动 1 — pipeline/prepare_local_eval_data.py
在 main() 里构建 records 那块(约第 108 行附近),给字典加一个字段:
records.append(
{
"report_id": report_id,
"text": text,
"label": label,
"conversation_id": report_id,
"conv_text": conv_text,
"store_region": "",
"alias2age_map": alias2age_map,
"uid2alias_map": {},
"msg_metadata": [],
"msg_dict": {},
# ↓↓↓ 新增 ↓↓↓
"estimated_position": (
int(row["estimated_position"])
if "estimated_position" in df.columns
and str(row["estimated_position"]).strip() not in ("", "nan", "None")
else None
),
}
)
改动 2 — vendor/ranking_moderation/scripts/find_positions.py
找到 process_single_item 函数(约第 267 行),改成:
def process_single_item(item, bt_ranker, config):
if "search_method" in config["ranking"]:
search_method = config["ranking"]["search_method"]
else:
search_method = 'similarity_search'
# ↓↓↓ 新增:从 item 里取 estimated_position 当起点 ↓↓↓
start_pos = item.get("estimated_position")
if start_pos is not None:
try:
start_pos = int(start_pos)
except (TypeError, ValueError):
start_pos = None
try:
result = bt_ranker.find_sample_index(
new_item=item,
initial_comparisons=config["ranking"]["initial_comparisons"],
num_rounds=config["ranking"]["num_rounds"],
search_method=search_method,
start_pos=start_pos, # ← 新增传参
)
return {"success": True, "payload": result}
except Exception as e:
return {"success": False, "error": str(e)}
改动 3 — true_skill_ranking.py: find_sample_index 加 start_pos 参数
约第 936-940 行:
def find_sample_index(self,
new_item: Dict,
initial_comparisons: int = 5,
num_rounds: int = 3,
search_method: str = 'binary_search',
start_pos: int = None) -> Dict: # ← 新增参数
"""...原 docstring...
新增 Args:
start_pos: 二分查找的起始位置(warm start)。None 时退化为标准二分(mid=中点)。
"""
约第 973-982 行的 dispatcher 改成:
if search_method == 'heuristic_search':
round_idx, round_candidates = self._heuristic_search(
new_item, num_rounds, initial_comparisons, ruler_items, start_pos=start_pos)
elif search_method == 'binary_search':
round_idx, round_candidates = self._binary_search(
new_item, ruler_items, start_pos=start_pos)
elif search_method == 'full_traversal':
round_idx, round_candidates = self._full_traversal(new_item, ruler_items)
elif search_method == 'similarity_search':
round_idx, round_candidates = self._similarity_search(
new_item, num_rounds, initial_comparisons, ruler_items)
else:
raise ValueError(f"Unknown search method: {search_method}")
改动 4 — _binary_search 用 start_pos 当首轮 mid
替换整个 _binary_search 函数(约第 851-901 行):
def _binary_search(self, new_item: RulerItem, ruler_items: List[RulerItem],
start_pos: int = None) -> Tuple[int, List[Dict]]:
"""二分搜索 ruler 上的位置。
start_pos: 首轮 mid 的位置(用于 embedding warm start)。
None 时使用标准二分中点。
"""
left, right = 0, len(ruler_items) - 1
round_candidates = []
round_idx = 0
while left <= right:
logger.debug(f"Starting round {round_idx + 1}")
round_info = {
'round': round_idx + 1,
'candidates': [],
'sampling_method': 'binary_search'
}
# ↓↓↓ 关键改动:首轮 mid 用 start_pos(夹到 [left, right] 之间) ↓↓↓
if round_idx == 0 and start_pos is not None:
mid = max(left, min(right, int(start_pos)))
round_info['sampling_method'] = 'binary_search_warm_start'
else:
mid = (left + right) // 2
mid_item = ruler_items[mid]
pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, mid_item.item_id)
if self.pairwise_comparison.get_comparison_count(pair_key) == 0:
candidates = [(mid, mid_item)]
round_info['candidates'] = [{'id': mid_item.item_id, 'position': mid,
'score': mid_item.score, 'rank': mid_item.rank}]
else:
candidates = []
logger.debug(f"Round {round_idx + 1} sampling info: method={round_info['sampling_method']}")
for i, candidate in enumerate(round_info['candidates']):
logger.debug(f" {i+1}. ID: {candidate['id']}, Position: {candidate['position']}, Score: {candidate['score']:.4f}")
if candidates:
pairs = [(new_item.item, candidate[1].item) for candidate in candidates]
self.pairwise_comparison.compare_pairs(pairs)
new_item.score, new_item.sigma = self.estimate_new_item_score(new_item, ruler_items)
if new_item.score <= mid_item.score:
left = mid + 1
else:
right = mid - 1
round_idx += 1
round_info['item_trueskill'] = {'mu': new_item.score, 'sigma': new_item.sigma}
round_candidates.append(round_info)
return round_idx, round_candidates
改动 5(可选) — _heuristic_search 也支持 start_pos
如果你 pipeline.yaml 里 search_method: heuristic_search,也要改这个:
def _heuristic_search(self, new_item, num_rounds, initial_comparisons, ruler_items,
start_pos: int = None):
round_candidates = []
for round_idx in range(num_rounds):
round_info = {'round': round_idx + 1, 'candidates': [], 'sampling_method': None}
if round_idx == 0:
# ↓↓↓ 关键改动:首轮如果有 start_pos,直接用它当唯一候选 ↓↓↓
if start_pos is not None:
idx = max(0, min(len(ruler_items) - 1, int(start_pos)))
temp_item = ruler_items[idx]
pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id)
if self.pairwise_comparison.get_comparison_count(pair_key) == 0:
candidates = [(idx, temp_item)]
else:
candidates = []
round_info['sampling_method'] = 'heuristic_warm_start'
round_info['candidates'] = [{'id': temp_item.item_id, 'position': idx,
'score': temp_item.score, 'rank': temp_item.rank}]
else:
# 原来的 segment 均匀采样保持不变
num_comparisons = initial_comparisons
segment_size = max(1, len(ruler_items) // num_comparisons)
candidates = []
for i in range(0, len(ruler_items), segment_size):
segment = ruler_items[i:i + segment_size]
if segment:
for temp_item in segment:
pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, temp_item.item_id)
if self.pairwise_comparison.get_comparison_count(pair_key) == 0:
candidates.append((i, temp_item))
break
round_info['sampling_method'] = 'uniform_segments'
round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0],
'score': c[1].score, 'rank': c[1].rank} for c in candidates]
else:
# round 1+ 维持原 score-based 逻辑(不变)
score_diffs = []
for i, item in enumerate(ruler_items):
pair_key = self.pairwise_comparison.get_pair_key(new_item.item_id, item.item_id)
if self.pairwise_comparison.get_comparison_count(pair_key) > 0:
continue
score_diff = abs(item.score - new_item.score)
score_diffs.append((i, item, score_diff))
score_diffs.sort(key=lambda x: x[2])
candidates = [(i, item) for i, item, _ in score_diffs[:initial_comparisons]]
if not candidates:
logger.info("No more pairs to compare, proceeding with final fitting")
break
round_info['sampling_method'] = 'score_based'
round_info['candidates'] = [{'id': c[1].item_id, 'position': c[0],
'score': c[1].score, 'rank': c[1].rank} for c in candidates]
# (后面原来的 compare_pairs / TrueSkill 更新逻辑保持不变)
# ...省略你原来的代码...
round_candidates.append(round_info)
return num_rounds, round_candidates
验证步骤
改完之后:
# 1) 重新跑 prepare(确保 jsonl 里有 estimated_position 字段)
python pipeline/prepare_local_eval_data.py \
--input_csv /mnt/.../aipf_golden_set.csv \
--output_jsonl /tmp/test.jsonl
# 2) 看 jsonl 里有没有这个字段
head -1 /tmp/test.jsonl | python -m json.tool | grep estimated_position
# 应该看到:
# "estimated_position": 91,
# 3) 跑完整流水线
bash adhoc_run.sh
# 4) 对比 metrics(同样 num_rounds=8 但 warm start vs cold start)
diff <(jq . cold_start_metrics.json) <(jq . warm_start_metrics.json)
预期效果
| 配置 | num_rounds=4 | num_rounds=8 |
|---|---|---|
| 冷启动(中点开始) | F1=0.81 | F1=0.82 |
| 热启动(embedding) | F1=??? | F1=??? |
理论上热启动4 轮就能达到原来 8 轮的效果(甚至更好),因为 embedding_position 已经把搜索范围缩到了一个比较准的局部,binary search 不需要全局收敛。
如果嵌入位置非常准(mean abs error < 30),4 轮足够;如果不准(error > 60),warm start 反而可能更差(被错误起点带偏)。
风险点
- embedding 估错的 case:A 类样本被 embedding 估到 89(< 106),warm start 会从 89 开始反而收敛到附近 → 假阳性风险增加。这是 trade-off。
- _heuristic_search round 0 改动:原版用 init_comparisons 个候选并行启动,warm start 只有 1 个候选 → 早期收敛信息少,后续轮次得跟上。