Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import subprocess | |
import sys | |
from transformers import AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration, TextIteratorStreamer | |
import torchaudio | |
from threading import Thread | |
from qwen_omni_utils import process_mm_info | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
print(f"PyTorch version: {torch.__version__}") | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
print(f"CUDA version: {torch.version.cuda}") | |
# Check environment | |
result = subprocess.run([sys.executable, '-c', | |
'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA: {torch.version.cuda}")'], | |
capture_output=True, text=True) | |
print(result.stdout) | |
# Model paths and configuration | |
model_path_1 = "./model" | |
model_path_2 = "./model2" | |
base_model_id = "Qwen/Qwen2.5-Omni-7B" | |
# Dictionary to store loaded models and processors | |
loaded_models = {} | |
# Load the model and processor | |
def load_model(model_path): | |
# Check if model is already loaded | |
if model_path in loaded_models: | |
return loaded_models[model_path] | |
# Load the processor from the base model | |
processor = AutoProcessor.from_pretrained( | |
base_model_id, | |
trust_remote_code=True, | |
) | |
# Load the model | |
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
model.eval() | |
# Store in cache | |
loaded_models[model_path] = (model, processor) | |
return model, processor | |
# Initialize first model and processor | |
model, processor = load_model(model_path_1) | |
def process_output(output): | |
if "<think>" in output: | |
rest = output.split("<think>")[1] | |
output = "<think>\n" + rest | |
elif "<semantic_elements>" in output: | |
rest = output.split("<semantic_elements>")[1] | |
output = "<semantic_elements>\n" + rest | |
elif "<answer>" in output: | |
rest = output.split("<answer>")[1] | |
output = "<answer>\n" + rest | |
elif "</think>" in output: | |
rest = output.split("</think>")[0] | |
output = rest + "\n</think>\n\n" | |
elif "</semantic_elements>" in output: | |
rest = output.split("</semantic_elements>")[0] | |
output = rest + "\n</semantic_elements>\n\n" | |
elif "</answer>" in output: | |
rest = output.split("</answer>")[0] | |
output = rest + "\n</answer>\n" | |
output = output.replace("\\n", "\n") | |
output = output.replace("\\", "\n") | |
output = output.replace("\n-", "-") | |
return output | |
# Custom Stopping Criteria | |
class StopOnSpecificToken(StoppingCriteria): | |
def __init__(self, stop_token_sequences: list[list[int]], device: str = "cuda"): | |
super().__init__() | |
self.stop_sequence_tensors = [] | |
for seq in stop_token_sequences: | |
if seq: # Only process non-empty sequences | |
self.stop_sequence_tensors.append(torch.tensor(seq, dtype=torch.long, device=device)) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
for stop_seq_tensor in self.stop_sequence_tensors: | |
current_sequence_length = input_ids.shape[-1] | |
stop_sequence_length = stop_seq_tensor.shape[-1] | |
# stop_sequence_length should be > 0 due to the check in __init__ | |
if stop_sequence_length == 0: # Should ideally not be reached if seq was non-empty | |
continue | |
if current_sequence_length >= stop_sequence_length: | |
# Check the last tokens of the last sequence in the batch | |
last_tokens = input_ids[0, -stop_sequence_length:] | |
if torch.equal(last_tokens, stop_seq_tensor): | |
return True | |
return False | |
# Keep only the process_audio_streaming function that's actually used in the Gradio interface | |
def process_audio_streaming(audio_file, model_choice, question="Describe this audio in detail"): | |
# Load the selected model | |
model_path = model_path_2 if model_choice == "Think" else model_path_1 | |
model, processor = load_model(model_path) | |
# Load and process the audio with torchaudio | |
waveform, sr = torchaudio.load(audio_file) | |
# Resample to 16kHz if needed | |
if sr != processor.feature_extractor.sampling_rate: | |
waveform = torchaudio.functional.resample(waveform, sr, processor.feature_extractor.sampling_rate) | |
sr = processor.feature_extractor.sampling_rate | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Get the audio data as numpy array | |
y = waveform.squeeze().numpy() | |
# Set sampling rate for the processor | |
sampling_rate = processor.feature_extractor.sampling_rate | |
# Define prompts based on model choice | |
prompt_think_semantics = f"You are given a question and an audio clip. Your task is to answer the question based on the audio clip. First, think about the question and the audio clip and put your thoughts in <think> and </think> tags. Then reason about the semantic elements involved in the audio clip and put your reasoning in <semantic_elements> and </semantic_elements> tags. Then answer the question based on the audio clip, put your answer in <answer> and </answer> tags. {question}" | |
instruction_text = "" | |
if model_choice == "Think + Semantics": | |
instruction_text = prompt_think_semantics | |
else: # Default to the question if no specific model processing is chosen. | |
instruction_text = question | |
# Create conversation format | |
conversation = [ | |
{"role": "system", "content": [ | |
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}]}, | |
{"role": "user", "content": [ | |
{"type": "audio", "audio": y}, | |
{"type": "text", "text": instruction_text} | |
]} | |
] | |
# Format the chat | |
chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) | |
# Process multimedia info using qwen_omni_utils | |
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) | |
# Process the inputs | |
inputs = processor( | |
text=chat_text, | |
audio=audios, | |
images=images, | |
videos=videos, | |
return_tensors="pt", | |
sampling_rate=sampling_rate, | |
).to(model.device) | |
# Create a standard streamer instance | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, | |
timeout=10., | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
# Initialize variables for buffering | |
accumulated_output = "" | |
buffer = "" | |
stop_sequence = "</answer>" | |
stop_found = False | |
answer_token_ids = processor.tokenizer.encode("</answer>", add_special_tokens=False) | |
processed_stop_token_ids = [answer_token_ids] | |
# Get the device of the model to ensure tensors are on the same device | |
model_device = next(model.parameters()).device | |
custom_stopping_criteria = StopOnSpecificToken(stop_token_sequences=processed_stop_token_ids, device=model_device.type) | |
stopping_criteria = StoppingCriteriaList([custom_stopping_criteria]) | |
# Generate the output with streaming | |
with torch.no_grad(): | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=768, | |
do_sample=False, | |
stopping_criteria=stopping_criteria | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Process the stream with buffering | |
for output in streamer: | |
if stop_found: | |
break | |
output = process_output(output) | |
buffer += output | |
# Check if stop sequence is in the buffer | |
if stop_sequence in buffer: | |
# Output everything up to and including the stop sequence | |
before_stop = buffer.split(stop_sequence)[0] | |
accumulated_output += before_stop + stop_sequence | |
yield accumulated_output | |
stop_found = True | |
break | |
else: | |
# Check if we can safely output part of the buffer | |
# Keep the last N characters where N is the length of the stop sequence | |
if len(buffer) > len(stop_sequence): | |
# Output all but the last len(stop_sequence) characters | |
safe_output = buffer[:-len(stop_sequence)] | |
buffer = buffer[-len(stop_sequence):] | |
accumulated_output += safe_output | |
yield accumulated_output | |
# Output any remaining buffer if no stop sequence was found | |
if not stop_found and buffer: | |
accumulated_output += buffer | |
yield accumulated_output | |
# Create Gradio interface for audio processing | |
audio_demo = gr.Interface( | |
fn=process_audio_streaming, | |
inputs=[ | |
gr.Audio(type="filepath", label="Upload Audio"), | |
gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics"), | |
gr.Textbox(label="Question", value="Describe this audio in detail") | |
], | |
outputs=gr.Textbox(label="Generated Output", lines=30), | |
title="AudSemThinker", | |
description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.", | |
examples=[["examples/1.wav", "Think + Semantics", "Describe this audio in detail"]], | |
cache_examples=False, | |
live=True | |
) | |
# Launch the apps | |
if __name__ == "__main__": | |
audio_demo.launch() |