# file: retriever.py import faiss import numpy as np import torch import re from collections import defaultdict from rank_bm25 import BM25Okapi def tokenize_vi_for_bm25_setup(text): """Tokenize tiếng Việt đơn giản cho BM25.""" text = text.lower() text = re.sub(r'[^\w\s]', '', text) return text.split() def _get_vehicle_type(query_lower: str) -> str | None: """Xác định loại xe được đề cập trong câu truy vấn.""" # Từ điển định nghĩa các từ khóa cho từng loại xe vehicle_keywords = { "ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"], "xe máy": ["xe máy", "xe mô tô", "xe gắn máy"], "xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"] } for vehicle_type, keywords in vehicle_keywords.items(): if any(keyword in query_lower for keyword in keywords): return vehicle_type return None def search_relevant_laws( query_text: str, embedding_model, faiss_index, chunks_data: list[dict], bm25_model, k: int = 5, initial_k_multiplier: int = 15, rrf_k_constant: int = 60 ) -> list[dict]: """ Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe. Quy trình: 1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa). 2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF). 3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe). 4. Sắp xếp lại và trả về top-k kết quả cuối cùng. """ if k <= 0: return [] num_vectors_in_index = faiss_index.ntotal if num_vectors_in_index == 0: return [] num_candidates = min(k * initial_k_multiplier, num_vectors_in_index) # --- 1. Semantic Search (FAISS) --- try: query_embedding = embedding_model.encode([query_text], convert_to_tensor=True) query_embedding_np = query_embedding.cpu().numpy().astype('float32') faiss.normalize_L2(query_embedding_np) _, semantic_indices = faiss_index.search(query_embedding_np, num_candidates) semantic_indices = semantic_indices[0] except Exception as e: print(f"Lỗi FAISS search: {e}") semantic_indices = [] # --- 2. Keyword Search (BM25) --- try: tokenized_query = tokenize_vi_for_bm25_setup(query_text) bm25_scores = bm25_model.get_scores(tokenized_query) # Lấy top N chỉ mục từ BM25 top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates] except Exception as e: print(f"Lỗi BM25 search: {e}") top_bm25_indices = [] # --- 3. Result Fusion (RRF) --- rrf_scores = defaultdict(float) all_indices = set(semantic_indices) | set(top_bm25_indices) for rank, doc_idx in enumerate(semantic_indices): rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) for rank, doc_idx in enumerate(top_bm25_indices): rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) # --- 4. Metadata Boosting & Final Ranking --- query_lower = query_text.lower() matched_vehicle = _get_vehicle_type(query_lower) final_results = [] for doc_idx in all_indices: try: metadata = chunks_data[doc_idx].get('metadata', {}) final_score = rrf_scores[doc_idx] # **LOGIC BOOSTING QUAN TRỌNG NHẤT** if matched_vehicle: article_title_lower = metadata.get("article_title", "").lower() # Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài vehicle_keywords = { "ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"], "xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"] } if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])): # Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên final_score += 0.5 final_results.append({ 'index': doc_idx, 'final_score': final_score }) except IndexError: continue final_results.sort(key=lambda x: x['final_score'], reverse=True) # Lấy đầy đủ thông tin cho top-k kết quả cuối cùng top_k_results = [] for res in final_results[:k]: doc_idx = res['index'] top_k_results.append({ 'index': doc_idx, 'final_score': res['final_score'], 'text': chunks_data[doc_idx].get('text', ''), 'metadata': chunks_data[doc_idx].get('metadata', {}) }) return top_k_results