"""High‑level RAG pipeline orchestration.""" from __future__ import annotations import logging from typing import Dict, Any, List from .config import PipelineConfig from .retrievers import bm25, dense, hybrid from .generators.hf_generator import HFGenerator from .retrievers.base import Retriever, Context from .rerankers.cross_encoder import CrossEncoderReranker logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) class RAGPipeline: """Run retrieval → generation → scoring in a single object.""" def __init__(self, cfg: PipelineConfig): self.cfg = cfg self.retriever: Retriever = self._build_retriever(cfg) self.generator = HFGenerator( model_name=cfg.generator.model_name, device=cfg.generator.device ) if cfg.reranker.enable: self.reranker = CrossEncoderReranker( cfg.reranker.model_name, device=cfg.reranker.device, max_length=cfg.reranker.max_length, ) else: self.reranker = None # --------------------------------------------------------------------- # Public API # --------------------------------------------------------------------- def run(self, question: str) -> Dict[str, Any]: logger.info("Question: %s", question) # 1. raw retrieval k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k initial: List[Context] = self.retriever.retrieve(question, top_k=k_first) raw_hits = [ {"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)} for c in initial ] # 2. reranking (if enabled) if self.reranker: final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k) reranked_hits = [ { "text": c.text, "id": c.id, "score": getattr(c, "cross_encoder_score", None), } for c in reranked ] contexts_for_gen = reranked else: reranked_hits = [] contexts_for_gen = initial # 3. generation answer = self.generator.generate( question, [c.text for c in contexts_for_gen], max_new_tokens=self.cfg.generator.max_new_tokens, temperature=self.cfg.generator.temperature, ) return { "question": question, "raw_retrieval": raw_hits, "reranked": reranked_hits, "contexts": [c.text for c in contexts_for_gen], "answer": answer, } __call__ = run # alias def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]: """Accepts a list of {'question': str, 'id': Any}, returns list of result dicts.""" results: list[dict[str, Any]] = [] for entry in queries: q = entry.get("question", "") doc_id = entry.get("id") answer = self.run(q) results.append({"id": doc_id, "question": q, "answer": answer}) return results # --------------------------------------------------------------------- # Private helpers # --------------------------------------------------------------------- def _build_retriever(self, cfg: PipelineConfig) -> Retriever: r=cfg.retriever name = r.name if name == "bm25": return bm25.BM25Retriever(bm25_idx=str(r.bm25_idx), doc_store=str(r.doc_store)) if name == "dense": return dense.DenseRetriever( faiss_index=str(r.faiss_index), doc_store=str(r.doc_store), model_name=r.model_name, embedder_cache=str(r.embedder_cache) if r.embedder_cache else None, device=r.device, ) if name == "hybrid": return hybrid.HybridRetriever( str(r.bm25_index), str(r.faiss_index), doc_store=str(r.doc_store), alpha=r.alpha, model_name=r.model_name, embedder_cache=str(r.embedder_cache) if r.embedder_cache else None, device=r.device, ) raise ValueError(f"Unsupported retriever '{name}'") def _retrieve(self, question: str) -> List[Context]: logger.info("Retrieving top‑%d passages", self.cfg.retriever.top_k) k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k initial = self.retriever.retrieve(question, top_k=k_first) if self.reranker: final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k logger.info("Re-ranking %d docs with cross-encoder ...", len(initial)) initial = self.reranker.rerank(question, initial, k=final_k) return initial def _generate(self, question: str, contexts: List[Context]) -> str: texts = [c.text for c in contexts] logger.info("Generating answer with %d context passages", len(texts)) return self.generator.generate( question, texts, max_new_tokens=self.cfg.generator.max_new_tokens, temperature=self.cfg.generator.temperature, )