Spaces:
Sleeping
Sleeping
import os | |
# Fix OpenMP environment variable issue | |
os.environ['OMP_NUM_THREADS'] = '1' | |
import gradio as gr | |
from nemo.collections.speechlm2.models import SALM | |
import torch | |
import tempfile | |
# Load model using official NVIDIA NeMo approach | |
model_id = "nvidia/canary-qwen-2.5b" | |
print("Loading NVIDIA Canary-Qwen-2.5B model using NeMo...") | |
model = SALM.from_pretrained(model_id) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
def generate_text(prompt, max_tokens=200, temperature=0.7, top_p=0.9): | |
"""Generate text using the NVIDIA NeMo model (LLM mode)""" | |
try: | |
# Use LLM mode (text-only) as per official documentation | |
with model.llm.disable_adapter(): | |
answer_ids = model.generate( | |
prompts=[[{"role": "user", "content": prompt}]], | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Convert IDs to text using model's tokenizer | |
# response = model.tokenizer.ids_to_text(answer_ids[0].cpu()) | |
response = model.tokenizer.ids_to_text(answer_ids[0].to(device)) | |
return response | |
except Exception as e: | |
return f"Error generating text: {str(e)}" | |
def transcribe_audio(audio_file, user_prompt="Transcribe the following:"): | |
"""Transcribe audio using ASR mode""" | |
try: | |
if audio_file is None: | |
return "No audio file provided" | |
# Use ASR mode (speech-to-text) as per official documentation | |
answer_ids = model.generate( | |
prompts=[ | |
[{"role": "user", "content": f"{user_prompt} {model.audio_locator_tag}", "audio": [audio_file]}] | |
], | |
max_new_tokens=128, | |
) | |
# Convert IDs to text | |
# transcript = model.tokenizer.ids_to_text(answer_ids[0].cpu()) | |
transcript = model.tokenizer.ids_to_text(answer_ids[0].to(device)) | |
return transcript | |
except Exception as e: | |
return f"Error transcribing audio: {str(e)}" | |
def chat_interface(message, history, max_tokens, temperature, top_p): | |
"""Chat interface for Gradio""" | |
# Build conversation context | |
conversation = "" | |
for user_msg, bot_msg in history: | |
conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n" | |
conversation += f"User: {message}\nAssistant: " | |
# Generate response | |
response = generate_text(conversation, max_tokens, temperature, top_p) | |
# Update history | |
history.append((message, response)) | |
return "", history | |
# Create Gradio interface | |
with gr.Blocks(title="NVIDIA Canary-Qwen-2.5B Chat") as demo: | |
gr.HTML(""" | |
<div style="text-align: center;"> | |
<h1>π€ NVIDIA Canary-Qwen-2.5B</h1> | |
<p>Official NeMo implementation - Speech-to-Text & Text Generation</p> | |
<p><strong>Capabilities:</strong> Audio Transcription + Text Chat</p> | |
</div> | |
""") | |
with gr.Tab("π€ Audio Transcription (ASR)"): | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="Upload Audio File (.wav or .flac)", | |
type="filepath", | |
format="wav" | |
) | |
asr_prompt = gr.Textbox( | |
label="Custom Prompt (optional)", | |
value="Transcribe the following:", | |
placeholder="Enter custom transcription prompt..." | |
) | |
transcribe_btn = gr.Button("π€ Transcribe Audio", variant="primary") | |
transcript_output = gr.Textbox( | |
label="Transcription Result", | |
lines=8, | |
max_lines=15 | |
) | |
gr.Examples( | |
examples=[ | |
["Transcribe the following:"], | |
["Please transcribe this audio in detail:"], | |
["Convert this speech to text:"] | |
], | |
inputs=[asr_prompt] | |
) | |
with gr.Tab("π¬ Text Chat (LLM)"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(height=400) | |
msg = gr.Textbox(label="Your message", placeholder="Type here...") | |
with gr.Row(): | |
submit_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear Chat") | |
with gr.Column(scale=1): | |
gr.Markdown("### βοΈ Settings") | |
max_tokens = gr.Slider( | |
minimum=10, maximum=500, value=200, step=10, | |
label="Max Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, maximum=2.0, value=0.7, step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.9, step=0.05, | |
label="Top-p" | |
) | |
with gr.Tab("π Single Generation"): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt...", | |
lines=5 | |
) | |
generate_btn = gr.Button("Generate", variant="primary") | |
output_text = gr.Textbox( | |
label="Generated Text", | |
lines=10, | |
max_lines=20 | |
) | |
with gr.Row(): | |
single_max_tokens = gr.Slider(10, 500, 200, label="Max Tokens") | |
single_temperature = gr.Slider(0.1, 2.0, 0.7, label="Temperature") | |
single_top_p = gr.Slider(0.1, 1.0, 0.9, label="Top-p") | |
with gr.Tab("βΉοΈ Model Info"): | |
gr.Markdown(""" | |
## NVIDIA Canary-Qwen-2.5B Model Information | |
### Capabilities: | |
- π€ **Audio Transcription (ASR)**: Convert speech to text | |
- π¬ **Text Generation (LLM)**: Chat and text completion | |
- π― **Multimodal**: Combines audio and text processing | |
### Model Details: | |
- **Size**: 2.5 billion parameters | |
- **Framework**: NVIDIA NeMo | |
- **Audio Input**: 16kHz mono-channel .wav or .flac files | |
- **Languages**: Multiple languages supported | |
### Usage Tips: | |
1. **For Audio**: Upload .wav or .flac files (16kHz recommended) | |
2. **For Text**: Use natural language prompts | |
3. **Custom Prompts**: You can modify transcription prompts | |
4. **Parameters**: Adjust temperature and tokens for different outputs | |
### Official Documentation: | |
- [Model Card](https://huggingface.co/nvidia/canary-qwen-2.5b) | |
- [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) | |
""") | |
# Event handlers | |
transcribe_btn.click( | |
transcribe_audio, | |
inputs=[audio_input, asr_prompt], | |
outputs=[transcript_output] | |
) | |
# Event handlers | |
submit_btn.click( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[msg, chatbot] | |
) | |
msg.submit( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[msg, chatbot] | |
) | |
clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
generate_btn.click( | |
generate_text, | |
inputs=[prompt_input, single_max_tokens, single_temperature, single_top_p], | |
outputs=[output_text] | |
) | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
["Explain quantum computing in simple terms"], | |
["Write a short story about AI"], | |
["What are the benefits of renewable energy?"], | |
["How do neural networks work?"], | |
["Summarize the key points about machine learning"] | |
], | |
inputs=[prompt_input] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |