|
import gradio as gr |
|
import torch |
|
|
|
from retrieval import ( |
|
process_law_data_to_chunks, |
|
|
|
|
|
|
|
tokenize_vi_for_bm25_setup, |
|
search_relevant_laws |
|
) |
|
from llm_handler import generate_response |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
from rank_bm25 import BM25Okapi |
|
import json |
|
from unsloth import FastLanguageModel |
|
|
|
|
|
|
|
JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json" |
|
FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index" |
|
LLM_MODEL_PATH = "models/lora_model_base" |
|
EMBEDDING_MODEL_PATH = "models/embedding_model" |
|
|
|
|
|
print("Loading and processing law data...") |
|
try: |
|
with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f: |
|
raw_data_from_file = json.load(f) |
|
chunks_data = process_law_data_to_chunks(raw_data_from_file) |
|
print(f"Loaded {len(chunks_data)} chunks.") |
|
if not chunks_data: |
|
raise ValueError("Chunks data is empty after processing.") |
|
except Exception as e: |
|
print(f"Error loading/processing law data: {e}") |
|
chunks_data = [] |
|
|
|
|
|
print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
try: |
|
embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device) |
|
print("Embedding model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading embedding model: {e}") |
|
embedding_model = None |
|
|
|
|
|
print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...") |
|
try: |
|
faiss_index = faiss.read_index(FAISS_INDEX_PATH) |
|
print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}") |
|
except Exception as e: |
|
print(f"Error loading FAISS index: {e}") |
|
faiss_index = None |
|
|
|
|
|
print("Creating BM25 model...") |
|
bm25_model = None |
|
if chunks_data: |
|
try: |
|
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] |
|
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] |
|
bm25_model = BM25Okapi(tokenized_corpus_bm25) |
|
print("BM25 model created successfully.") |
|
except Exception as e: |
|
print(f"Error creating BM25 model: {e}") |
|
else: |
|
print("Skipping BM25 model creation as chunks_data is empty.") |
|
|
|
|
|
|
|
print(f"Loading LLM model: {LLM_MODEL_PATH}...") |
|
try: |
|
|
|
|
|
llm_model, llm_tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=LLM_MODEL_PATH, |
|
max_seq_length=2048, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
FastLanguageModel.for_inference(llm_model) |
|
print("LLM model and tokenizer loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading LLM model: {e}") |
|
llm_model = None |
|
llm_tokenizer = None |
|
|
|
|
|
|
|
def respond(message, history: list[tuple[str, str]]): |
|
if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]): |
|
|
|
missing_components = [] |
|
if not chunks_data: missing_components.append("chunks_data") |
|
if not embedding_model: missing_components.append("embedding_model") |
|
if not faiss_index: missing_components.append("faiss_index") |
|
if not bm25_model: missing_components.append("bm25_model") |
|
if not llm_model: missing_components.append("llm_model") |
|
if not llm_tokenizer: missing_components.append("llm_tokenizer") |
|
error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space." |
|
print(error_msg) |
|
return error_msg |
|
|
|
try: |
|
response_text = generate_response( |
|
query=message, |
|
llama_model=llm_model, |
|
tokenizer=llm_tokenizer, |
|
faiss_index=faiss_index, |
|
embed_model=embedding_model, |
|
chunks_data_list=chunks_data, |
|
bm25_model=bm25_model, |
|
search_function=search_relevant_laws |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
yield response_text |
|
|
|
except Exception as e: |
|
|
|
import traceback |
|
print(f"Error during response generation for query '{message}': {e}") |
|
print(traceback.format_exc()) |
|
yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên." |
|
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
|
|
|
|
|
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |