Spaces:
Paused
Paused
File size: 4,924 Bytes
a53f1d8 b8e326d a53f1d8 b8e326d a53f1d8 3264b15 b8e326d 3264b15 b8e326d 3264b15 b8e326d a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 b8e326d 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 a53f1d8 3264b15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# file: retriever.py
import faiss
import numpy as np
import torch
import re
from collections import defaultdict
from rank_bm25 import BM25Okapi
def tokenize_vi_for_bm25_setup(text):
"""Tokenize tiếng Việt đơn giản cho BM25."""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
return text.split()
def _get_vehicle_type(query_lower: str) -> str | None:
"""Xác định loại xe được đề cập trong câu truy vấn."""
# Từ điển định nghĩa các từ khóa cho từng loại xe
vehicle_keywords = {
"ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"],
"xe máy": ["xe máy", "xe mô tô", "xe gắn máy"],
"xe đạp": ["xe đạp", "xe thô sơ"],
"máy kéo": ["máy kéo", "xe chuyên dùng"]
}
for vehicle_type, keywords in vehicle_keywords.items():
if any(keyword in query_lower for keyword in keywords):
return vehicle_type
return None
def search_relevant_laws(
query_text: str,
embedding_model,
faiss_index,
chunks_data: list[dict],
bm25_model,
k: int = 5,
initial_k_multiplier: int = 15,
rrf_k_constant: int = 60
) -> list[dict]:
"""
Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe.
Quy trình:
1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa).
2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF).
3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe).
4. Sắp xếp lại và trả về top-k kết quả cuối cùng.
"""
if k <= 0:
return []
num_vectors_in_index = faiss_index.ntotal
if num_vectors_in_index == 0:
return []
num_candidates = min(k * initial_k_multiplier, num_vectors_in_index)
# --- 1. Semantic Search (FAISS) ---
try:
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True)
query_embedding_np = query_embedding.cpu().numpy().astype('float32')
faiss.normalize_L2(query_embedding_np)
_, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
semantic_indices = semantic_indices[0]
except Exception as e:
print(f"Lỗi FAISS search: {e}")
semantic_indices = []
# --- 2. Keyword Search (BM25) ---
try:
tokenized_query = tokenize_vi_for_bm25_setup(query_text)
bm25_scores = bm25_model.get_scores(tokenized_query)
# Lấy top N chỉ mục từ BM25
top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates]
except Exception as e:
print(f"Lỗi BM25 search: {e}")
top_bm25_indices = []
# --- 3. Result Fusion (RRF) ---
rrf_scores = defaultdict(float)
all_indices = set(semantic_indices) | set(top_bm25_indices)
for rank, doc_idx in enumerate(semantic_indices):
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
for rank, doc_idx in enumerate(top_bm25_indices):
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
# --- 4. Metadata Boosting & Final Ranking ---
query_lower = query_text.lower()
matched_vehicle = _get_vehicle_type(query_lower)
final_results = []
for doc_idx in all_indices:
try:
metadata = chunks_data[doc_idx].get('metadata', {})
final_score = rrf_scores[doc_idx]
# **LOGIC BOOSTING QUAN TRỌNG NHẤT**
if matched_vehicle:
article_title_lower = metadata.get("article_title", "").lower()
# Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài
vehicle_keywords = {
"ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"],
"xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"]
}
if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])):
# Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên
final_score += 0.5
final_results.append({
'index': doc_idx,
'final_score': final_score
})
except IndexError:
continue
final_results.sort(key=lambda x: x['final_score'], reverse=True)
# Lấy đầy đủ thông tin cho top-k kết quả cuối cùng
top_k_results = []
for res in final_results[:k]:
doc_idx = res['index']
top_k_results.append({
'index': doc_idx,
'final_score': res['final_score'],
'text': chunks_data[doc_idx].get('text', ''),
'metadata': chunks_data[doc_idx].get('metadata', {})
})
return top_k_results |