chatbot_demo / rag_pipeline.py
deddoggo's picture
update
0902442
# file: rag_pipeline.py
import torch
import json
import faiss
import numpy as np
import re
from unsloth import FastLanguageModel
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import TextStreamer
# Import các hàm từ file khác
from data_processor import process_law_data_to_chunks
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup
def initialize_components(data_path):
"""
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline.
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động.
"""
print("--- Bắt đầu khởi tạo các thành phần ---")
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth
print("1. Tải mô hình LLM (Unsloth)...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
max_seq_length=4096, # Có thể tăng cho các mô hình mới
dtype=None,
load_in_4bit=True,
)
print("✅ Tải LLM và Tokenizer thành công.")
# 2. Tải mô hình Embedding
print("2. Tải mô hình Embedding...")
embedding_model = SentenceTransformer(
"bkai-foundation-models/vietnamese-bi-encoder",
device="cuda" if torch.cuda.is_available() else "cpu"
)
print("✅ Tải mô hình Embedding thành công.")
# 3. Tải và xử lý dữ liệu JSON
print(f"3. Tải và xử lý dữ liệu từ {data_path}...")
with open(data_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
chunks_data = process_law_data_to_chunks(raw_data)
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.")
# 4. Tạo Embeddings và FAISS Index
print("4. Tạo embeddings và FAISS index...")
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
chunk_embeddings_tensor = embedding_model.encode(
texts_to_encode,
convert_to_tensor=True,
device=embedding_model.device
)
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32')
faiss.normalize_L2(chunk_embeddings_np)
dimension = chunk_embeddings_np.shape[1]
faiss_index = faiss.IndexFlatIP(dimension)
faiss_index.add(chunk_embeddings_np)
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.")
# 5. Tạo BM25 Model
print("5. Tạo mô hình BM25...")
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
bm25_model = BM25Okapi(tokenized_corpus_bm25)
print("✅ Tạo mô hình BM25 thành công.")
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---")
return {
"llm_model": model,
"tokenizer": tokenizer,
"embedding_model": embedding_model,
"chunks_data": chunks_data,
"faiss_index": faiss_index,
"bm25_model": bm25_model
}
def generate_response(query: str, components: dict) -> str:
"""
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
Phiên bản cuối cùng, sửa lỗi ValueError cho mô hình Vision bằng cách
sử dụng apply_chat_template để tokenization trực tiếp.
"""
print("--- Bắt đầu quy trình RAG cho query mới ---")
# --- Bước 1: Truy xuất Ngữ cảnh ---
retrieved_results = search_relevant_laws(
query_text=query,
embedding_model=components["embedding_model"],
faiss_index=components["faiss_index"],
chunks_data=components["chunks_data"],
bm25_model=components["bm25_model"],
k=5,
initial_k_multiplier=15
)
# --- Bước 2: Định dạng Ngữ cảnh ---
if not retrieved_results:
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
else:
context_parts = []
for i, res in enumerate(retrieved_results):
metadata = res.get('metadata', {})
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
text = res.get('text', '*Nội dung không có*')
context_parts.append(f"{header}\n{text}")
context = "\n\n---\n\n".join(context_parts)
# --- Bước 3: Chuẩn bị Dữ liệu và Tokenize bằng Chat Template (Phần sửa lỗi cốt lõi) ---
print("--- Chuẩn bị và tokenize prompt bằng chat template ---")
llm_model = components["llm_model"]
tokenizer = components["tokenizer"]
# Tạo cấu trúc tin nhắn theo chuẩn
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam..."}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": f"""Dựa vào các trích dẫn luật dưới đây:
### Thông tin luật:
{context}
### Câu hỏi:
{query}
"""}
]
}
]
# SỬA LỖI: Dùng apply_chat_template để tokenize trực tiếp
# Nó sẽ tự động định dạng và chuyển thành tensor, tương thích với mô hình Vision
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(llm_model.device)
# --- Bước 4: Tạo câu trả lời từ LLM ---
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
generation_config = dict(
max_new_tokens=256,
temperature=0.1,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
output_ids = llm_model.generate(inputs, **generation_config)
# Decode như cũ, nhưng đầu vào là `inputs` thay vì `inputs.input_ids`
response_text = tokenizer.decode(output_ids[0][inputs.shape[1]:], skip_special_tokens=True)
print("--- Tạo câu trả lời hoàn tất ---")
return response_text