Spaces:
Running
Running
import os | |
import json | |
import shutil | |
import gradio as gr | |
import tempfile | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional | |
from pytube import YouTube | |
from pathlib import Path | |
import re | |
# --- Agent Imports & Safe Fallbacks --- | |
try: | |
from alz_companion.agent import ( | |
bootstrap_vectorstore, make_rag_chain, answer_query, synthesize_tts, | |
transcribe_audio, detect_tags_from_query, describe_image, build_or_load_vectorstore, | |
_default_embeddings | |
) | |
from alz_companion.prompts import BEHAVIOUR_TAGS, EMOTION_STYLES | |
from langchain.schema import Document | |
from langchain_community.vectorstores import FAISS | |
AGENT_OK = True | |
except Exception as e: | |
AGENT_OK = False | |
def bootstrap_vectorstore(sample_paths=None, index_path="data/"): return object() | |
def build_or_load_vectorstore(docs, index_path, is_personal=False): return object() | |
def make_rag_chain(vs_general, vs_personal, **kwargs): return lambda q, **k: {"answer": f"(Demo) You asked: {q}", "sources": []} | |
def answer_query(chain, q, **kwargs): return chain(q, **kwargs) | |
def synthesize_tts(text: str, lang: str = "en"): return None | |
def transcribe_audio(filepath: str, lang: str = "en"): return "This is a transcribed message." | |
def detect_tags_from_query(*args, **kwargs): return {"detected_behavior": "None", "detected_emotion": "None"} | |
def describe_image(image_path: str): return "This is a description of an image." | |
def _default_embeddings(): return None | |
class Document: | |
def __init__(self, page_content, metadata): self.page_content, self.metadata = page_content, metadata | |
class FAISS: | |
def __init__(self): self.docstore = type('obj', (object,), {'_dict': {}})() | |
BEHAVIOUR_TAGS, EMOTION_STYLES = {"None": []}, {"None": {}} | |
print(f"WARNING: Could not import from alz_companion ({e}). Running in UI-only demo mode.") | |
# --- Centralized Configuration --- | |
CONFIG = { | |
"themes": ["All", "The Father", "Still Alice", "Away from Her", "Alive Inside", "General Caregiving"], | |
"roles": ["patient", "caregiver"], | |
"behavior_tags": ["None"] + list(BEHAVIOUR_TAGS.keys()), | |
"emotion_tags": ["None"] + list(EMOTION_STYLES.keys()), | |
"topic_tags": ["None", "caregiving_advice", "medical_fact", "personal_story", "research_update", "treatment_option:home_safety", "treatment_option:long_term_care", "treatment_option:music_therapy", "treatment_option:reassurance", "treatment_option:routine_structuring", "treatment_option:validation_therapy"], | |
"context_tags": ["None", "disease_stage_mild", "disease_stage_moderate", "disease_stage_advanced", "disease_stage_unspecified", "interaction_mode_one_to_one", "interaction_mode_small_group", "interaction_mode_group_activity", "relationship_family", "relationship_spouse", "relationship_staff_or_caregiver", "relationship_unspecified", "setting_home_or_community", "setting_care_home", "setting_clinic_or_hospital"], | |
"languages": {"English": "en", "Chinese": "zh", "Cantonese": "zh-yue", "Korean": "ko", "Japanese": "ja", "Malay": "ms", "French": "fr", "Spanish": "es", "Hindi": "hi", "Arabic": "ar"}, | |
"tones": ["warm", "empathetic", "caring", "reassuring", "calm", "optimistic", "motivating", "neutral", "formal", "humorous"] | |
} | |
# --- File Management & Vector Store Logic --- | |
def _storage_root() -> Path: | |
for p in [Path(os.getenv("SPACE_STORAGE", "")), Path("/data"), Path.home() / ".cache" / "alz_companion"]: | |
if not p: continue | |
try: | |
p.mkdir(parents=True, exist_ok=True) | |
(p / ".write_test").write_text("ok") | |
(p / ".write_test").unlink(missing_ok=True) | |
return p | |
except Exception: continue | |
tmp = Path(tempfile.gettempdir()) / "alz_companion" | |
tmp.mkdir(parents=True, exist_ok=True) | |
return tmp | |
STORAGE_ROOT = _storage_root() | |
INDEX_BASE = STORAGE_ROOT / "index" | |
PERSONAL_DATA_BASE = STORAGE_ROOT / "personal" | |
UPLOADS_BASE = INDEX_BASE / "uploads" | |
PERSONAL_INDEX_PATH = str(PERSONAL_DATA_BASE / "personal_faiss_index") | |
NLU_EXAMPLES_INDEX_PATH = str(INDEX_BASE / "nlu_examples_faiss_index") | |
THEME_PATHS = {t: str(INDEX_BASE / f"faiss_index_{t.replace(' ', '').lower()}") for t in CONFIG["themes"]} | |
os.makedirs(UPLOADS_BASE, exist_ok=True) | |
os.makedirs(PERSONAL_DATA_BASE, exist_ok=True) | |
for p in THEME_PATHS.values(): os.makedirs(p, exist_ok=True) | |
vectorstores = {} | |
personal_vectorstore = None | |
nlu_vectorstore = None | |
test_fixtures = [] | |
try: | |
personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
except Exception: | |
personal_vectorstore = None | |
def bootstrap_nlu_vectorstore(example_file: str, index_path: str) -> FAISS: | |
if not os.path.exists(example_file): | |
print(f"WARNING: NLU example file not found at {example_file}. NLU will be less accurate.") | |
return build_or_load_vectorstore([], index_path) | |
docs = [] | |
with open(example_file, "r", encoding="utf-8") as f: | |
for line in f: | |
try: | |
data = json.loads(line) | |
doc = Document(page_content=data["query"], metadata=data) | |
docs.append(doc) | |
except (json.JSONDecodeError, KeyError): | |
continue | |
print(f"Found and loaded {len(docs)} NLU training examples.") | |
if os.path.exists(index_path): | |
shutil.rmtree(index_path) | |
return build_or_load_vectorstore(docs, index_path) | |
def canonical_theme(tk: str) -> str: return tk if tk in CONFIG["themes"] else "All" | |
def theme_upload_dir(theme: str) -> str: | |
p = UPLOADS_BASE / f"theme_{canonical_theme(theme).replace(' ', '').lower()}" | |
p.mkdir(exist_ok=True) | |
return str(p) | |
def load_manifest(theme: str) -> Dict[str, Any]: | |
p = os.path.join(theme_upload_dir(theme), "manifest.json") | |
if os.path.exists(p): | |
try: | |
with open(p, "r", encoding="utf-8") as f: return json.load(f) | |
except Exception: pass | |
return {"files": {}} | |
def save_manifest(theme: str, man: Dict[str, Any]): | |
with open(os.path.join(theme_upload_dir(theme), "manifest.json"), "w", encoding="utf-8") as f: json.dump(man, f, indent=2) | |
def list_theme_files(theme: str) -> List[tuple[str, bool]]: | |
man = load_manifest(theme) | |
base = theme_upload_dir(theme) | |
found = [(n, bool(e)) for n, e in man.get("files", {}).items() if os.path.exists(os.path.join(base, n))] | |
existing = {n for n, e in found} | |
for name in sorted(os.listdir(base)): | |
if name not in existing and os.path.isfile(os.path.join(base, name)): found.append((name, False)) | |
man["files"] = dict(found) | |
save_manifest(theme, man) | |
return found | |
def copy_into_theme(theme: str, src_path: str) -> str: | |
fname = os.path.basename(src_path) | |
dest = os.path.join(theme_upload_dir(theme), fname) | |
shutil.copy2(src_path, dest) | |
return dest | |
def seed_files_into_theme(theme: str): | |
SEED_FILES = [("sample_data/caregiving_tips.txt", True), ("sample_data/the_father_segments_enriched_harmonized_plus.jsonl", True), ("sample_data/still_alice_enriched_harmonized_plus.jsonl", True), ("sample_data/away_from_her_enriched_harmonized_plus.jsonl", True), ("sample_data/alive_inside_enriched_harmonized.jsonl", True)] | |
man, changed = load_manifest(theme), False | |
for path, enable in SEED_FILES: | |
if not os.path.exists(path): continue | |
fname = os.path.basename(path) | |
if not os.path.exists(os.path.join(theme_upload_dir(theme), fname)): | |
copy_into_theme(theme, path) | |
man["files"][fname] = bool(enable) | |
changed = True | |
if changed: save_manifest(theme, man) | |
def ensure_index(theme='All'): | |
theme = canonical_theme(theme) | |
if theme in vectorstores: return vectorstores[theme] | |
upload_dir = theme_upload_dir(theme) | |
enabled_files = [os.path.join(upload_dir, n) for n, enabled in list_theme_files(theme) if enabled] | |
index_path = THEME_PATHS.get(theme) | |
vectorstores[theme] = bootstrap_vectorstore(sample_paths=enabled_files, index_path=index_path) | |
return vectorstores[theme] | |
# --- Gradio Callbacks --- | |
def collect_settings(*args): | |
keys = ["role", "patient_name", "caregiver_name", "tone", "language", "tts_lang", "temperature", "behaviour_tag", "emotion_tag", "topic_tag", "active_theme", "tts_on", "debug_mode"] | |
return dict(zip(keys, args)) | |
def parse_and_tag_entries(text_content: str, source: str, settings: dict = None) -> List[Document]: | |
docs_to_add = [] | |
for entry in re.split(r'\n(?:---|--|-|-\*-|-\.-)\n', text_content): | |
if not entry.strip(): continue | |
lines = entry.strip().split('\n') | |
title_line = lines[0].split(':', 1) | |
title = title_line[1].strip() if len(title_line) > 1 and "title:" in lines[0].lower() else "Untitled Text Entry" | |
content_part = "\n".join(lines[1:]) | |
content = content_part.split(':', 1)[1].strip() if "content:" in content_part.lower() else content_part.strip() | |
full_content = f"Title: {title}\n\nContent: {content}" | |
detected_tags = detect_tags_from_query( | |
content, nlu_vectorstore=nlu_vectorstore, behavior_options=CONFIG["behavior_tags"], | |
emotion_options=CONFIG["emotion_tags"], topic_options=CONFIG["topic_tags"], | |
context_options=CONFIG["context_tags"], settings=settings) | |
metadata = {"source": source, "title": title} | |
if detected_tags.get("detected_behaviors"): metadata["behaviors"] = [b.lower() for b in detected_tags["detected_behaviors"]] | |
if detected_tags.get("detected_emotion") != "None": metadata["emotion"] = detected_tags.get("detected_emotion").lower() | |
if detected_tags.get("detected_topic") != "None": metadata["topic_tags"] = [detected_tags.get("detected_topic").lower()] | |
if detected_tags.get("detected_contexts"): metadata["context_tags"] = [c.lower() for c in detected_tags["detected_contexts"]] | |
docs_to_add.append(Document(page_content=full_content, metadata=metadata)) | |
return docs_to_add | |
def handle_add_knowledge(title, text_input, file_input, image_input, yt_url, settings): | |
global personal_vectorstore | |
docs_to_add = [] | |
source, content = "Unknown", "" | |
if text_input and text_input.strip(): | |
source, content = "Text Input", f"Title: {title or 'Untitled'}\n\nContent: {text_input}" | |
elif file_input: | |
source = os.path.basename(file_input.name) | |
if file_input.name.lower().endswith('.txt'): | |
with open(file_input.name, 'r', encoding='utf-8') as f: content = f.read() | |
else: | |
transcribed = transcribe_audio(file_input.name) | |
content = f"Title: {title or 'Audio/Video Note'}\n\nContent: {transcribed}" | |
elif image_input: | |
source, description = "Image Input", describe_image(image_input) | |
content = f"Title: {title or 'Image Note'}\n\nContent: {description}" | |
elif yt_url and ("youtube.com" in yt_url or "youtu.be" in yt_url): | |
try: | |
yt = YouTube(yt_url) | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_audio_file: | |
yt.streams.get_audio_only().download(filename=temp_audio_file.name) | |
transcribed = transcribe_audio(temp_audio_file.name) | |
os.remove(temp_audio_file.name) | |
source, content = f"YouTube: {yt.title}", f"Title: {title or yt.title}\n\nContent: {transcribed}" | |
except Exception as e: | |
return f"Error processing YouTube link: {e}" | |
else: | |
return "Please provide content to add." | |
if content: | |
docs_to_add = parse_and_tag_entries(content, source, settings=settings) | |
if not docs_to_add: return "No processable content found to add." | |
if personal_vectorstore is None: | |
personal_vectorstore = build_or_load_vectorstore(docs_to_add, PERSONAL_INDEX_PATH, is_personal=True) | |
else: | |
personal_vectorstore.add_documents(docs_to_add) | |
personal_vectorstore.save_local(PERSONAL_INDEX_PATH) | |
return f"Successfully added {len(docs_to_add)} new memory/memories." | |
def chat_fn(user_text, audio_file, settings, chat_history): | |
global personal_vectorstore | |
if chat_history: | |
chat_history.reverse() | |
question = (user_text or "").strip() | |
if audio_file and not question: | |
try: | |
question = transcribe_audio(audio_file, lang=CONFIG["languages"].get(settings.get("tts_lang", "English"), "en")) | |
except Exception as e: | |
err_msg = f"Audio Error: {e}" if settings.get("debug_mode") else "Sorry, I couldn't understand the audio." | |
chat_history.append({"role": "assistant", "content": err_msg}) | |
return "", None, chat_history[::-1] | |
if not question: | |
if chat_history: | |
chat_history.reverse() | |
return "", None, chat_history | |
chat_history.append({"role": "user", "content": question}) | |
final_tags = { "scenario_tag": None, "emotion_tag": None, "topic_tag": None, "context_tags": [] } | |
manual_behavior = settings.get("behaviour_tag", "None") | |
manual_emotion = settings.get("emotion_tag", "None") | |
manual_topic = settings.get("topic_tag", "None") | |
if all(m == "None" for m in [manual_behavior, manual_emotion, manual_topic]): | |
detected_tags = detect_tags_from_query( | |
question, nlu_vectorstore=nlu_vectorstore, behavior_options=CONFIG["behavior_tags"], | |
emotion_options=CONFIG["emotion_tags"], topic_options=CONFIG["topic_tags"], | |
context_options=CONFIG["context_tags"], settings=settings) | |
behaviors = detected_tags.get("detected_behaviors") | |
if behaviors: | |
final_tags["scenario_tag"] = behaviors[0] | |
else: | |
final_tags["scenario_tag"] = None | |
final_tags["emotion_tag"] = detected_tags.get("detected_emotion") | |
final_tags["topic_tag"] = detected_tags.get("detected_topic") | |
final_tags["context_tags"] = detected_tags.get("detected_contexts", []) | |
detected_parts = [f"{k.split('_')[1]}=`{v}`" for k, v in final_tags.items() if v and v != "None"] | |
if detected_parts: | |
chat_history.append({"role": "assistant", "content": f"*(Auto-detected context: {', '.join(detected_parts)})*"}) | |
else: | |
final_tags["scenario_tag"] = manual_behavior if manual_behavior != "None" else None | |
final_tags["emotion_tag"] = manual_emotion if manual_emotion != "None" else None | |
final_tags["topic_tag"] = manual_topic if manual_topic != "None" else None | |
vs_general = ensure_index(settings.get("active_theme", "All")) | |
if personal_vectorstore is None: | |
personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
rag_settings = {k: settings.get(k) for k in ["role", "temperature", "language", "patient_name", "caregiver_name", "tone"]} | |
chain = make_rag_chain(vs_general, personal_vectorstore, **rag_settings) | |
response = answer_query(chain, question, chat_history=chat_history[:-1], **final_tags) | |
answer = response.get("answer", "[No answer found]") | |
chat_history.append({"role": "assistant", "content": answer}) | |
if response.get("sources"): | |
chat_history.append({"role": "assistant", "content": f"*(Sources used: {', '.join(response['sources'])})*"}) | |
audio_out = None | |
if settings.get("tts_on") and answer: | |
audio_out = synthesize_tts(answer, lang=CONFIG["languages"].get(settings.get("tts_lang"), "en")) | |
return "", gr.update(value=audio_out, visible=bool(audio_out)), chat_history[::-1] | |
def save_chat_to_memory(chat_history): | |
global personal_vectorstore | |
if chat_history: | |
chat_history.reverse() | |
if not chat_history: return "Nothing to save." | |
formatted_chat = [f"{m['role'].title()}: {m['content'].strip()}" for m in chat_history if not m['content'].strip().startswith("*(")] | |
if not formatted_chat: return "No conversation to save." | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
title = f"Conversation from {timestamp}" | |
full_content = f"Title: {title}\n\nContent:\n" + "\n".join(formatted_chat) | |
doc = Document(page_content=full_content, metadata={"source": "Saved Chat", "title": title}) | |
if personal_vectorstore is None: | |
personal_vectorstore = build_or_load_vectorstore([doc], PERSONAL_INDEX_PATH, is_personal=True) | |
else: | |
personal_vectorstore.add_documents([doc]) | |
personal_vectorstore.save_local(PERSONAL_INDEX_PATH) | |
return f"Conversation from {timestamp} saved." | |
def list_personal_memories(): | |
global personal_vectorstore | |
if personal_vectorstore is None or not hasattr(personal_vectorstore.docstore, '_dict') or not personal_vectorstore.docstore._dict: | |
return gr.update(value=[["No memories", "", ""]]), gr.update(choices=[], value=None) | |
docs = list(personal_vectorstore.docstore._dict.values()) | |
return gr.update(value=[[d.metadata.get('title', '...'), d.metadata.get('source', '...'), d.page_content] for d in docs]), gr.update(choices=[d.page_content for d in docs]) | |
def delete_personal_memory(memory_to_delete): | |
global personal_vectorstore | |
if personal_vectorstore is None or not memory_to_delete: return "No memory selected." | |
all_docs = list(personal_vectorstore.docstore._dict.values()) | |
docs_to_keep = [d for d in all_docs if d.page_content != memory_to_delete] | |
if len(all_docs) == len(docs_to_keep): return "Error: Could not find memory." | |
if not docs_to_keep: | |
if os.path.isdir(PERSONAL_INDEX_PATH): shutil.rmtree(PERSONAL_INDEX_PATH) | |
personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
else: | |
new_vs = FAISS.from_documents(docs_to_keep, _default_embeddings()) | |
new_vs.save_local(PERSONAL_INDEX_PATH) | |
personal_vectorstore = new_vs | |
return "Successfully deleted memory." | |
def upload_knowledge(files, theme): | |
for f in files: copy_into_theme(theme, f.name) | |
if theme in vectorstores: del vectorstores[theme] | |
return f"Uploaded {len(files)} file(s)." | |
def save_file_selection(theme, enabled): | |
man = load_manifest(theme) | |
for fname in man['files']: man['files'][fname] = fname in enabled | |
save_manifest(theme, man) | |
if theme in vectorstores: del vectorstores[theme] | |
return f"Settings saved for theme '{theme}'." | |
def refresh_file_list_ui(theme): | |
files = list_theme_files(theme) | |
return gr.update(choices=[f for f, _ in files], value=[f for f, en in files if en]), f"Found {len(files)} file(s)." | |
def auto_setup_on_load(theme): | |
if not os.listdir(theme_upload_dir(theme)): seed_files_into_theme(theme) | |
settings = collect_settings("caregiver", "", "", "warm", "English", "English", 0.7, "None", "None", "None", "All", True, False) | |
files_ui, status = refresh_file_list_ui(theme) | |
return settings, files_ui, status | |
def run_nlu_test(test_title: str): | |
if not test_title or not test_fixtures: return "Please select a test case.", None | |
fixture = next((f for f in test_fixtures if f["title"] == test_title), None) | |
if not fixture: return f"Error: Could not find test case '{test_title}'.", None | |
actual_raw = detect_tags_from_query( | |
fixture["turns"][0]["text"], nlu_vectorstore, CONFIG["behavior_tags"], CONFIG["emotion_tags"], CONFIG["topic_tags"], CONFIG["context_tags"] | |
) | |
actual = {"emotion": [actual_raw.get("detected_emotion")], "behaviors": actual_raw.get("detected_behaviors", []), "topic_tags": [actual_raw.get("detected_topic")], "context_tags": actual_raw.get("detected_contexts", [])} | |
pass_count, total_count, data = 0, 0, [] | |
expected = fixture["expected"] | |
all_keys = set(expected.keys()) | set(actual.keys()) | |
for key in sorted(list(all_keys)): | |
expected_set = set(expected.get(key, [])) | |
if not expected_set: continue | |
total_count += 1 | |
actual_set = set(a for a in actual.get(key, []) if a and a != "None") | |
is_pass = len(expected_set.intersection(actual_set)) > 0 | |
if is_pass: pass_count += 1 | |
data.append([key, ", ".join(sorted(list(expected_set))), ", ".join(sorted(list(actual_set))) or "None", "✅ Pass" if is_pass else "❌ Fail"]) | |
return f"## Test Result: {pass_count} / {total_count} Passed", data | |
def load_test_fixtures(): | |
global test_fixtures | |
test_fixtures = [] | |
fixtures_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "conversation_test_fixtures.jsonl") | |
if not os.path.exists(fixtures_path): return gr.update(choices=[]) | |
with open(fixtures_path, "r", encoding="utf-8") as f: | |
for line in f: test_fixtures.append(json.loads(line)) | |
return gr.update(choices=[f["title"] for f in test_fixtures]) | |
def run_all_nlu_tests(): | |
if not test_fixtures: load_test_fixtures() | |
if not test_fixtures: return "## No test fixtures found.", [] | |
passed_tests, all_results = 0, [] | |
for fixture in test_fixtures: | |
user_query = fixture["turns"][0]["text"] | |
expected_results = fixture["expected"] | |
actual_results_raw = detect_tags_from_query(user_query, nlu_vectorstore, CONFIG["behavior_tags"], CONFIG["emotion_tags"], CONFIG["topic_tags"], CONFIG["context_tags"]) | |
actual_results = {"emotion": [actual_results_raw.get("detected_emotion")], "behaviors": actual_results_raw.get("detected_behaviors", []), "topic_tags": [actual_results_raw.get("detected_topic")], "context_tags": actual_results_raw.get("detected_contexts", [])} | |
pass_count, total_count = 0, 0 | |
for key in sorted(list(expected_results.keys())): | |
expected_set = set(expected_results.get(key, [])) | |
if not expected_set: continue | |
total_count += 1 | |
actual_set = set(a for a in actual_results.get(key, []) if a and a != "None") | |
if len(expected_set.intersection(actual_set)) > 0: pass_count += 1 | |
overall_result = "❌ Fail" | |
if total_count > 0: | |
pass_ratio = pass_count / total_count | |
if pass_ratio == 1.0: passed_tests += 1; overall_result = "✅ Pass" | |
elif pass_ratio > 0.65: overall_result = "⚠️ Partial" | |
all_results.append([fixture["title"], overall_result, f"{pass_count} / {total_count}"]) | |
pass_rate = (passed_tests / len(test_fixtures)) * 100 if test_fixtures else 0 | |
return f"## Batch Summary: {passed_tests} / {len(test_fixtures)} Tests Passed ({pass_rate:.1f}%)", all_results | |
def test_save_file(): | |
try: | |
path = PERSONAL_DATA_BASE / "persistence_test.txt" | |
path.write_text(f"File saved at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
return f"✅ Success! Wrote test file to: {path}" | |
except Exception as e: return f"❌ Error! Failed to write file: {e}" | |
def check_test_file(): | |
path = PERSONAL_DATA_BASE / "persistence_test.txt" | |
if path.exists(): return f"✅ Success! Found test file. Contents: '{path.read_text()}'" | |
return f"❌ Failure. Test file not found at: {path}" | |
# --- UI Definition --- | |
CSS = """ | |
.gradio-container { font-size: 14px; } | |
#chatbot { min-height: 400px; } | |
#audio_in audio, #audio_out audio { max-height: 40px; } | |
#audio_in .waveform, #audio_out .waveform { display: none !important; } | |
#audio_in, #audio_out { min-height: 0px !important; } | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo: | |
settings_state = gr.State({}) | |
with gr.Tab("Chat"): | |
with gr.Row(): | |
user_text = gr.Textbox(show_label=False, placeholder="Type your message here...", scale=7) | |
submit_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Row(): | |
audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Voice Input", elem_id="audio_in") | |
audio_out = gr.Audio(label="Response Audio", autoplay=True, visible=True, elem_id="audio_out") | |
chatbot = gr.Chatbot(elem_id="chatbot", label="Conversation", type="messages") | |
chat_status = gr.Markdown() | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
save_btn = gr.Button("Save to Memory") | |
with gr.Tab("Personalize"): | |
with gr.Accordion("Add to Personal Knowledge Base", open=True): | |
personal_title = gr.Textbox(label="Title") | |
personal_text = gr.Textbox(lines=5, label="Text Content") | |
with gr.Row(): | |
personal_file = gr.File(label="Upload Audio/Video/Text File") | |
personal_image = gr.Image(type="filepath", label="Upload Image") | |
personal_yt_url = gr.Textbox(label="Or, provide a YouTube URL") | |
personal_add_btn = gr.Button("Add Knowledge", variant="primary") | |
personal_status = gr.Markdown() | |
gr.Markdown("### **Manage Personal Knowledge**") | |
with gr.Accordion("View/Hide Details", open=False): | |
personal_memory_display = gr.DataFrame(headers=["Title", "Source", "Content"], label="Saved Memories", row_count=(5, "dynamic")) | |
personal_refresh_btn = gr.Button("Refresh Memories") | |
personal_delete_selector = gr.Dropdown(label="Select memory to delete", scale=3, interactive=True) | |
personal_delete_btn = gr.Button("Delete Selected", variant="stop", scale=1) | |
personal_delete_status = gr.Markdown() | |
with gr.Tab("Testing"): | |
gr.Markdown("## NLU Context Detection Tests") | |
batch_summary_md = gr.Markdown("### Batch Test Summary: Not yet run.") | |
with gr.Row(): | |
test_case_dropdown = gr.Dropdown(label="Select Single Test Case", scale=2) | |
run_test_btn = gr.Button("Run Single Test", scale=1) | |
run_all_btn = gr.Button("Run All Tests", variant="primary", scale=1) | |
test_status_md = gr.Markdown("### Test Results") | |
test_results_df = gr.DataFrame(label="Test Comparison", headers=["Test Case Title", "Result", "Categories Passed"], interactive=False) | |
with gr.Tab("Settings"): | |
with gr.Group(): | |
gr.Markdown("## Conversation & Persona Settings") | |
with gr.Row(): | |
role = gr.Radio(CONFIG["roles"], value="patient", label="Your Role") | |
patient_name = gr.Textbox(label="Patient's Name") | |
caregiver_name = gr.Textbox(label="Caregiver's Name") | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 1.2, value=0.7, step=0.1, label="Creativity") | |
tone = gr.Dropdown(CONFIG["tones"], value="warm", label="Response Tone") | |
with gr.Row(): | |
behaviour_tag = gr.Dropdown(CONFIG["behavior_tags"], value="None", label="Behaviour Filter (Manual)") | |
emotion_tag = gr.Dropdown(CONFIG["emotion_tags"], value="None", label="Emotion Filter (Manual)") | |
topic_tag = gr.Dropdown(CONFIG["topic_tags"], value="None", label="Topic Tag Filter (Manual)") | |
with gr.Accordion("Language, Voice & Debugging", open=False): | |
language = gr.Dropdown(list(CONFIG["languages"].keys()), value="English", label="Response Language") | |
tts_lang = gr.Dropdown(list(CONFIG["languages"].keys()), value="English", label="Voice Language") | |
tts_on = gr.Checkbox(True, label="Enable Voice Response") | |
debug_mode = gr.Checkbox(False, label="Show Debug Info") | |
gr.Markdown("--- \n ## General Knowledge Base Management") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
files_in = gr.File(file_count="multiple", file_types=[".jsonl", ".txt"], label="Upload Knowledge Files") | |
upload_btn = gr.Button("Upload to Theme") | |
seed_btn = gr.Button("Import Sample Data") | |
mgmt_status = gr.Markdown() | |
with gr.Column(scale=2): | |
active_theme = gr.Radio(CONFIG["themes"], value="All", label="Active Knowledge Theme") | |
files_box = gr.CheckboxGroup(choices=[], label="Enable Files for Selected Theme") | |
with gr.Row(): | |
save_files_btn = gr.Button("Save Selection", variant="primary") | |
refresh_btn = gr.Button("Refresh List") | |
with gr.Accordion("Persistence Test", open=False): | |
test_save_btn = gr.Button("1. Run Persistence Test (Save File)") | |
check_save_btn = gr.Button("3. Check for Test File") | |
test_status = gr.Markdown() | |
# --- Event Wiring --- | |
all_settings = [role, patient_name, caregiver_name, tone, language, tts_lang, temperature, behaviour_tag, emotion_tag, topic_tag, active_theme, tts_on, debug_mode] | |
for c in all_settings: c.change(fn=collect_settings, inputs=all_settings, outputs=settings_state) | |
submit_btn.click(fn=chat_fn, inputs=[user_text, audio_in, settings_state, chatbot], outputs=[user_text, audio_out, chatbot]) | |
save_btn.click(fn=save_chat_to_memory, inputs=[chatbot], outputs=[chat_status]) | |
clear_btn.click(lambda: (None, None, [], None, "", ""), outputs=[user_text, audio_out, chatbot, audio_in, user_text, chat_status]) | |
personal_add_btn.click(fn=handle_add_knowledge, inputs=[personal_title, personal_text, personal_file, personal_image, personal_yt_url, settings_state], outputs=[personal_status]).then(lambda: (None, None, None, None, None), outputs=[personal_title, personal_text, personal_file, personal_image, personal_yt_url]) | |
personal_refresh_btn.click(fn=list_personal_memories, inputs=None, outputs=[personal_memory_display, personal_delete_selector]) | |
personal_delete_btn.click(fn=delete_personal_memory, inputs=[personal_delete_selector], outputs=[personal_delete_status]).then(fn=list_personal_memories, inputs=None, outputs=[personal_memory_display, personal_delete_selector]) | |
upload_btn.click(upload_knowledge, inputs=[files_in, active_theme], outputs=[mgmt_status]).then(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
save_files_btn.click(save_file_selection, inputs=[active_theme, files_box], outputs=[mgmt_status]) | |
seed_btn.click(seed_files_into_theme, inputs=[active_theme]).then(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
refresh_btn.click(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
active_theme.change(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
demo.load(auto_setup_on_load, inputs=[active_theme], outputs=[settings_state, files_box, mgmt_status]) | |
demo.load(load_test_fixtures, outputs=[test_case_dropdown]) | |
run_test_btn.click(fn=run_nlu_test, inputs=[test_case_dropdown], outputs=[test_status_md, test_results_df]) | |
run_all_btn.click(fn=run_all_nlu_tests, outputs=[batch_summary_md, test_results_df]) | |
test_save_btn.click(fn=test_save_file, inputs=None, outputs=[test_status]) | |
check_save_btn.click(fn=check_test_file, inputs=None, outputs=[test_status]) | |
# --- Startup Logic --- | |
def pre_load_indexes(): | |
global personal_vectorstore, nlu_vectorstore | |
print("Pre-loading all indexes at startup...") | |
print(" - Loading NLU examples index...") | |
nlu_vectorstore = bootstrap_nlu_vectorstore("nlu_training_examples.jsonl", NLU_EXAMPLES_INDEX_PATH) | |
print(f" ...NLU index loaded.") | |
for theme in CONFIG["themes"]: | |
print(f" - Loading general index for theme: '{theme}'") | |
try: | |
ensure_index(theme) | |
print(f" ...'{theme}' theme loaded.") | |
except Exception as e: | |
print(f" ...Error loading theme '{theme}': {e}") | |
print(" - Loading personal knowledge index...") | |
try: | |
personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
print(" ...Personal knowledge loaded.") | |
except Exception as e: | |
print(f" ...Error loading personal knowledge: {e}") | |
print("All indexes loaded. Application is ready.") | |
if __name__ == "__main__": | |
seed_files_into_theme('All') | |
pre_load_indexes() | |
demo.queue().launch(debug=True) | |