Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from pydub import AudioSegment | |
import numpy as np | |
import io | |
from scipy.io import wavfile | |
from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor | |
from transformers.utils.import_utils import is_flash_attn_2_available | |
import base64 | |
from scipy.io.wavfile import write | |
import os | |
# Global model variables | |
model = None | |
processor = None | |
def load_model(): | |
"""Load model and processor once""" | |
global model, processor | |
if model is None: | |
model = ColQwen2_5Omni.from_pretrained( | |
"vidore/colqwen-omni-v0.1", | |
torch_dtype=torch.bfloat16, | |
device_map="cpu", # Start on CPU for ZeroGPU | |
attn_implementation="eager" # ZeroGPU compatible | |
).eval() | |
processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1") | |
return model, processor | |
def chunk_audio(audio_file_path, chunk_length=30): | |
"""Split audio into chunks""" | |
try: | |
# audio_file_path is already a string path when type="filepath" | |
audio = AudioSegment.from_file(audio_file_path) | |
audios = [] | |
target_rate = 16000 | |
chunk_length_ms = chunk_length * 1000 | |
for i in range(0, len(audio), chunk_length_ms): | |
chunk = audio[i:i + chunk_length_ms] | |
chunk = chunk.set_channels(1).set_frame_rate(target_rate) | |
buf = io.BytesIO() | |
chunk.export(buf, format="wav") | |
buf.seek(0) | |
rate, data = wavfile.read(buf) | |
audios.append(data) | |
return audios | |
except Exception as e: | |
raise gr.Error(f"Error processing audio file: {str(e)}. Make sure ffmpeg is installed.") | |
def embed_audio_chunks(audios): | |
"""Embed audio chunks using GPU""" | |
model, processor = load_model() | |
model = model.to('cuda') | |
# Process in batches | |
from torch.utils.data import DataLoader | |
dataloader = DataLoader( | |
dataset=audios, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: processor.process_audios(x) | |
) | |
embeddings = [] | |
for batch_doc in dataloader: | |
with torch.no_grad(): | |
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
# Move model back to CPU to free GPU memory | |
model = model.to('cpu') | |
torch.cuda.empty_cache() | |
return embeddings | |
def search_audio(query, embeddings, audios, top_k=5): | |
"""Search for relevant audio chunks""" | |
model, processor = load_model() | |
model = model.to('cuda') | |
# Process query | |
batch_queries = processor.process_queries([query]).to(model.device) | |
with torch.no_grad(): | |
query_embeddings = model(**batch_queries) | |
# Score against all embeddings | |
scores = processor.score_multi_vector(query_embeddings, embeddings) | |
top_indices = scores[0].topk(top_k).indices.tolist() | |
# Move model back to CPU | |
model = model.to('cpu') | |
torch.cuda.empty_cache() | |
return top_indices | |
def audio_to_base64(data, rate=16000): | |
"""Convert audio data to base64""" | |
buf = io.BytesIO() | |
write(buf, rate, data) | |
buf.seek(0) | |
encoded_string = base64.b64encode(buf.read()).decode("utf-8") | |
return encoded_string | |
def process_audio_rag(audio_file_path, query, chunk_length=30, use_openai=False, openai_key=None): | |
"""Main processing function""" | |
if not audio_file_path: | |
return "Please upload an audio file", None, None | |
if not query: | |
return "Please enter a search query", None, None | |
try: | |
# Chunk audio | |
audios = chunk_audio(audio_file_path, chunk_length) | |
# Embed chunks | |
embeddings = embed_audio_chunks(audios) | |
# Search for relevant chunks | |
top_indices = search_audio(query, embeddings, audios) | |
# Prepare results | |
result_text = f"Found {len(top_indices)} relevant audio chunks:\n" | |
result_text += f"Chunk indices: {top_indices}\n\n" | |
# Save first result as audio file | |
first_chunk_path = "result_chunk.wav" | |
wavfile.write(first_chunk_path, 16000, audios[top_indices[0]]) | |
# Optional: Use OpenAI for answer generation | |
if use_openai and openai_key: | |
from openai import OpenAI | |
client = OpenAI(api_key=openai_key) | |
content = [{"type": "text", "text": f"Answer the query using the audio files. Query: {query}"}] | |
for idx in top_indices[:3]: # Use top 3 chunks | |
content.extend([ | |
{"type": "text", "text": f"Audio chunk #{idx}:"}, | |
{ | |
"type": "input_audio", | |
"input_audio": { | |
"data": audio_to_base64(audios[idx]), | |
"format": "wav" | |
} | |
} | |
]) | |
try: | |
completion = client.chat.completions.create( | |
model="gpt-4o-audio-preview", | |
messages=[{"role": "user", "content": content}] | |
) | |
result_text += f"\nWritten answer: {completion.choices[0].message.content}" | |
except Exception as e: | |
result_text += f"\nError: {str(e)}" | |
# Create audio visualization | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
ax.plot(audios[top_indices[0]]) | |
ax.set_title(f"Waveform of top matching chunk (#{top_indices[0]})") | |
ax.set_xlabel("Samples") | |
ax.set_ylabel("Amplitude") | |
plt.tight_layout() | |
return result_text, first_chunk_path, fig | |
except Exception as e: | |
return f"Error: {str(e)}", None, None | |
# Create Gradio interface | |
with gr.Blocks(title="AudioRAG Demo") as demo: | |
gr.Markdown("# AudioRAG Demo - Semantic Audio Search") | |
gr.Markdown(""" | |
This demo builds on the work from the ColQwen team, expanding retrieval capabilities beyond images to include audio and video. | |
Unlike traditional methods, this model searches directly through raw audio without converting it to text. It understands semantic meaning in sound, speech, and audio patterns, making "AudioRAG" a real possibility. | |
๐ [Blog post](https://huggingface.co/blog/manu/colqwen-omni-omnimodal-retrieval) | ๐ค [Model on Hugging Face](https://huggingface.co/vidore/colqwen-omni-v0.1) | ๐ [Colab Notebook](https://colab.research.google.com/drive/1YOlTWfLbiyQqfq1SlqHA2iME1R-nH4aS#scrollTo=w7UyXtEcK0lA) | ๐๏ธ Sample from [Newsroom Robots](https://www.newsroomrobots.com/p/how-open-source-ai-puts-newsrooms) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
query_input = gr.Textbox(label="Search Query", placeholder="What are you looking for in the audio?") | |
chunk_length = gr.Slider(minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)") | |
with gr.Accordion("API key for textual answer (Optional)", open=False): | |
gr.Markdown("Generate a textual answer based on the retrieved audio chunks with an OpenAI api key") | |
use_openai = gr.Checkbox(label="Generate textual answer from retrieved audio") | |
openai_key = gr.Textbox(label="OpenAI API Key", type="password") | |
search_btn = gr.Button("Search Audio", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Results", lines=10) | |
output_audio = gr.Audio(label="Top Matching Audio Chunk", type="filepath") | |
gr.Examples( | |
examples=[ | |
["test.m4a", "Whoโs the podcast host?", 30], | |
], | |
inputs=[audio_input, query_input, chunk_length] | |
) | |
search_btn.click( | |
fn=process_audio_rag, | |
inputs=[audio_input, query_input, chunk_length, use_openai, openai_key], | |
outputs=[output_text, output_audio] | |
) | |
if __name__ == "__main__": | |
# Load model on startup | |
load_model() | |
demo.launch() |