Spaces:
Running
on
L4
Running
on
L4
import gradio as gr | |
import spaces | |
import uuid | |
import torch | |
from datetime import timedelta | |
from lhotse import Recording | |
from lhotse.dataset import DynamicCutSampler | |
from nemo.collections.speechlm2 import SALM | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
SAMPLE_RATE = 16000 # Hz | |
MAX_AUDIO_MINUTES = 120 # wont try to transcribe if longer than this | |
CHUNK_SECONDS = 40.0 # max audio length seen by the model | |
BATCH_SIZE = 192 # for parallel transcription of audio longer than CHUNK_SECONDS | |
model = SALM.from_pretrained("nvidia/canary-qwen-2.5b").bfloat16().eval().to(device) | |
def timestamp(idx: int): | |
b = str(timedelta(seconds= idx * CHUNK_SECONDS)) | |
e = str(timedelta(seconds=(idx + 1) * CHUNK_SECONDS)) | |
return f"[{b} - {e}]" | |
def as_batches(audio_filepath, utt_id): | |
rec = Recording.from_file(audio_filepath, recording_id=utt_id) | |
if rec.duration / 60.0 > MAX_AUDIO_MINUTES: | |
raise gr.Error( | |
f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. " | |
"If you wish, you may trim the audio using the Audio viewer in Step 1 " | |
"(click on the scissors icon to start trimming audio)." | |
) | |
cut = rec.resample(SAMPLE_RATE).to_cut() | |
if cut.num_channels > 1: | |
cut = cut.to_mono(mono_downmix=True) | |
return DynamicCutSampler(cut.cut_into_windows(CHUNK_SECONDS), max_cuts=BATCH_SIZE) | |
def transcribe(audio_filepath): | |
if audio_filepath is None: | |
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") | |
utt_id = uuid.uuid4() | |
pred_text = [] | |
pred_text_ts = [] | |
chunk_idx = 0 | |
for batch in as_batches(audio_filepath, str(utt_id)): | |
audio, audio_lens = batch.load_audio(collate=True) | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]] * len(batch), | |
audios=torch.as_tensor(audio).to(device, non_blocking=True), | |
audio_lens=torch.as_tensor(audio_lens).to(device, non_blocking=True), | |
max_new_tokens=256, | |
) | |
texts = [model.tokenizer.ids_to_text(oids) for oids in output_ids.cpu()] | |
for t in texts: | |
pred_text.append(t) | |
pred_text_ts.append(f"{timestamp(chunk_idx)} {t}\n\n") | |
chunk_idx += 1 | |
return ''.join(pred_text_ts), ' '.join(pred_text) | |
def postprocess(transcript, prompt): | |
with torch.inference_mode(), model.llm.disable_adapter(): | |
output_ids = model.generate( | |
prompts=[[{"role": "user", "content": f"{prompt}\n\n{transcript}"}]], | |
max_new_tokens=2048, | |
) | |
ans = model.tokenizer.ids_to_text(output_ids[0].cpu()) | |
ans = ans.split("<|im_start|>assistant")[-1] # get rid of the prompt | |
if "<think>" in ans: | |
ans = ans.split("<think>")[-1] | |
thoughts, ans = ans.split("</think>") # get rid of the thinking | |
else: | |
thoughts = "" | |
return ans.strip(), thoughts | |
def disable_buttons(): | |
return gr.update(interactive=False), gr.update(interactive=False) | |
def enable_buttons(): | |
return gr.update(interactive=True), gr.update(interactive=True) | |
with gr.Blocks( | |
title="NeMo Canary-Qwen-2.5B Model", | |
css=""" | |
textarea { font-size: 18px;} | |
#transcript_box span { | |
font-size: 18px; | |
font-weight: bold; | |
} | |
""", | |
theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md ) | |
) as demo: | |
gr.HTML( | |
"<h1 style='text-align: center'>NeMo Canary-Qwen-2.5B model: Transcribe and prompt</h1>" | |
"<p>Canary-Qwen-2.5B is an ASR model capable of transcribing speech to text (ASR mode) and using its inner Qwen3-1.7B LLM for answering questions about the transcript (LLM mode).</p>" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
"<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>" | |
"<p style='color: #A0A0A0;'>This demo supports audio files up to 2 hours long." | |
) | |
audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath") | |
with gr.Column(): | |
gr.HTML("<p><b>Step 2:</b> Transcribe the audio.</p>") | |
asr_button = gr.Button( | |
value="Run model", | |
variant="primary", # make "primary" so it stands out (default is "secondary") | |
) | |
transcript_box = gr.Textbox( | |
label="Model Transcript", | |
elem_id="transcript_box", | |
) | |
raw_transcript = gr.State() | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML("<p><b>Step 3:</b> Prompt the model.</p>") | |
prompt_box = gr.Textbox( | |
"Give me a TL;DR:", | |
label="Prompt", | |
elem_id="prompt_box", | |
) | |
with gr.Column(): | |
gr.HTML("<p><b>Step 4:</b> See the outcome!</p>") | |
llm_button = gr.Button( | |
value="Apply the prompt", | |
variant="primary", # make "primary" so it stands out (default is "secondary") | |
) | |
magic_box = gr.Textbox( | |
label="Assistant's Response", | |
elem_id="magic_box", | |
) | |
think_box = gr.Textbox( | |
label="Assistant's Thinking", | |
elem_id="think_box", | |
) | |
with gr.Row(): | |
gr.HTML( | |
"<p style='text-align: center'>" | |
"🐤 <a href='https://huggingface.co/nvidia/canary-qwen-2.5b' target='_blank'>Canary-Qwen-2.5B model</a> | " | |
"🧑💻 <a href='https://github.com/NVIDIA/NeMo' target='_blank'>NeMo Repository</a>" | |
"</p>" | |
) | |
asr_button.click( | |
disable_buttons, | |
outputs=[asr_button, llm_button], | |
trigger_mode="once", | |
).then( | |
fn=transcribe, | |
inputs=[audio_file], | |
outputs=[transcript_box, raw_transcript] | |
).then( | |
enable_buttons, | |
outputs=[asr_button, llm_button], | |
) | |
llm_button.click( | |
disable_buttons, | |
outputs=[asr_button, llm_button], | |
trigger_mode="once", | |
).then( | |
fn=postprocess, | |
inputs=[raw_transcript, prompt_box], | |
outputs=[magic_box, think_box] | |
).then( | |
enable_buttons, | |
outputs=[asr_button, llm_button], | |
) | |
demo.queue() | |
demo.launch() | |