import os import json import base64 import sqlite3 import tempfile import torch import whisper from gtts import gTTS from pydub import AudioSegment import numpy as np from sentence_transformers import SentenceTransformer, util from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # 模型與裝置初始化 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") whisper_model = whisper.load_model("medium", device=device) embed_model = SentenceTransformer("all-MiniLM-L6-v2", device=device) qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True) qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True).to(device).eval() # 載入 qa.json with open("qa.json", "r", encoding="utf-8") as f: qa_data = json.load(f) # 初始化 SQLite 資料庫 DB_PATH = "db.sqlite" def init_db(): conn = sqlite3.connect(DB_PATH) c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, content TEXT UNIQUE, embedding TEXT )''') conn.commit() conn.close() def embed_and_store_documents(): for filename in os.listdir("documents"): if filename.endswith(".txt"): with open(os.path.join("documents", filename), "r", encoding="utf-8") as f: for line in f: content = line.strip() if content: embedding = embed_model.encode(content).astype(np.float32) emb_str = base64.b64encode(embedding.tobytes()).decode("utf-8") try: conn = sqlite3.connect(DB_PATH) c = conn.cursor() c.execute("INSERT OR IGNORE INTO documents (content, embedding) VALUES (?, ?)", (content, emb_str)) conn.commit() conn.close() except Exception as e: print("DB insert error:", e) init_db() embed_and_store_documents() # 關鍵字匹配 def match_qa(text): for entry in qa_data: keywords = entry["keywords"] match_type = entry.get("match", "OR").upper() if match_type == "OR" and any(k in text for k in keywords): return entry["response"] elif match_type == "AND" and all(k in text for k in keywords): return entry["response"] return None # 向量相似搜尋 def search_similar_paragraphs(query, top_k=5): query_vec = embed_model.encode(query).astype(np.float32) conn = sqlite3.connect(DB_PATH) c = conn.cursor() c.execute("SELECT content, embedding FROM documents") all_data = c.fetchall() conn.close() scores = [] for content, emb_str in all_data: emb = np.frombuffer(base64.b64decode(emb_str), dtype=np.float32) sim = util.cos_sim(torch.tensor(query_vec), torch.tensor(emb))[0][0].item() scores.append((content, sim)) top = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k] return [item[0] for item in top] # Whisper 語音辨識 def transcribe(audio_path): result = whisper_model.transcribe(audio_path) return result["text"] # AI 回答 def generate_response(text): matched = match_qa(text) if matched: return matched context = "\n".join(search_similar_paragraphs(text, top_k=5)) prompt = f"只依據以下提供的資料,回答問題。如果找不到相關資料,就說「找不到相關資訊」:\n{context}\n問題:{text}\n回答:" inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device) outputs = qwen_model.generate(**inputs, max_new_tokens=300) result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) return result.replace(prompt, "").strip() # TTS 合成 def text_to_speech(text): try: tts = gTTS(text=text, lang="zh") path = tempfile.mktemp(suffix=".mp3") tts.save(path) return path except Exception as e: print("TTS error:", e) return None # 主處理流程:支援文字和語音 def main(text, audio): transcript = "" if text and text.strip(): # 優先使用文字輸入 transcript = text.strip() elif audio: # 如果沒文字,使用語音辨識 transcript = transcribe(audio) if not transcript: return "", "未輸入文字或語音", None reply = generate_response(transcript) tts_path = text_to_speech(reply) return transcript, reply, tts_path # Gradio UI ui = gr.Interface( fn=main, inputs=[ gr.Textbox(label="文字輸入(可選)"), gr.Audio(type="filepath", label="語音輸入(可錄音或上傳)") ], outputs=[ gr.Textbox(label="辨識文字 / 文字輸入"), gr.Textbox(label="AI 回答"), gr.Audio(type="filepath", label="語音回覆") ], title="南臺科技大學語音文字問答系統", description="你可以輸入文字或錄音提問,系統會從資料中找出答案,並以語音回覆。" ) if __name__ == "__main__": ui.launch()