audio-llama-v1.1 / inference.py
cdreetz's picture
Upload AudioLLM model weights
90bc656
import os
import torch
import torch.nn as nn
import torchaudio
from transformers import LlamaForCausalLM, WhisperModel, AutoTokenizer
from huggingface_hub import hf_hub_download
import json
class FrozenModelWrapper:
def __init__(self, model):
self.model = model
for param in self.model.parameters():
param.requires_grad = False
def forward(self, *args, **kwargs):
with torch.no_grad():
return self.model(*args, **kwargs)
def to(self, device):
self.model = self.model.to(device)
return self
class AudioProjector(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=None):
super().__init__()
if hidden_dim is None:
hidden_dim = (input_dim + output_dim) // 2
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim),
nn.LayerNorm(output_dim)
)
def forward(self, x):
return self.layers(x)
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank=8, alpha=16):
super().__init__()
self.lora_A = nn.Parameter(torch.zeros(rank, in_dim))
self.lora_B = nn.Parameter(torch.randn(out_dim, rank) * 0.01)
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
def forward(self, x):
return (x @ (self.lora_B @ self.lora_A).T) * self.scaling
def lora_forward_hook(module, input, output, lora_layer):
# Add LoRA output to the original linear layer output
return output + lora_layer(input[0])
class AudioLLM(nn.Module):
def __init__(self, llama_model, whisper_encoder, projector, lora_layers, tokenizer):
super().__init__()
self.llama = FrozenModelWrapper(llama_model)
self.whisper_encoder = FrozenModelWrapper(whisper_encoder)
self.projector = projector
self.lora_layers = lora_layers
self.tokenizer = tokenizer
# Register forward hooks to apply LoRA
self.hooks = []
for name, module in self.llama.model.named_modules():
if name in self.lora_layers:
hook = module.register_forward_hook(
lambda mod, inp, out, n=name: lora_forward_hook(mod, inp, out, self.lora_layers[n])
)
self.hooks.append(hook)
self.audio_start_token = "<audio>"
self.audio_end_token = "</audio>"
def _process_audio(self, audio_path, max_audio_length=30, sample_rate=16000):
# Process audio file for model input
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
waveform, file_sample_rate = torchaudio.load(audio_path)
max_frames = max_audio_length * sample_rate
# Trim or pad audio
if waveform.shape[1] > max_frames:
waveform = waveform[:, :max_frames]
elif waveform.shape[1] < max_frames:
pad_len = max_frames - waveform.shape[1]
waveform = nn.functional.pad(waveform, (0, pad_len))
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if needed
if file_sample_rate != sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=file_sample_rate, new_freq=sample_rate
)
waveform = resampler(waveform)
# Add batch dimension
waveform = waveform.unsqueeze(0)
return waveform
def generate(self,
input_ids=None,
attention_mask=None,
audio_features=None,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
**kwargs):
# Generate text with optional audio context
device = next(self.llama.model.parameters()).device
# Move inputs to the model's device
if input_ids is not None:
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if audio_features is not None:
audio_features = audio_features.to(device)
# Get the initial text embeddings
text_embeddings = self.llama.model.model.embed_tokens(input_ids)
# Process audio if provided
if audio_features is not None:
audio_features = audio_features.squeeze(1)
with torch.no_grad():
whisper_output = self.whisper_encoder.model(audio_features)
whisper_embeddings = whisper_output.last_hidden_state
projected_audio = self.projector(whisper_embeddings)
# Get embeddings for audio delimiter tokens
audio_start_id = self.tokenizer.convert_tokens_to_ids(self.audio_start_token)
audio_end_id = self.tokenizer.convert_tokens_to_ids(self.audio_end_token)
audio_start_tokens = torch.tensor([[audio_start_id]], device=device)
audio_end_tokens = torch.tensor([[audio_end_id]], device=device)
audio_start_embedding = self.llama.model.model.embed_tokens(audio_start_tokens)
audio_end_embedding = self.llama.model.model.embed_tokens(audio_end_tokens)
# Concatenate: <audio> + audio_embeddings + </audio> + text_embeddings
combined_embeddings = torch.cat([
audio_start_embedding,
projected_audio,
audio_end_embedding,
text_embeddings
], dim=1)
# Create extended attention mask that includes audio tokens
batch_size, text_seq_len = attention_mask.shape
audio_seq_len = combined_embeddings.shape[1] - text_embeddings.shape[1]
audio_attention = torch.ones(batch_size, audio_seq_len, device=device)
combined_attention_mask = torch.cat([audio_attention, attention_mask], dim=1)
else:
combined_embeddings = text_embeddings
combined_attention_mask = attention_mask
# Set generation parameters
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.pad_token_id,
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Add any additional kwargs
generation_config.update(kwargs)
# Generate tokens
with torch.no_grad():
outputs = self.llama.model.generate(
inputs_embeds=combined_embeddings,
attention_mask=combined_attention_mask,
**generation_config
)
# Calculate where the actual generated content starts
input_length = input_ids.shape[1]
if audio_features is not None:
input_length += audio_seq_len
# Get only the newly generated tokens
generated_tokens = outputs[0, input_length:]
# Decode to text
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
def load_audio_llm(repo_id, llama_path=None, whisper_path=None, device="cuda"):
# Load AudioLLM model from Hugging Face Hub
# Download config and weights
config_file = hf_hub_download(repo_id=repo_id, filename="config.json")
projector_file = hf_hub_download(repo_id=repo_id, filename="projector.pt")
lora_file = hf_hub_download(repo_id=repo_id, filename="lora_layers.pt")
# Load configuration
with open(config_file, "r") as f:
config = json.load(f)
# Use provided model paths or fall back to config
llama_path = llama_path or config["llama_model_path"]
whisper_path = whisper_path or config["whisper_model_path"]
lora_rank = config.get("lora_rank", 64)
print(f"Loading LLaMA model from {llama_path}...")
llama = LlamaForCausalLM.from_pretrained(llama_path, device_map=device)
print(f"Loading Whisper model from {whisper_path}...")
whisper_encoder = WhisperModel.from_pretrained(whisper_path, device_map=device).encoder
# Load tokenizer
try:
tokenizer_path = os.path.join(os.path.dirname(config_file), "tokenizer")
if os.path.exists(tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print("Loaded tokenizer from repository")
else:
tokenizer = AutoTokenizer.from_pretrained(llama_path)
tokenizer.pad_token = tokenizer.eos_token
# Add special tokens for audio
audio_tokens = {"additional_special_tokens": ["<audio>", "</audio>"]}
tokenizer.add_special_tokens(audio_tokens)
print("Added special tokens to tokenizer")
except Exception as e:
print(f"Error loading tokenizer: {e}. Falling back to base tokenizer.")
tokenizer = AutoTokenizer.from_pretrained(llama_path)
tokenizer.pad_token = tokenizer.eos_token
# Resize token embeddings if needed
llama.resize_token_embeddings(len(tokenizer))
# Load projector state
projector_state = torch.load(projector_file, map_location=device)
# Determine dimensions from state dict
first_layer = list(projector_state.keys())[0]
if "layers.0.weight" in projector_state:
input_dim = projector_state["layers.0.weight"].shape[1]
output_dim = projector_state["layers.2.weight"].shape[0]
else:
# Approximate based on typical Whisper and LLaMA dimensions
input_dim = whisper_encoder.config.d_model # typically 1024 for large Whisper
output_dim = llama.config.hidden_size # typically 4096 for 7B LLaMA
# Create and load projector
projector = AudioProjector(input_dim, output_dim)
projector.load_state_dict(projector_state)
projector = projector.to(device)
# Load LoRA layers
lora_layers_state = torch.load(lora_file, map_location=device)
lora_layers = {}
# Reinstantiate LoRA layers
for name, state_dict in lora_layers_state.items():
# Extract dimensions from state dict
lora_A = state_dict["lora_A"]
lora_B = state_dict["lora_B"]
rank = lora_A.shape[0]
in_dim = lora_A.shape[1]
out_dim = lora_B.shape[0]
# Create layer
lora_layer = LoRALayer(in_dim, out_dim, rank=rank)
lora_layer.load_state_dict(state_dict)
lora_layers[name] = lora_layer.to(device)
# Create model
model = AudioLLM(
llama_model=llama,
whisper_encoder=whisper_encoder,
projector=projector,
lora_layers=lora_layers,
tokenizer=tokenizer
)
return model
def transcribe_and_generate(model, audio_path, prompt="", max_new_tokens=256, temperature=0.7):
# Process audio and generate text response
device = next(model.llama.model.parameters()).device
# Process audio
audio_features = model._process_audio(audio_path)
audio_features = audio_features.to(device)
# Tokenize prompt
encoded_prompt = model.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True
)
input_ids = encoded_prompt.input_ids
attention_mask = encoded_prompt.attention_mask
# Generate response
response = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
audio_features=audio_features,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return response
# Example usage
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="AudioLLM Inference")
parser.add_argument("--repo_id", type=str, required=True, help="HuggingFace repo ID")
parser.add_argument("--audio_path", type=str, required=True, help="Path to audio file")
parser.add_argument("--prompt", type=str, default="", help="Text prompt")
parser.add_argument("--max_new_tokens", type=int, default=256, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature")
parser.add_argument("--llama_path", type=str, default=None, help="Optional: path to LLaMA model")
parser.add_argument("--whisper_path", type=str, default=None, help="Optional: path to Whisper model")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda or cpu)")
args = parser.parse_args()
# Load model
model = load_audio_llm(
repo_id=args.repo_id,
llama_path=args.llama_path,
whisper_path=args.whisper_path,
device=args.device
)
# Generate response
response = transcribe_and_generate(
model=model,
audio_path=args.audio_path,
prompt=args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature
)
print(f"Prompt: {args.prompt}")
print(f"Response: {response}")