File size: 4,924 Bytes
a53f1d8
b8e326d
a53f1d8
 
 
b8e326d
a53f1d8
 
 
 
 
 
 
 
3264b15
 
 
 
 
 
 
 
 
 
 
 
 
 
b8e326d
3264b15
 
 
 
 
 
 
 
 
b8e326d
3264b15
 
 
 
 
 
 
b8e326d
a53f1d8
 
 
 
 
 
 
3264b15
a53f1d8
3264b15
a53f1d8
3264b15
 
a53f1d8
3264b15
 
a53f1d8
3264b15
 
a53f1d8
3264b15
a53f1d8
3264b15
 
 
 
a53f1d8
3264b15
 
a53f1d8
3264b15
b8e326d
3264b15
a53f1d8
3264b15
a53f1d8
 
3264b15
 
a53f1d8
3264b15
 
 
 
a53f1d8
3264b15
a53f1d8
3264b15
 
a53f1d8
3264b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53f1d8
3264b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# file: retriever.py
import faiss
import numpy as np
import torch
import re
from collections import defaultdict
from rank_bm25 import BM25Okapi

def tokenize_vi_for_bm25_setup(text):
    """Tokenize tiếng Việt đơn giản cho BM25."""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return text.split()

def _get_vehicle_type(query_lower: str) -> str | None:
    """Xác định loại xe được đề cập trong câu truy vấn."""
    # Từ điển định nghĩa các từ khóa cho từng loại xe
    vehicle_keywords = {
        "ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"],
        "xe máy": ["xe máy", "xe mô tô", "xe gắn máy"],
        "xe đạp": ["xe đạp", "xe thô sơ"],
        "máy kéo": ["máy kéo", "xe chuyên dùng"]
    }
    for vehicle_type, keywords in vehicle_keywords.items():
        if any(keyword in query_lower for keyword in keywords):
            return vehicle_type
    return None

def search_relevant_laws(
    query_text: str,
    embedding_model,
    faiss_index,
    chunks_data: list[dict],
    bm25_model,
    k: int = 5,
    initial_k_multiplier: int = 15,
    rrf_k_constant: int = 60
) -> list[dict]:
    """
    Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe.

    Quy trình:
    1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa).
    2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF).
    3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe).
    4. Sắp xếp lại và trả về top-k kết quả cuối cùng.
    """
    if k <= 0:
        return []

    num_vectors_in_index = faiss_index.ntotal
    if num_vectors_in_index == 0:
        return []

    num_candidates = min(k * initial_k_multiplier, num_vectors_in_index)

    # --- 1. Semantic Search (FAISS) ---
    try:
        query_embedding = embedding_model.encode([query_text], convert_to_tensor=True)
        query_embedding_np = query_embedding.cpu().numpy().astype('float32')
        faiss.normalize_L2(query_embedding_np)
        _, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
        semantic_indices = semantic_indices[0]
    except Exception as e:
        print(f"Lỗi FAISS search: {e}")
        semantic_indices = []

    # --- 2. Keyword Search (BM25) ---
    try:
        tokenized_query = tokenize_vi_for_bm25_setup(query_text)
        bm25_scores = bm25_model.get_scores(tokenized_query)
        # Lấy top N chỉ mục từ BM25
        top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates]
    except Exception as e:
        print(f"Lỗi BM25 search: {e}")
        top_bm25_indices = []

    # --- 3. Result Fusion (RRF) ---
    rrf_scores = defaultdict(float)
    all_indices = set(semantic_indices) | set(top_bm25_indices)

    for rank, doc_idx in enumerate(semantic_indices):
        rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)

    for rank, doc_idx in enumerate(top_bm25_indices):
        rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)

    # --- 4. Metadata Boosting & Final Ranking ---
    query_lower = query_text.lower()
    matched_vehicle = _get_vehicle_type(query_lower)
    final_results = []

    for doc_idx in all_indices:
        try:
            metadata = chunks_data[doc_idx].get('metadata', {})
            final_score = rrf_scores[doc_idx]
            
            # **LOGIC BOOSTING QUAN TRỌNG NHẤT**
            if matched_vehicle:
                article_title_lower = metadata.get("article_title", "").lower()
                # Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài
                vehicle_keywords = {
                    "ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"],
                    "xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"]
                }
                if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])):
                    # Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên
                    final_score += 0.5 
            
            final_results.append({
                'index': doc_idx,
                'final_score': final_score
            })
        except IndexError:
            continue

    final_results.sort(key=lambda x: x['final_score'], reverse=True)

    # Lấy đầy đủ thông tin cho top-k kết quả cuối cùng
    top_k_results = []
    for res in final_results[:k]:
        doc_idx = res['index']
        top_k_results.append({
            'index': doc_idx,
            'final_score': res['final_score'],
            'text': chunks_data[doc_idx].get('text', ''),
            'metadata': chunks_data[doc_idx].get('metadata', {})
        })

    return top_k_results