Spaces:
Runtime error
Runtime error
File size: 5,277 Bytes
9ccba7a 3b1849a 9ccba7a 587dce1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import json
import base64
import sqlite3
import tempfile
import torch
import whisper
from gtts import gTTS
from pydub import AudioSegment
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# ๆจกๅ่่ฃ็ฝฎๅๅงๅ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
whisper_model = whisper.load_model("medium", device=device)
embed_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", trust_remote_code=True).to(device).eval()
# ่ผๅ
ฅ qa.json
with open("qa.json", "r", encoding="utf-8") as f:
qa_data = json.load(f)
# ๅๅงๅ SQLite ่ณๆๅบซ
DB_PATH = "db.sqlite"
def init_db():
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT UNIQUE,
embedding TEXT
)''')
conn.commit()
conn.close()
def embed_and_store_documents():
for filename in os.listdir("documents"):
if filename.endswith(".txt"):
with open(os.path.join("documents", filename), "r", encoding="utf-8") as f:
for line in f:
content = line.strip()
if content:
embedding = embed_model.encode(content).astype(np.float32)
emb_str = base64.b64encode(embedding.tobytes()).decode("utf-8")
try:
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute("INSERT OR IGNORE INTO documents (content, embedding) VALUES (?, ?)", (content, emb_str))
conn.commit()
conn.close()
except Exception as e:
print("DB insert error:", e)
init_db()
embed_and_store_documents()
# ้้ตๅญๅน้
def match_qa(text):
for entry in qa_data:
keywords = entry["keywords"]
match_type = entry.get("match", "OR").upper()
if match_type == "OR" and any(k in text for k in keywords):
return entry["response"]
elif match_type == "AND" and all(k in text for k in keywords):
return entry["response"]
return None
# ๅ้็ธไผผๆๅฐ
def search_similar_paragraphs(query, top_k=5):
query_vec = embed_model.encode(query).astype(np.float32)
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute("SELECT content, embedding FROM documents")
all_data = c.fetchall()
conn.close()
scores = []
for content, emb_str in all_data:
emb = np.frombuffer(base64.b64decode(emb_str), dtype=np.float32)
sim = util.cos_sim(torch.tensor(query_vec), torch.tensor(emb))[0][0].item()
scores.append((content, sim))
top = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k]
return [item[0] for item in top]
# Whisper ่ช้ณ่พจ่ญ
def transcribe(audio_path):
result = whisper_model.transcribe(audio_path)
return result["text"]
# AI ๅ็ญ
def generate_response(text):
matched = match_qa(text)
if matched:
return matched
context = "\n".join(search_similar_paragraphs(text, top_k=5))
prompt = f"ๅชไพๆไปฅไธๆไพ็่ณๆ๏ผๅ็ญๅ้กใๅฆๆๆพไธๅฐ็ธ้่ณๆ๏ผๅฐฑ่ชชใๆพไธๅฐ็ธ้่ณ่จใ๏ผ\n{context}\nๅ้ก๏ผ{text}\nๅ็ญ๏ผ"
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
outputs = qwen_model.generate(**inputs, max_new_tokens=300)
result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
return result.replace(prompt, "").strip()
# TTS ๅๆ
def text_to_speech(text):
try:
tts = gTTS(text=text, lang="zh")
path = tempfile.mktemp(suffix=".mp3")
tts.save(path)
return path
except Exception as e:
print("TTS error:", e)
return None
# ไธป่็ๆต็จ๏ผๆฏๆดๆๅญๅ่ช้ณ
def main(text, audio):
transcript = ""
if text and text.strip():
# ๅชๅ
ไฝฟ็จๆๅญ่ผธๅ
ฅ
transcript = text.strip()
elif audio:
# ๅฆๆๆฒๆๅญ๏ผไฝฟ็จ่ช้ณ่พจ่ญ
transcript = transcribe(audio)
if not transcript:
return "", "ๆช่ผธๅ
ฅๆๅญๆ่ช้ณ", None
reply = generate_response(transcript)
tts_path = text_to_speech(reply)
return transcript, reply, tts_path
# Gradio UI
ui = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="ๆๅญ่ผธๅ
ฅ๏ผๅฏ้ธ๏ผ"),
gr.Audio(type="filepath", label="่ช้ณ่ผธๅ
ฅ๏ผๅฏ้้ณๆไธๅณ๏ผ")
],
outputs=[
gr.Textbox(label="่พจ่ญๆๅญ / ๆๅญ่ผธๅ
ฅ"),
gr.Textbox(label="AI ๅ็ญ"),
gr.Audio(type="filepath", label="่ช้ณๅ่ฆ")
],
title="ๅ่บ็งๆๅคงๅญธ่ช้ณๆๅญๅ็ญ็ณป็ตฑ",
description="ไฝ ๅฏไปฅ่ผธๅ
ฅๆๅญๆ้้ณๆๅ๏ผ็ณป็ตฑๆๅพ่ณๆไธญๆพๅบ็ญๆก๏ผไธฆไปฅ่ช้ณๅ่ฆใ"
)
if __name__ == "__main__":
ui.launch()
|