20250901001 / app.py
julin90's picture
Update app.py
587dce1 verified
import os
import json
import base64
import sqlite3
import tempfile
import torch
import whisper
from gtts import gTTS
from pydub import AudioSegment
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# ๆจกๅž‹่ˆ‡่ฃ็ฝฎๅˆๅง‹ๅŒ–
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
whisper_model = whisper.load_model("medium", device=device)
embed_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True).to(device).eval()
# ่ผ‰ๅ…ฅ qa.json
with open("qa.json", "r", encoding="utf-8") as f:
qa_data = json.load(f)
# ๅˆๅง‹ๅŒ– SQLite ่ณ‡ๆ–™ๅบซ
DB_PATH = "db.sqlite"
def init_db():
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT UNIQUE,
embedding TEXT
)''')
conn.commit()
conn.close()
def embed_and_store_documents():
for filename in os.listdir("documents"):
if filename.endswith(".txt"):
with open(os.path.join("documents", filename), "r", encoding="utf-8") as f:
for line in f:
content = line.strip()
if content:
embedding = embed_model.encode(content).astype(np.float32)
emb_str = base64.b64encode(embedding.tobytes()).decode("utf-8")
try:
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute("INSERT OR IGNORE INTO documents (content, embedding) VALUES (?, ?)", (content, emb_str))
conn.commit()
conn.close()
except Exception as e:
print("DB insert error:", e)
init_db()
embed_and_store_documents()
# ้—œ้ตๅญ—ๅŒน้…
def match_qa(text):
for entry in qa_data:
keywords = entry["keywords"]
match_type = entry.get("match", "OR").upper()
if match_type == "OR" and any(k in text for k in keywords):
return entry["response"]
elif match_type == "AND" and all(k in text for k in keywords):
return entry["response"]
return None
# ๅ‘้‡็›ธไผผๆœๅฐ‹
def search_similar_paragraphs(query, top_k=5):
query_vec = embed_model.encode(query).astype(np.float32)
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute("SELECT content, embedding FROM documents")
all_data = c.fetchall()
conn.close()
scores = []
for content, emb_str in all_data:
emb = np.frombuffer(base64.b64decode(emb_str), dtype=np.float32)
sim = util.cos_sim(torch.tensor(query_vec), torch.tensor(emb))[0][0].item()
scores.append((content, sim))
top = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k]
return [item[0] for item in top]
# Whisper ่ชž้Ÿณ่พจ่ญ˜
def transcribe(audio_path):
result = whisper_model.transcribe(audio_path)
return result["text"]
# AI ๅ›ž็ญ”
def generate_response(text):
matched = match_qa(text)
if matched:
return matched
context = "\n".join(search_similar_paragraphs(text, top_k=5))
prompt = f"ๅชไพๆ“šไปฅไธ‹ๆไพ›็š„่ณ‡ๆ–™๏ผŒๅ›ž็ญ”ๅ•้กŒใ€‚ๅฆ‚ๆžœๆ‰พไธๅˆฐ็›ธ้—œ่ณ‡ๆ–™๏ผŒๅฐฑ่ชชใ€Œๆ‰พไธๅˆฐ็›ธ้—œ่ณ‡่จŠใ€๏ผš\n{context}\nๅ•้กŒ๏ผš{text}\nๅ›ž็ญ”๏ผš"
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
outputs = qwen_model.generate(**inputs, max_new_tokens=300)
result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
return result.replace(prompt, "").strip()
# TTS ๅˆๆˆ
def text_to_speech(text):
try:
tts = gTTS(text=text, lang="zh")
path = tempfile.mktemp(suffix=".mp3")
tts.save(path)
return path
except Exception as e:
print("TTS error:", e)
return None
# ไธป่™•็†ๆต็จ‹๏ผšๆ”ฏๆดๆ–‡ๅญ—ๅ’Œ่ชž้Ÿณ
def main(text, audio):
transcript = ""
if text and text.strip():
# ๅ„ชๅ…ˆไฝฟ็”จๆ–‡ๅญ—่ผธๅ…ฅ
transcript = text.strip()
elif audio:
# ๅฆ‚ๆžœๆฒ’ๆ–‡ๅญ—๏ผŒไฝฟ็”จ่ชž้Ÿณ่พจ่ญ˜
transcript = transcribe(audio)
if not transcript:
return "", "ๆœช่ผธๅ…ฅๆ–‡ๅญ—ๆˆ–่ชž้Ÿณ", None
reply = generate_response(transcript)
tts_path = text_to_speech(reply)
return transcript, reply, tts_path
# Gradio UI
ui = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="ๆ–‡ๅญ—่ผธๅ…ฅ๏ผˆๅฏ้ธ๏ผ‰"),
gr.Audio(type="filepath", label="่ชž้Ÿณ่ผธๅ…ฅ๏ผˆๅฏ้Œ„้Ÿณๆˆ–ไธŠๅ‚ณ๏ผ‰")
],
outputs=[
gr.Textbox(label="่พจ่ญ˜ๆ–‡ๅญ— / ๆ–‡ๅญ—่ผธๅ…ฅ"),
gr.Textbox(label="AI ๅ›ž็ญ”"),
gr.Audio(type="filepath", label="่ชž้Ÿณๅ›ž่ฆ†")
],
title="ๅ—่‡บ็ง‘ๆŠ€ๅคงๅญธ่ชž้Ÿณๆ–‡ๅญ—ๅ•็ญ”็ณป็ตฑ",
description="ไฝ ๅฏไปฅ่ผธๅ…ฅๆ–‡ๅญ—ๆˆ–้Œ„้Ÿณๆๅ•๏ผŒ็ณป็ตฑๆœƒๅพž่ณ‡ๆ–™ไธญๆ‰พๅ‡บ็ญ”ๆกˆ๏ผŒไธฆไปฅ่ชž้Ÿณๅ›ž่ฆ†ใ€‚"
)
if __name__ == "__main__":
ui.launch()