File size: 4,596 Bytes
3ae87dc
e3c045e
 
3ae87dc
 
e3c045e
 
 
3ae87dc
 
 
 
e3c045e
e19aff3
e3c045e
3ae87dc
e3c045e
 
 
 
 
 
 
 
e19aff3
e3c045e
e19aff3
e3c045e
3ae87dc
 
6525f09
3ae87dc
6525f09
 
 
 
3ae87dc
 
 
 
 
 
 
 
 
 
6525f09
e19aff3
3ae87dc
e19aff3
3ae87dc
 
 
e3c045e
e19aff3
3ae87dc
e3c045e
12e378f
3ae87dc
 
 
 
12e378f
e19aff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ae87dc
 
 
 
 
e19aff3
4b92e5d
e19aff3
 
 
 
 
12e378f
9c8dddf
3ae87dc
 
 
12e378f
 
e19aff3
12e378f
c8da6e9
3ae87dc
 
 
4b92e5d
3ae87dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# app.py
import gradio as gr
import torch
import re
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# === Load model embedding
embedder = SentenceTransformer("keepitreal/vietnamese-sbert")

# === Thiết bị
device = torch.device("cpu")
print("✅ Using device:", device)

# === Load mô hình sinh phản hồi
model_name = "vanhai123/vietnamese-ecom-chatbot"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = AutoModelForCausalLM.from_pretrained(
        "NlpHUST/gpt2-vietnamese", torch_dtype=torch.float32
    ).to(device)
    base_model.resize_token_embeddings(len(tokenizer))
    model = PeftModel.from_pretrained(base_model, model_name).to(device)
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model or tokenizer: {str(e)}")
    raise
def load_qa_from_file(path="examples.txt"):
    qa_pairs = []
    try:
        with open(path, "r", encoding="utf-8") as file:
            content = file.read()
            blocks = content.split("<|human|>")
            for block in blocks:
                if "Hỏi:" in block and "<|assistant|>" in block:
                    q = re.search(r"Hỏi:(.*)", block)
                    a = re.search(r"<\|assistant\|>(.*)", block, re.DOTALL)
                    if q and a:
                        question = q.group(1).strip()
                        answer_block = a.group(1).strip()
                        for line in answer_block.splitlines():
                            if "**Phản hồi**" in line:
                                answer = line.split("**Phản hồi**:")[-1].strip()
                                qa_pairs.append({"q": question, "a": answer})
                                break
    except Exception as e:
        print(f"Lỗi đọc file: {e}")
    return qa_pairs

qa_data = load_qa_from_file("examples.txt")
questions = [qa["q"] for qa in qa_data]
embeddings = embedder.encode(questions, convert_to_tensor=True)

# === Prompt builder
def build_prompt(question):
    try:
        with open("examples.txt", "r", encoding="utf-8") as file:
            example_block = "".join(file.readlines()[:30])
    except:
        example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam."
    return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>"

# === Sinh phản hồi từ mô hình

def generate_with_model(question):
    prompt = build_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_len = inputs["input_ids"].shape[-1]

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=120,
            temperature=0.6,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.15,
            no_repeat_ngram_size=3,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
    lines = [line.strip() for line in output_text.splitlines() if line.strip()]
    for line in lines:
        if "**Phản hồi**" in line:
            return line.split("**Phản hồi**:")[-1].strip()
    return None

def semantic_fallback(question):
    query_embedding = embedder.encode(question, convert_to_tensor=True)
    cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
    top_idx = torch.argmax(cos_scores).item()
    top_score = cos_scores[top_idx].item()
    if top_score >= 0.75:
        return qa_data[top_idx]["a"]
    return "Vui lòng liên hệ CSKH để được hỗ trợ!"

def answer_question(user_question):
    response = generate_with_model(user_question)
    if response and len(response) > 30:
        return response
    return semantic_fallback(user_question)

# === Giao diện Gradio
interface = gr.Interface(
    fn=answer_question,
    inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."),
    outputs="text",
    title="Vietnamese E-commerce Chatbot",
    description="Trợ lý AI thương mại điện tử: Trả lời từ mô hình ngôn ngữ hoặc tra cứu dữ liệu câu hỏi.",
    examples=[
        ["Tôi muốn kiểm tra đơn hàng"],
        ["Có giảm giá khi mua số lượng lớn không?"],
        ["Tôi muốn trả hàng vì sản phẩm lỗi"],
        ["Tư vấn laptop cho dân văn phòng"]
    ]
)

interface.launch()