Spaces:
Running
Running
File size: 1,572 Bytes
7cc8bc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from typing import List, Dict
from FlagEmbedding import FlagReranker
import logging
import torch
import os
from sentence_transformers import CrossEncoder
class Reranker:
def __init__(self, model_path="BAAI/bge-reranker-large"):
try:
self.model = FlagReranker(
model_path,
use_fp16=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
logging.info(f"成功加载重排序模型 {model_path} 到 {'cuda' if torch.cuda.is_available() else 'cpu'} 设备")
except Exception as e:
logging.error(f"加载重排序模型失败: {str(e)}")
raise
def rerank(self, query: str, passages: List[Dict]) -> List[Dict]:
"""
对文档进行重排序
"""
try:
# 准备文本列表
texts = [p['passage'] for p in passages]
# 执行重排序
scores = self.model.compute_score([[query, text] for text in texts])
# 将分数添加到原始字典中
for passage, score in zip(passages, scores):
passage['rerank_score'] = float(score)
# 按重排序分数排序
reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True)
return reranked
except Exception as e:
logging.error(f"重排序过程中出错: {str(e)}")
# 如果重排序失败,返回原始排序
return passages |