File size: 1,572 Bytes
7cc8bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import List, Dict
from FlagEmbedding import FlagReranker
import logging
import torch
import os
from sentence_transformers import CrossEncoder

class Reranker:
    def __init__(self, model_path="BAAI/bge-reranker-large"):
        try:
            self.model = FlagReranker(
                model_path,
                use_fp16=True,
                device="cuda" if torch.cuda.is_available() else "cpu"
            )
            logging.info(f"成功加载重排序模型 {model_path}{'cuda' if torch.cuda.is_available() else 'cpu'} 设备")
        except Exception as e:
            logging.error(f"加载重排序模型失败: {str(e)}")
            raise

    def rerank(self, query: str, passages: List[Dict]) -> List[Dict]:
        """
        对文档进行重排序
        """
        try:
            # 准备文本列表
            texts = [p['passage'] for p in passages]
            
            # 执行重排序
            scores = self.model.compute_score([[query, text] for text in texts])
            
            # 将分数添加到原始字典中
            for passage, score in zip(passages, scores):
                passage['rerank_score'] = float(score)
            
            # 按重排序分数排序
            reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True)
            
            return reranked
            
        except Exception as e:
            logging.error(f"重排序过程中出错: {str(e)}")
            # 如果重排序失败,返回原始排序
            return passages