|
|
|
|
|
import torch |
|
import re |
|
import json |
|
from unsloth import FastLanguageModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_llm_model_and_tokenizer( |
|
model_name_or_path: str, |
|
max_seq_length: int = 2048, |
|
load_in_4bit: bool = True, |
|
device_map: str = "auto" |
|
): |
|
""" |
|
Tải mô hình ngôn ngữ lớn (LLM) đã được fine-tune bằng Unsloth và tokenizer tương ứng. |
|
|
|
Args: |
|
model_name_or_path (str): Tên hoặc đường dẫn đến mô hình đã fine-tune. |
|
max_seq_length (int): Độ dài chuỗi tối đa mà mô hình hỗ trợ. |
|
load_in_4bit (bool): Có tải mô hình ở dạng 4-bit quantization hay không. |
|
device_map (str): Cách map model lên các device (ví dụ "auto", "cuda:0"). |
|
|
|
Returns: |
|
tuple: (model, tokenizer) nếu thành công, (None, None) nếu có lỗi. |
|
""" |
|
print(f"Đang tải LLM model: {model_name_or_path}...") |
|
try: |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=model_name_or_path, |
|
max_seq_length=max_seq_length, |
|
dtype=None, |
|
load_in_4bit=load_in_4bit, |
|
device_map=device_map, |
|
|
|
) |
|
FastLanguageModel.for_inference(model) |
|
print("Tải LLM model và tokenizer thành công.") |
|
return model, tokenizer |
|
except Exception as e: |
|
print(f"Lỗi khi tải LLM model và tokenizer: {e}") |
|
return None, None |
|
|
|
|
|
def generate_response( |
|
query: str, |
|
llama_model, |
|
tokenizer, |
|
|
|
faiss_index, |
|
embed_model, |
|
chunks_data_list: list, |
|
bm25_model, |
|
|
|
search_k: int = 5, |
|
search_multiplier: int = 10, |
|
rrf_k_constant: int = 60, |
|
|
|
max_new_tokens: int = 768, |
|
temperature: float = 0.4, |
|
top_p: float = 0.9, |
|
top_k: int = 40, |
|
repetition_penalty: float = 1.15, |
|
|
|
search_function |
|
): |
|
""" |
|
Truy xuất ngữ cảnh bằng hàm search_relevant_laws (được truyền vào) |
|
và tạo câu trả lời từ LLM dựa trên ngữ cảnh đó. |
|
|
|
Args: |
|
query (str): Câu truy vấn của người dùng. |
|
llama_model: Mô hình LLM đã tải. |
|
tokenizer: Tokenizer tương ứng. |
|
faiss_index: FAISS index đã tải. |
|
embed_model: Mô hình embedding đã tải. |
|
chunks_data_list (list): Danh sách các chunk dữ liệu luật. |
|
bm25_model: Mô hình BM25 đã tạo. |
|
search_k (int): Số lượng kết quả cuối cùng muốn lấy từ hàm search. |
|
search_multiplier (int): Hệ số initial_k_multiplier cho hàm search. |
|
rrf_k_constant (int): Hằng số k cho RRF trong hàm search. |
|
max_new_tokens (int): Số token tối đa được tạo mới bởi LLM. |
|
temperature (float): Nhiệt độ cho việc sinh văn bản. |
|
top_p (float): Tham số top-p cho nucleus sampling. |
|
top_k (int): Tham số top-k. |
|
repetition_penalty (float): Phạt cho việc lặp từ. |
|
search_function: Hàm thực hiện tìm kiếm (ví dụ: retrieval.search_relevant_laws). |
|
|
|
Returns: |
|
str: Câu trả lời được tạo ra bởi LLM. |
|
""" |
|
print(f"\n--- [LLM Handler] Bắt đầu xử lý query: '{query}' ---") |
|
|
|
|
|
print("--- [LLM Handler] Bước 1: Truy xuất ngữ cảnh (Hybrid Search)... ---") |
|
try: |
|
retrieved_results = search_function( |
|
query_text=query, |
|
embedding_model=embed_model, |
|
faiss_index=faiss_index, |
|
chunks_data=chunks_data_list, |
|
bm25_model=bm25_model, |
|
k=search_k, |
|
initial_k_multiplier=search_multiplier, |
|
rrf_k_constant=rrf_k_constant |
|
|
|
|
|
) |
|
print(f"--- [LLM Handler] Truy xuất xong, số kết quả: {len(retrieved_results)} ---") |
|
if not retrieved_results: |
|
print("--- [LLM Handler] Không tìm thấy ngữ cảnh nào. ---") |
|
except Exception as e: |
|
print(f"Lỗi trong quá trình truy xuất ngữ cảnh: {e}") |
|
retrieved_results = [] |
|
|
|
|
|
print("--- [LLM Handler] Bước 2: Định dạng context cho LLM... ---") |
|
context_parts = [] |
|
if not retrieved_results: |
|
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu để trả lời câu hỏi này." |
|
else: |
|
for i, res in enumerate(retrieved_results): |
|
metadata = res.get('metadata', {}) |
|
article_title = metadata.get('article_title', 'N/A') |
|
article = metadata.get('article', 'N/A') |
|
clause = metadata.get('clause_number', 'N/A') |
|
point = metadata.get('point_id', '') |
|
source = metadata.get('source_document', 'N/A') |
|
text_content = res.get('text', '*Nội dung không có*') |
|
|
|
|
|
header_parts = [f"Trích dẫn {i+1}:"] |
|
if source != 'N/A': |
|
header_parts.append(f"(Nguồn: {source})") |
|
if article != 'N/A': |
|
header_parts.append(f"Điều {article}") |
|
if article_title != 'N/A' and article_title != article: |
|
header_parts.append(f"({article_title})") |
|
if clause != 'N/A': |
|
header_parts.append(f", Khoản {clause}") |
|
if point: |
|
header_parts.append(f", Điểm {point}") |
|
|
|
header = " ".join(header_parts) |
|
|
|
|
|
|
|
query_analysis = metadata.get("query_analysis_for_boost", {}) |
|
|
|
mentions_fine_in_query = bool(re.search(r'tiền|phạt|bao nhiêu đồng|mức phạt', query.lower())) |
|
mentions_points_in_query = bool(re.search(r'điểm|trừ điểm|bằng lái|gplx', query.lower())) |
|
|
|
fine_info_text = [] |
|
if metadata.get("has_fine") and mentions_fine_in_query: |
|
if metadata.get("individual_fine_min") is not None and metadata.get("individual_fine_max") is not None: |
|
fine_info_text.append(f"Phạt tiền: {metadata.get('individual_fine_min'):,} - {metadata.get('individual_fine_max'):,} VND.") |
|
elif metadata.get("overall_fine_note_for_clause_text"): |
|
fine_info_text.append(f"Ghi chú phạt tiền: {metadata.get('overall_fine_note_for_clause_text')}") |
|
|
|
points_info_text = [] |
|
if metadata.get("has_points_deduction") and mentions_points_in_query: |
|
if metadata.get("points_deducted_values_str"): |
|
points_info_text.append(f"Trừ điểm: {metadata.get('points_deducted_values_str')} điểm.") |
|
elif metadata.get("overall_points_deduction_note_for_clause_text"): |
|
points_info_text.append(f"Ghi chú trừ điểm: {metadata.get('overall_points_deduction_note_for_clause_text')}") |
|
|
|
|
|
penalty_summary = "" |
|
if fine_info_text or points_info_text: |
|
penalty_summary = " (Liên quan: " + " ".join(fine_info_text + points_info_text) + ")" |
|
|
|
context_parts.append(f"{header}{penalty_summary}\nNội dung: {text_content}") |
|
|
|
context = "\n\n---\n\n".join(context_parts) |
|
|
|
|
|
|
|
|
|
|
|
prompt = f"""Bạn là một trợ lý AI chuyên tư vấn về luật giao thông đường bộ Việt Nam. |
|
Nhiệm vụ của bạn là dựa vào các thông tin luật được cung cấp dưới đây để trả lời câu hỏi của người dùng một cách chính xác, chi tiết và dễ hiểu. |
|
Nếu thông tin không đủ hoặc không có trong các trích dẫn được cung cấp, hãy trả lời rằng bạn không tìm thấy thông tin cụ thể trong tài liệu được cung cấp. |
|
Tránh đưa ra ý kiến cá nhân hoặc thông tin không có trong ngữ cảnh. Hãy trích dẫn điều, khoản, điểm nếu có thể. |
|
|
|
### Thông tin luật được trích dẫn: |
|
{context} |
|
|
|
### Câu hỏi của người dùng: |
|
{query} |
|
|
|
### Trả lời của bạn:""" |
|
|
|
|
|
|
|
|
|
print("--- [LLM Handler] Bước 3: Tạo câu trả lời từ LLM... ---") |
|
device = llama_model.device |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
generation_config = dict( |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True if temperature > 0 else False, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
output_ids = llama_model.generate(**inputs, **generation_config) |
|
|
|
|
|
input_length = inputs.input_ids.shape[1] |
|
generated_ids = output_ids[0][input_length:] |
|
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
print("--- [LLM Handler] Tạo câu trả lời hoàn tất. ---") |
|
|
|
return response_text |
|
|
|
except Exception as e: |
|
print(f"Lỗi trong quá trình LLM generating: {e}") |
|
return "Xin lỗi, đã có lỗi xảy ra trong quá trình tạo câu trả lời từ mô hình ngôn ngữ." |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
print("Chạy test cho llm_handler.py (chưa có mock dữ liệu)...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |