File size: 5,277 Bytes
9ccba7a
3b1849a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ccba7a
587dce1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import json
import base64
import sqlite3
import tempfile

import torch
import whisper
from gtts import gTTS
from pydub import AudioSegment
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

# ๆจกๅž‹่ˆ‡่ฃ็ฝฎๅˆๅง‹ๅŒ–
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

whisper_model = whisper.load_model("medium", device=device)
embed_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True).to(device).eval()

# ่ผ‰ๅ…ฅ qa.json
with open("qa.json", "r", encoding="utf-8") as f:
    qa_data = json.load(f)

# ๅˆๅง‹ๅŒ– SQLite ่ณ‡ๆ–™ๅบซ
DB_PATH = "db.sqlite"

def init_db():
    conn = sqlite3.connect(DB_PATH)
    c = conn.cursor()
    c.execute('''CREATE TABLE IF NOT EXISTS documents (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    content TEXT UNIQUE,
                    embedding TEXT
                )''')
    conn.commit()
    conn.close()

def embed_and_store_documents():
    for filename in os.listdir("documents"):
        if filename.endswith(".txt"):
            with open(os.path.join("documents", filename), "r", encoding="utf-8") as f:
                for line in f:
                    content = line.strip()
                    if content:
                        embedding = embed_model.encode(content).astype(np.float32)
                        emb_str = base64.b64encode(embedding.tobytes()).decode("utf-8")
                        try:
                            conn = sqlite3.connect(DB_PATH)
                            c = conn.cursor()
                            c.execute("INSERT OR IGNORE INTO documents (content, embedding) VALUES (?, ?)", (content, emb_str))
                            conn.commit()
                            conn.close()
                        except Exception as e:
                            print("DB insert error:", e)

init_db()
embed_and_store_documents()

# ้—œ้ตๅญ—ๅŒน้…
def match_qa(text):
    for entry in qa_data:
        keywords = entry["keywords"]
        match_type = entry.get("match", "OR").upper()
        if match_type == "OR" and any(k in text for k in keywords):
            return entry["response"]
        elif match_type == "AND" and all(k in text for k in keywords):
            return entry["response"]
    return None

# ๅ‘้‡็›ธไผผๆœๅฐ‹
def search_similar_paragraphs(query, top_k=5):
    query_vec = embed_model.encode(query).astype(np.float32)
    conn = sqlite3.connect(DB_PATH)
    c = conn.cursor()
    c.execute("SELECT content, embedding FROM documents")
    all_data = c.fetchall()
    conn.close()

    scores = []
    for content, emb_str in all_data:
        emb = np.frombuffer(base64.b64decode(emb_str), dtype=np.float32)
        sim = util.cos_sim(torch.tensor(query_vec), torch.tensor(emb))[0][0].item()
        scores.append((content, sim))

    top = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k]
    return [item[0] for item in top]

# Whisper ่ชž้Ÿณ่พจ่ญ˜
def transcribe(audio_path):
    result = whisper_model.transcribe(audio_path)
    return result["text"]

# AI ๅ›ž็ญ”
def generate_response(text):
    matched = match_qa(text)
    if matched:
        return matched

    context = "\n".join(search_similar_paragraphs(text, top_k=5))
    prompt = f"ๅชไพๆ“šไปฅไธ‹ๆไพ›็š„่ณ‡ๆ–™๏ผŒๅ›ž็ญ”ๅ•้กŒใ€‚ๅฆ‚ๆžœๆ‰พไธๅˆฐ็›ธ้—œ่ณ‡ๆ–™๏ผŒๅฐฑ่ชชใ€Œๆ‰พไธๅˆฐ็›ธ้—œ่ณ‡่จŠใ€๏ผš\n{context}\nๅ•้กŒ๏ผš{text}\nๅ›ž็ญ”๏ผš"
    inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
    outputs = qwen_model.generate(**inputs, max_new_tokens=300)
    result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result.replace(prompt, "").strip()

# TTS ๅˆๆˆ
def text_to_speech(text):
    try:
        tts = gTTS(text=text, lang="zh")
        path = tempfile.mktemp(suffix=".mp3")
        tts.save(path)
        return path
    except Exception as e:
        print("TTS error:", e)
        return None

# ไธป่™•็†ๆต็จ‹๏ผšๆ”ฏๆดๆ–‡ๅญ—ๅ’Œ่ชž้Ÿณ
def main(text, audio):
    transcript = ""

    if text and text.strip():
        # ๅ„ชๅ…ˆไฝฟ็”จๆ–‡ๅญ—่ผธๅ…ฅ
        transcript = text.strip()
    elif audio:
        # ๅฆ‚ๆžœๆฒ’ๆ–‡ๅญ—๏ผŒไฝฟ็”จ่ชž้Ÿณ่พจ่ญ˜
        transcript = transcribe(audio)

    if not transcript:
        return "", "ๆœช่ผธๅ…ฅๆ–‡ๅญ—ๆˆ–่ชž้Ÿณ", None

    reply = generate_response(transcript)
    tts_path = text_to_speech(reply)
    return transcript, reply, tts_path

# Gradio UI
ui = gr.Interface(
    fn=main,
    inputs=[
        gr.Textbox(label="ๆ–‡ๅญ—่ผธๅ…ฅ๏ผˆๅฏ้ธ๏ผ‰"),
        gr.Audio(type="filepath", label="่ชž้Ÿณ่ผธๅ…ฅ๏ผˆๅฏ้Œ„้Ÿณๆˆ–ไธŠๅ‚ณ๏ผ‰")
    ],
    outputs=[
        gr.Textbox(label="่พจ่ญ˜ๆ–‡ๅญ— / ๆ–‡ๅญ—่ผธๅ…ฅ"),
        gr.Textbox(label="AI ๅ›ž็ญ”"),
        gr.Audio(type="filepath", label="่ชž้Ÿณๅ›ž่ฆ†")
    ],
    title="ๅ—่‡บ็ง‘ๆŠ€ๅคงๅญธ่ชž้Ÿณๆ–‡ๅญ—ๅ•็ญ”็ณป็ตฑ",
    description="ไฝ ๅฏไปฅ่ผธๅ…ฅๆ–‡ๅญ—ๆˆ–้Œ„้Ÿณๆๅ•๏ผŒ็ณป็ตฑๆœƒๅพž่ณ‡ๆ–™ไธญๆ‰พๅ‡บ็ญ”ๆกˆ๏ผŒไธฆไปฅ่ชž้Ÿณๅ›ž่ฆ†ใ€‚"
)

if __name__ == "__main__":
    ui.launch()