Spaces:
Paused
Paused
# file: retriever.py | |
import faiss | |
import numpy as np | |
import torch | |
import re | |
from collections import defaultdict | |
from rank_bm25 import BM25Okapi | |
def tokenize_vi_for_bm25_setup(text): | |
"""Tokenize tiếng Việt đơn giản cho BM25.""" | |
text = text.lower() | |
text = re.sub(r'[^\w\s]', '', text) | |
return text.split() | |
def _get_vehicle_type(query_lower: str) -> str | None: | |
"""Xác định loại xe được đề cập trong câu truy vấn.""" | |
# Từ điển định nghĩa các từ khóa cho từng loại xe | |
vehicle_keywords = { | |
"ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"], | |
"xe máy": ["xe máy", "xe mô tô", "xe gắn máy"], | |
"xe đạp": ["xe đạp", "xe thô sơ"], | |
"máy kéo": ["máy kéo", "xe chuyên dùng"] | |
} | |
for vehicle_type, keywords in vehicle_keywords.items(): | |
if any(keyword in query_lower for keyword in keywords): | |
return vehicle_type | |
return None | |
def search_relevant_laws( | |
query_text: str, | |
embedding_model, | |
faiss_index, | |
chunks_data: list[dict], | |
bm25_model, | |
k: int = 5, | |
initial_k_multiplier: int = 15, | |
rrf_k_constant: int = 60 | |
) -> list[dict]: | |
""" | |
Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe. | |
Quy trình: | |
1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa). | |
2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF). | |
3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe). | |
4. Sắp xếp lại và trả về top-k kết quả cuối cùng. | |
""" | |
if k <= 0: | |
return [] | |
num_vectors_in_index = faiss_index.ntotal | |
if num_vectors_in_index == 0: | |
return [] | |
num_candidates = min(k * initial_k_multiplier, num_vectors_in_index) | |
# --- 1. Semantic Search (FAISS) --- | |
try: | |
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True) | |
query_embedding_np = query_embedding.cpu().numpy().astype('float32') | |
faiss.normalize_L2(query_embedding_np) | |
_, semantic_indices = faiss_index.search(query_embedding_np, num_candidates) | |
semantic_indices = semantic_indices[0] | |
except Exception as e: | |
print(f"Lỗi FAISS search: {e}") | |
semantic_indices = [] | |
# --- 2. Keyword Search (BM25) --- | |
try: | |
tokenized_query = tokenize_vi_for_bm25_setup(query_text) | |
bm25_scores = bm25_model.get_scores(tokenized_query) | |
# Lấy top N chỉ mục từ BM25 | |
top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates] | |
except Exception as e: | |
print(f"Lỗi BM25 search: {e}") | |
top_bm25_indices = [] | |
# --- 3. Result Fusion (RRF) --- | |
rrf_scores = defaultdict(float) | |
all_indices = set(semantic_indices) | set(top_bm25_indices) | |
for rank, doc_idx in enumerate(semantic_indices): | |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
for rank, doc_idx in enumerate(top_bm25_indices): | |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
# --- 4. Metadata Boosting & Final Ranking --- | |
query_lower = query_text.lower() | |
matched_vehicle = _get_vehicle_type(query_lower) | |
final_results = [] | |
for doc_idx in all_indices: | |
try: | |
metadata = chunks_data[doc_idx].get('metadata', {}) | |
final_score = rrf_scores[doc_idx] | |
# **LOGIC BOOSTING QUAN TRỌNG NHẤT** | |
if matched_vehicle: | |
article_title_lower = metadata.get("article_title", "").lower() | |
# Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài | |
vehicle_keywords = { | |
"ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"], | |
"xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"] | |
} | |
if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])): | |
# Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên | |
final_score += 0.5 | |
final_results.append({ | |
'index': doc_idx, | |
'final_score': final_score | |
}) | |
except IndexError: | |
continue | |
final_results.sort(key=lambda x: x['final_score'], reverse=True) | |
# Lấy đầy đủ thông tin cho top-k kết quả cuối cùng | |
top_k_results = [] | |
for res in final_results[:k]: | |
doc_idx = res['index'] | |
top_k_results.append({ | |
'index': doc_idx, | |
'final_score': res['final_score'], | |
'text': chunks_data[doc_idx].get('text', ''), | |
'metadata': chunks_data[doc_idx].get('metadata', {}) | |
}) | |
return top_k_results |