Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import logging | |
from typing import List, Optional | |
from .base import Context, Retriever | |
from .bm25 import BM25Retriever | |
from .dense import DenseRetriever | |
logger = logging.getLogger(__name__) | |
class HybridRetriever(Retriever): | |
"""Combine BM25 and Dense retrievers by normalising and summing scores.""" | |
def __init__( | |
self, | |
bm25_idx: str, | |
faiss_index: str, | |
doc_store: str, | |
*, | |
alpha: float = 0.5, | |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | |
embedder_cache: Optional[str] = None, | |
device: str = "cpu", | |
): | |
# 1) BM25 retriever | |
self.bm25 = BM25Retriever(bm25_idx, doc_store=doc_store) | |
# 2) Dense retriever | |
self.dense = DenseRetriever( | |
faiss_index=faiss_index, | |
doc_store=doc_store, | |
model_name=model_name, | |
embedder_cache=embedder_cache, | |
device=device, | |
) | |
if not 0 <= alpha <= 1: | |
raise ValueError("alpha must be in [0, 1]") | |
self.alpha = alpha | |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]: | |
# 1) Get sparse hits | |
sparse_hits = self.bm25.retrieve(query, top_k=top_k) | |
sparse_dict = {ctx.id: ctx for ctx in sparse_hits} | |
# 2) Get dense hits | |
dense_hits = self.dense.retrieve(query, top_k=top_k) | |
dense_dict = {ctx.id: ctx for ctx in dense_hits} | |
# 3) Union of all IDs | |
all_ids = set(sparse_dict) | set(dense_dict) | |
merged: List[Context] = [] | |
for doc_id in all_ids: | |
s_score = sparse_dict.get(doc_id, Context(doc_id, "", 0.0)).score | |
d_score = dense_dict.get(doc_id, Context(doc_id, "", 0.0)).score | |
combined_score = self.alpha * s_score + (1 - self.alpha) * d_score | |
# Prefer the text from whichever retriever has this doc_id present; | |
# if only one side has it, grab that text. | |
if doc_id in sparse_dict: | |
text = sparse_dict[doc_id].text | |
else: | |
text = dense_dict[doc_id].text | |
merged.append(Context(id=doc_id, text=text, score=combined_score)) | |
# 4) Sort by score descending | |
merged.sort(key=lambda c: c.score, reverse=True) | |
return merged[:top_k] | |