import os, time, requests, tempfile, asyncio, logging | |
import gradio as gr | |
from transformers import pipeline | |
import edge_tts | |
from collections import Counter | |
# βββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud/invocations" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# βββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1) SpeechβText | |
asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
def speech_to_text(audio): | |
if not audio: | |
return "" | |
# Gradio supplies a tuple (sr, ndarray) | |
if isinstance(audio, tuple): | |
sr, arr = audio | |
return asr(arr, sampling_rate=sr)["text"] | |
# filepath | |
return asr(audio)["text"] | |
# 2) Prompt formatting | |
def format_prompt(message, history): | |
fixed_prompt = """ | |
You are a smart mood analyzer tasked with determining the user's mood for a music recommendation system. Your goal is to classify the user's mood into one of four categories: Happy, Sad, Instrumental, or Party. | |
Instructions: | |
1. Engage in a conversation with the user to understand their mood. | |
2. Ask relevant questions to guide the conversation towards mood classification. | |
3. If the user's mood is clear, respond with a single word: "Happy", "Sad", "Instrumental", or "Party". | |
4. If the mood is unclear, continue the conversation with a follow-up question. | |
5. Limit the conversation to a maximum of 5 exchanges. | |
6. Do not classify the mood prematurely if it's not evident from the user's responses. | |
7. Focus on the user's emotional state rather than specific activities or preferences. | |
8. If unable to classify after 5 exchanges, respond with "Unclear" to indicate the need for more information. | |
Remember: Your primary goal is mood classification. Stay on topic and guide the conversation towards understanding the user's emotional state. | |
""" | |
prompt = f"{fixed_prompt}\n" | |
for i, (u, b) in enumerate(history): | |
prompt += f"User: {u}\nAssistant: {b}\n" | |
if i == 3: | |
prompt += "Note: This is the last exchange. Classify the mood if possible or respond with 'Unclear'.\n" | |
prompt += f"User: {message}\nAssistant:" | |
return prompt | |
# 3) Call HF Invocation Endpoint | |
def query_model(prompt, max_new_tokens=64, temperature=0.1): | |
headers = { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
"Content-Type": "application/json", | |
} | |
payload = { | |
"inputs": prompt, | |
"parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}, | |
} | |
resp = requests.post(ENDPOINT_URL, headers=headers, json=payload, timeout=30) | |
resp.raise_for_status() | |
return resp.json()[0]["generated_text"] | |
# 4) Aggregate mood from history | |
def aggregate_mood_from_history(history): | |
mood_words = {"happy", "sad", "instrumental", "party"} | |
counts = Counter() | |
for _, bot_response in history: | |
for tok in bot_response.split(): | |
w = tok.strip('.,?!;"\'').lower() | |
if w in mood_words: | |
counts[w] += 1 | |
if not counts: | |
return None | |
return counts.most_common(1)[0][0] | |
# 5) TextβSpeech | |
def text_to_speech(text): | |
communicate = edge_tts.Communicate(text) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
# save synchronously to simplify callback | |
asyncio.get_event_loop().run_until_complete(communicate.save(tmp.name)) | |
return tmp.name | |
# βββ Gradio Callbacks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def user_turn(user_input, history): | |
history = history + [(user_input, None)] | |
formatted = format_prompt(user_input, history) | |
raw = query_model(formatted) | |
# temporarily assign raw | |
history[-1] = (user_input, raw) | |
# aggregate mood | |
mood = aggregate_mood_from_history(history) | |
if mood: | |
reply = f"Playing {mood.capitalize()} playlist for you!" | |
else: | |
reply = raw | |
history[-1] = (user_input, reply) | |
return history, history, "" | |
async def bot_audio(history): | |
last = history[-1][1] | |
return text_to_speech(last) | |
def speech_callback(audio): | |
return speech_to_text(audio) | |
# βββ Build the Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks() as demo: | |
gr.Markdown("## π΅ Mood-Based Music Buddy") | |
chat = gr.Chatbot() | |
txt = gr.Textbox(placeholder="Type your mood...", label="Text") | |
send = gr.Button("Send") | |
mic = gr.Audio() | |
out_audio = gr.Audio(label="Response (Audio)", autoplay=True) | |
state = gr.State([]) | |
def init(): | |
greeting = "Hi! I'm your music buddyβtell me how youβre feeling today." | |
return [("", greeting)], [("", greeting)], None | |
demo.load(init, outputs=[state, chat, out_audio]) | |
txt.submit(user_turn, [txt, state], [state, chat, txt])\ | |
.then(bot_audio, [state], [out_audio]) | |
send.click(user_turn, [txt, state], [state, chat, txt])\ | |
.then(bot_audio, [state], [out_audio]) | |
mic.change(speech_callback, [mic], [txt])\ | |
.then(user_turn, [txt, state], [state, chat, txt])\ | |
.then(bot_audio, [state], [out_audio]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |
# import gradio as gr | |
# import requests | |
# from transformers import pipeline | |
# import edge_tts | |
# import tempfile | |
# import asyncio | |
# import os | |
# import json | |
# import time | |
# import logging | |
# # Set up logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# ENDPOINT_URL = "https://xzup8268xrmmxcma.us-east-1.aws.endpoints.huggingface.cloud/invocations" | |
# hf_token = os.getenv("HF_TOKEN") | |
# print(f"DEBUG: Starting application at {time.strftime('%Y-%m-%d %H:%M:%S')}") | |
# print(f"DEBUG: HF_TOKEN available: {bool(hf_token)}") | |
# print(f"DEBUG: Endpoint URL: {ENDPOINT_URL}") | |
# try: | |
# print("DEBUG: Loading ASR pipeline...") | |
# start_time = time.time() | |
# asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
# print(f"DEBUG: ASR pipeline loaded in {time.time() - start_time:.2f} seconds") | |
# except Exception as e: | |
# print(f"DEBUG: Error loading ASR pipeline: {e}") | |
# asr = None | |
# INITIAL_MESSAGE = "Hi! I'm your music buddyβtell me about your mood and the type of tunes you're in the mood for today!" | |
# def speech_to_text(speech): | |
# print(f"DEBUG: speech_to_text called with input: {speech is not None}") | |
# if speech is None: | |
# print("DEBUG: No speech input provided") | |
# return "" | |
# try: | |
# start_time = time.time() | |
# print("DEBUG: Starting speech recognition...") | |
# result = asr(speech)["text"] | |
# print(f"DEBUG: Speech recognition completed in {time.time() - start_time:.2f} seconds") | |
# print(f"DEBUG: Recognized text: '{result}'") | |
# return result | |
# except Exception as e: | |
# print(f"DEBUG: Error in speech_to_text: {e}") | |
# return "" | |
# def classify_mood(input_string): | |
# print(f"DEBUG: classify_mood called with: '{input_string}'") | |
# input_string = input_string.lower() | |
# mood_words = {"happy", "sad", "instrumental", "party"} | |
# for word in mood_words: | |
# if word in input_string: | |
# print(f"DEBUG: Mood classified as: {word}") | |
# return word, True | |
# print("DEBUG: No mood classified") | |
# return None, False | |
# def generate(prompt, history, temperature=0.1, max_new_tokens=2048): | |
# print(f"DEBUG: generate() called at {time.strftime('%H:%M:%S')}") | |
# print(f"DEBUG: Prompt length: {len(prompt)}") | |
# print(f"DEBUG: History length: {len(history)}") | |
# if not hf_token: | |
# error_msg = "Error: Hugging Face authentication required. Please set your HF_TOKEN." | |
# print(f"DEBUG: {error_msg}") | |
# return error_msg | |
# try: | |
# print("DEBUG: Formatting prompt...") | |
# start_time = time.time() | |
# formatted_prompt = format_prompt(prompt, history) | |
# print(f"DEBUG: Prompt formatted in {time.time() - start_time:.2f} seconds") | |
# print(f"DEBUG: Formatted prompt length: {len(formatted_prompt)}") | |
# headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"} | |
# payload = { | |
# "inputs": formatted_prompt, | |
# "parameters": { | |
# "temperature": temperature, | |
# "max_new_tokens": max_new_tokens | |
# } | |
# } | |
# print("DEBUG: Making API request...") | |
# api_start_time = time.time() | |
# response = requests.post(ENDPOINT_URL, headers=headers, json=payload, timeout=60) | |
# api_duration = time.time() - api_start_time | |
# print(f"DEBUG: API request completed in {api_duration:.2f} seconds") | |
# print(f"DEBUG: Response status code: {response.status_code}") | |
# if response.status_code == 200: | |
# print("DEBUG: Parsing API response...") | |
# result = response.json() | |
# output = result[0]["generated_text"] | |
# print(f"DEBUG: Generated output: '{output[:100]}...'") | |
# mood, is_classified = classify_mood(output) | |
# if is_classified: | |
# playlist_message = f"Playing {mood.capitalize()} playlist for you!" | |
# print(f"DEBUG: Returning playlist message: {playlist_message}") | |
# return playlist_message | |
# print(f"DEBUG: Returning generated output") | |
# return output | |
# else: | |
# error_msg = f"Error: {response.status_code} - {response.text}" | |
# print(f"DEBUG: API error: {error_msg}") | |
# return error_msg | |
# except requests.exceptions.Timeout: | |
# error_msg = "Error: API request timed out after 60 seconds" | |
# print(f"DEBUG: {error_msg}") | |
# return error_msg | |
# except Exception as e: | |
# error_msg = f"Error generating response: {str(e)}" | |
# print(f"DEBUG: Exception in generate(): {error_msg}") | |
# return error_msg | |
# def format_prompt(message, history): | |
# print("DEBUG: format_prompt called") | |
# fixed_prompt = """ | |
# You are a smart mood analyzer tasked with determining the user's mood for a music recommendation system. Your goal is to classify the user's mood into one of four categories: Happy, Sad, Instrumental, or Party. | |
# Instructions: | |
# 1. Engage in a conversation with the user to understand their mood. | |
# 2. Ask relevant questions to guide the conversation towards mood classification. | |
# 3. If the user's mood is clear, respond with a single word: "Happy", "Sad", "Instrumental", or "Party". | |
# 4. If the mood is unclear, continue the conversation with a follow-up question. | |
# 5. Limit the conversation to a maximum of 5 exchanges. | |
# 6. Do not classify the mood prematurely if it's not evident from the user's responses. | |
# 7. Focus on the user's emotional state rather than specific activities or preferences. | |
# 8. If unable to classify after 5 exchanges, respond with "Unclear" to indicate the need for more information. | |
# Remember: Your primary goal is mood classification. Stay on topic and guide the conversation towards understanding the user's emotional state. | |
# """ | |
# prompt = f"{fixed_prompt}\n" | |
# for i, (user_prompt, bot_response) in enumerate(history): | |
# prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n" | |
# if i == 3: | |
# prompt += "Note: This is the last exchange. Classify the mood if possible or respond with 'Unclear'.\n" | |
# prompt += f"User: {message}\nAssistant:" | |
# print(f"DEBUG: Final prompt length: {len(prompt)}") | |
# return prompt | |
# async def text_to_speech(text): | |
# print(f"DEBUG: text_to_speech called with text length: {len(text)}") | |
# try: | |
# start_time = time.time() | |
# print("DEBUG: Creating TTS communicate object...") | |
# communicate = edge_tts.Communicate(text) | |
# print("DEBUG: Creating temporary file...") | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
# tmp_path = tmp_file.name | |
# print(f"DEBUG: Saving TTS to: {tmp_path}") | |
# await communicate.save(tmp_path) | |
# duration = time.time() - start_time | |
# print(f"DEBUG: TTS completed in {duration:.2f} seconds") | |
# print(f"DEBUG: TTS file size: {os.path.getsize(tmp_path) if os.path.exists(tmp_path) else 'File not found'}") | |
# return tmp_path | |
# except Exception as e: | |
# print(f"DEBUG: TTS Error: {e}") | |
# return None | |
# def process_input(input_text, history): | |
# print(f"DEBUG: process_input called with text: '{input_text[:50]}...'") | |
# if not input_text: | |
# print("DEBUG: No input text provided") | |
# return history, history, "" | |
# print("DEBUG: Calling generate function...") | |
# start_time = time.time() | |
# response = generate(input_text, history) | |
# duration = time.time() - start_time | |
# print(f"DEBUG: generate() completed in {duration:.2f} seconds") | |
# print(f"DEBUG: Response: '{response[:100]}...'") | |
# history.append((input_text, response)) | |
# print(f"DEBUG: Updated history length: {len(history)}") | |
# return history, history, "" | |
# async def generate_audio(history): | |
# print(f"DEBUG: generate_audio called with history length: {len(history)}") | |
# if history and len(history) > 0: | |
# last_response = history[-1][1] | |
# print(f"DEBUG: Generating audio for: '{last_response[:50]}...'") | |
# start_time = time.time() | |
# audio_path = await text_to_speech(last_response) | |
# duration = time.time() - start_time | |
# print(f"DEBUG: Audio generation completed in {duration:.2f} seconds") | |
# return audio_path | |
# print("DEBUG: No history available for audio generation") | |
# return None | |
# async def init_chat(): | |
# print("DEBUG: init_chat called") | |
# try: | |
# history = [("", INITIAL_MESSAGE)] | |
# print("DEBUG: Generating initial audio...") | |
# start_time = time.time() | |
# audio_path = await text_to_speech(INITIAL_MESSAGE) | |
# duration = time.time() - start_time | |
# print(f"DEBUG: Initial audio generated in {duration:.2f} seconds") | |
# print("DEBUG: init_chat completed successfully") | |
# return history, history, audio_path | |
# except Exception as e: | |
# print(f"DEBUG: Error in init_chat: {e}") | |
# return [("", INITIAL_MESSAGE)], [("", INITIAL_MESSAGE)], None | |
# def handle_voice_upload(audio_file): | |
# print(f"DEBUG: handle_voice_upload called with file: {audio_file}") | |
# if audio_file is None: | |
# print("DEBUG: No audio file provided") | |
# return "" | |
# try: | |
# start_time = time.time() | |
# result = speech_to_text(audio_file) | |
# duration = time.time() - start_time | |
# print(f"DEBUG: Voice upload processing completed in {duration:.2f} seconds") | |
# return result | |
# except Exception as e: | |
# print(f"DEBUG: Error in handle_voice_upload: {e}") | |
# return "" | |
# print("DEBUG: Creating Gradio interface...") | |
# with gr.Blocks() as demo: | |
# gr.Markdown("# Mood-Based Music Recommender with Continuous Voice Chat") | |
# chatbot = gr.Chatbot() | |
# with gr.Row(): | |
# msg = gr.Textbox( | |
# placeholder="Type your message here...", | |
# label="Text Input", | |
# scale=4 | |
# ) | |
# submit = gr.Button("Send", scale=1) | |
# with gr.Row(): | |
# voice_input = gr.Audio( | |
# label="π€ Record your voice or upload audio file", | |
# sources=["microphone", "upload"], | |
# type="filepath" | |
# ) | |
# audio_output = gr.Audio(label="AI Response", autoplay=True) | |
# state = gr.State([]) | |
# print("DEBUG: Setting up Gradio event handlers...") | |
# demo.load(init_chat, outputs=[state, chatbot, audio_output]) | |
# def submit_and_generate_audio(input_text, history): | |
# print(f"DEBUG: submit_and_generate_audio called at {time.strftime('%H:%M:%S')}") | |
# start_time = time.time() | |
# new_state, new_chatbot, empty_msg = process_input(input_text, history) | |
# duration = time.time() - start_time | |
# print(f"DEBUG: submit_and_generate_audio completed in {duration:.2f} seconds") | |
# return new_state, new_chatbot, empty_msg | |
# msg.submit( | |
# submit_and_generate_audio, | |
# inputs=[msg, state], | |
# outputs=[state, chatbot, msg] | |
# ).then( | |
# generate_audio, | |
# inputs=[state], | |
# outputs=[audio_output] | |
# ) | |
# submit.click( | |
# submit_and_generate_audio, | |
# inputs=[msg, state], | |
# outputs=[state, chatbot, msg] | |
# ).then( | |
# generate_audio, | |
# inputs=[state], | |
# outputs=[audio_output] | |
# ) | |
# voice_input.upload( | |
# handle_voice_upload, | |
# inputs=[voice_input], | |
# outputs=[msg] | |
# ).then( | |
# submit_and_generate_audio, | |
# inputs=[msg, state], | |
# outputs=[state, chatbot, msg] | |
# ).then( | |
# generate_audio, | |
# inputs=[state], | |
# outputs=[audio_output] | |
# ) | |
# print("DEBUG: Gradio interface created successfully") | |
# if __name__ == "__main__": | |
# print("DEBUG: Launching Gradio app...") | |
# demo.launch(share=True, debug=True) |