Spaces:
Runtime error
Runtime error
import os | |
import json | |
import base64 | |
import sqlite3 | |
import tempfile | |
import torch | |
import whisper | |
from gtts import gTTS | |
from pydub import AudioSegment | |
import numpy as np | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# ๆจกๅ่่ฃ็ฝฎๅๅงๅ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
whisper_model = whisper.load_model("medium", device=device) | |
embed_model = SentenceTransformer("all-MiniLM-L6-v2", device=device) | |
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True) | |
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True).to(device).eval() | |
# ่ผๅ ฅ qa.json | |
with open("qa.json", "r", encoding="utf-8") as f: | |
qa_data = json.load(f) | |
# ๅๅงๅ SQLite ่ณๆๅบซ | |
DB_PATH = "db.sqlite" | |
def init_db(): | |
conn = sqlite3.connect(DB_PATH) | |
c = conn.cursor() | |
c.execute('''CREATE TABLE IF NOT EXISTS documents ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
content TEXT UNIQUE, | |
embedding TEXT | |
)''') | |
conn.commit() | |
conn.close() | |
def embed_and_store_documents(): | |
for filename in os.listdir("documents"): | |
if filename.endswith(".txt"): | |
with open(os.path.join("documents", filename), "r", encoding="utf-8") as f: | |
for line in f: | |
content = line.strip() | |
if content: | |
embedding = embed_model.encode(content).astype(np.float32) | |
emb_str = base64.b64encode(embedding.tobytes()).decode("utf-8") | |
try: | |
conn = sqlite3.connect(DB_PATH) | |
c = conn.cursor() | |
c.execute("INSERT OR IGNORE INTO documents (content, embedding) VALUES (?, ?)", (content, emb_str)) | |
conn.commit() | |
conn.close() | |
except Exception as e: | |
print("DB insert error:", e) | |
init_db() | |
embed_and_store_documents() | |
# ้้ตๅญๅน้ | |
def match_qa(text): | |
for entry in qa_data: | |
keywords = entry["keywords"] | |
match_type = entry.get("match", "OR").upper() | |
if match_type == "OR" and any(k in text for k in keywords): | |
return entry["response"] | |
elif match_type == "AND" and all(k in text for k in keywords): | |
return entry["response"] | |
return None | |
# ๅ้็ธไผผๆๅฐ | |
def search_similar_paragraphs(query, top_k=5): | |
query_vec = embed_model.encode(query).astype(np.float32) | |
conn = sqlite3.connect(DB_PATH) | |
c = conn.cursor() | |
c.execute("SELECT content, embedding FROM documents") | |
all_data = c.fetchall() | |
conn.close() | |
scores = [] | |
for content, emb_str in all_data: | |
emb = np.frombuffer(base64.b64decode(emb_str), dtype=np.float32) | |
sim = util.cos_sim(torch.tensor(query_vec), torch.tensor(emb))[0][0].item() | |
scores.append((content, sim)) | |
top = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k] | |
return [item[0] for item in top] | |
# Whisper ่ช้ณ่พจ่ญ | |
def transcribe(audio_path): | |
result = whisper_model.transcribe(audio_path) | |
return result["text"] | |
# AI ๅ็ญ | |
def generate_response(text): | |
matched = match_qa(text) | |
if matched: | |
return matched | |
context = "\n".join(search_similar_paragraphs(text, top_k=5)) | |
prompt = f"ๅชไพๆไปฅไธๆไพ็่ณๆ๏ผๅ็ญๅ้กใๅฆๆๆพไธๅฐ็ธ้่ณๆ๏ผๅฐฑ่ชชใๆพไธๅฐ็ธ้่ณ่จใ๏ผ\n{context}\nๅ้ก๏ผ{text}\nๅ็ญ๏ผ" | |
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device) | |
outputs = qwen_model.generate(**inputs, max_new_tokens=300) | |
result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return result.replace(prompt, "").strip() | |
# TTS ๅๆ | |
def text_to_speech(text): | |
try: | |
tts = gTTS(text=text, lang="zh") | |
path = tempfile.mktemp(suffix=".mp3") | |
tts.save(path) | |
return path | |
except Exception as e: | |
print("TTS error:", e) | |
return None | |
# ไธป่็ๆต็จ๏ผๆฏๆดๆๅญๅ่ช้ณ | |
def main(text, audio): | |
transcript = "" | |
if text and text.strip(): | |
# ๅชๅ ไฝฟ็จๆๅญ่ผธๅ ฅ | |
transcript = text.strip() | |
elif audio: | |
# ๅฆๆๆฒๆๅญ๏ผไฝฟ็จ่ช้ณ่พจ่ญ | |
transcript = transcribe(audio) | |
if not transcript: | |
return "", "ๆช่ผธๅ ฅๆๅญๆ่ช้ณ", None | |
reply = generate_response(transcript) | |
tts_path = text_to_speech(reply) | |
return transcript, reply, tts_path | |
# Gradio UI | |
ui = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.Textbox(label="ๆๅญ่ผธๅ ฅ๏ผๅฏ้ธ๏ผ"), | |
gr.Audio(type="filepath", label="่ช้ณ่ผธๅ ฅ๏ผๅฏ้้ณๆไธๅณ๏ผ") | |
], | |
outputs=[ | |
gr.Textbox(label="่พจ่ญๆๅญ / ๆๅญ่ผธๅ ฅ"), | |
gr.Textbox(label="AI ๅ็ญ"), | |
gr.Audio(type="filepath", label="่ช้ณๅ่ฆ") | |
], | |
title="ๅ่บ็งๆๅคงๅญธ่ช้ณๆๅญๅ็ญ็ณป็ตฑ", | |
description="ไฝ ๅฏไปฅ่ผธๅ ฅๆๅญๆ้้ณๆๅ๏ผ็ณป็ตฑๆๅพ่ณๆไธญๆพๅบ็ญๆก๏ผไธฆไปฅ่ช้ณๅ่ฆใ" | |
) | |
if __name__ == "__main__": | |
ui.launch() | |