import os import torch import torch.nn as nn import torchaudio from transformers import LlamaForCausalLM, WhisperModel, AutoTokenizer from huggingface_hub import hf_hub_download import json class FrozenModelWrapper: def __init__(self, model): self.model = model for param in self.model.parameters(): param.requires_grad = False def forward(self, *args, **kwargs): with torch.no_grad(): return self.model(*args, **kwargs) def to(self, device): self.model = self.model.to(device) return self class AudioProjector(nn.Module): def __init__(self, input_dim, output_dim, hidden_dim=None): super().__init__() if hidden_dim is None: hidden_dim = (input_dim + output_dim) // 2 self.layers = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, output_dim), nn.LayerNorm(output_dim) ) def forward(self, x): return self.layers(x) class LoRALayer(nn.Module): def __init__(self, in_dim, out_dim, rank=8, alpha=16): super().__init__() self.lora_A = nn.Parameter(torch.zeros(rank, in_dim)) self.lora_B = nn.Parameter(torch.randn(out_dim, rank) * 0.01) self.rank = rank self.alpha = alpha self.scaling = alpha / rank def forward(self, x): return (x @ (self.lora_B @ self.lora_A).T) * self.scaling def lora_forward_hook(module, input, output, lora_layer): # Add LoRA output to the original linear layer output return output + lora_layer(input[0]) class AudioLLM(nn.Module): def __init__(self, llama_model, whisper_encoder, projector, lora_layers, tokenizer): super().__init__() self.llama = FrozenModelWrapper(llama_model) self.whisper_encoder = FrozenModelWrapper(whisper_encoder) self.projector = projector self.lora_layers = lora_layers self.tokenizer = tokenizer # Register forward hooks to apply LoRA self.hooks = [] for name, module in self.llama.model.named_modules(): if name in self.lora_layers: hook = module.register_forward_hook( lambda mod, inp, out, n=name: lora_forward_hook(mod, inp, out, self.lora_layers[n]) ) self.hooks.append(hook) self.audio_start_token = "" def _process_audio(self, audio_path, max_audio_length=30, sample_rate=16000): # Process audio file for model input if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") waveform, file_sample_rate = torchaudio.load(audio_path) max_frames = max_audio_length * sample_rate # Trim or pad audio if waveform.shape[1] > max_frames: waveform = waveform[:, :max_frames] elif waveform.shape[1] < max_frames: pad_len = max_frames - waveform.shape[1] waveform = nn.functional.pad(waveform, (0, pad_len)) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample if needed if file_sample_rate != sample_rate: resampler = torchaudio.transforms.Resample( orig_freq=file_sample_rate, new_freq=sample_rate ) waveform = resampler(waveform) # Add batch dimension waveform = waveform.unsqueeze(0) return waveform def generate(self, input_ids=None, attention_mask=None, audio_features=None, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, **kwargs): # Generate text with optional audio context device = next(self.llama.model.parameters()).device # Move inputs to the model's device if input_ids is not None: input_ids = input_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if audio_features is not None: audio_features = audio_features.to(device) # Get the initial text embeddings text_embeddings = self.llama.model.model.embed_tokens(input_ids) # Process audio if provided if audio_features is not None: audio_features = audio_features.squeeze(1) with torch.no_grad(): whisper_output = self.whisper_encoder.model(audio_features) whisper_embeddings = whisper_output.last_hidden_state projected_audio = self.projector(whisper_embeddings) # Get embeddings for audio delimiter tokens audio_start_id = self.tokenizer.convert_tokens_to_ids(self.audio_start_token) audio_end_id = self.tokenizer.convert_tokens_to_ids(self.audio_end_token) audio_start_tokens = torch.tensor([[audio_start_id]], device=device) audio_end_tokens = torch.tensor([[audio_end_id]], device=device) audio_start_embedding = self.llama.model.model.embed_tokens(audio_start_tokens) audio_end_embedding = self.llama.model.model.embed_tokens(audio_end_tokens) # Concatenate: + text_embeddings combined_embeddings = torch.cat([ audio_start_embedding, projected_audio, audio_end_embedding, text_embeddings ], dim=1) # Create extended attention mask that includes audio tokens batch_size, text_seq_len = attention_mask.shape audio_seq_len = combined_embeddings.shape[1] - text_embeddings.shape[1] audio_attention = torch.ones(batch_size, audio_seq_len, device=device) combined_attention_mask = torch.cat([audio_attention, attention_mask], dim=1) else: combined_embeddings = text_embeddings combined_attention_mask = attention_mask # Set generation parameters generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "do_sample": do_sample, "pad_token_id": self.tokenizer.pad_token_id, "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, } # Add any additional kwargs generation_config.update(kwargs) # Generate tokens with torch.no_grad(): outputs = self.llama.model.generate( inputs_embeds=combined_embeddings, attention_mask=combined_attention_mask, **generation_config ) # Calculate where the actual generated content starts input_length = input_ids.shape[1] if audio_features is not None: input_length += audio_seq_len # Get only the newly generated tokens generated_tokens = outputs[0, input_length:] # Decode to text generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return generated_text def load_audio_llm(repo_id, llama_path=None, whisper_path=None, device="cuda"): # Load AudioLLM model from Hugging Face Hub # Download config and weights config_file = hf_hub_download(repo_id=repo_id, filename="config.json") projector_file = hf_hub_download(repo_id=repo_id, filename="projector.pt") lora_file = hf_hub_download(repo_id=repo_id, filename="lora_layers.pt") # Load configuration with open(config_file, "r") as f: config = json.load(f) # Use provided model paths or fall back to config llama_path = llama_path or config["llama_model_path"] whisper_path = whisper_path or config["whisper_model_path"] lora_rank = config.get("lora_rank", 64) print(f"Loading LLaMA model from {llama_path}...") llama = LlamaForCausalLM.from_pretrained(llama_path, device_map=device) print(f"Loading Whisper model from {whisper_path}...") whisper_encoder = WhisperModel.from_pretrained(whisper_path, device_map=device).encoder # Load tokenizer try: tokenizer_path = os.path.join(os.path.dirname(config_file), "tokenizer") if os.path.exists(tokenizer_path): tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) print("Loaded tokenizer from repository") else: tokenizer = AutoTokenizer.from_pretrained(llama_path) tokenizer.pad_token = tokenizer.eos_token # Add special tokens for audio audio_tokens = {"additional_special_tokens": [""]} tokenizer.add_special_tokens(audio_tokens) print("Added special tokens to tokenizer") except Exception as e: print(f"Error loading tokenizer: {e}. Falling back to base tokenizer.") tokenizer = AutoTokenizer.from_pretrained(llama_path) tokenizer.pad_token = tokenizer.eos_token # Resize token embeddings if needed llama.resize_token_embeddings(len(tokenizer)) # Load projector state projector_state = torch.load(projector_file, map_location=device) # Determine dimensions from state dict first_layer = list(projector_state.keys())[0] if "layers.0.weight" in projector_state: input_dim = projector_state["layers.0.weight"].shape[1] output_dim = projector_state["layers.2.weight"].shape[0] else: # Approximate based on typical Whisper and LLaMA dimensions input_dim = whisper_encoder.config.d_model # typically 1024 for large Whisper output_dim = llama.config.hidden_size # typically 4096 for 7B LLaMA # Create and load projector projector = AudioProjector(input_dim, output_dim) projector.load_state_dict(projector_state) projector = projector.to(device) # Load LoRA layers lora_layers_state = torch.load(lora_file, map_location=device) lora_layers = {} # Reinstantiate LoRA layers for name, state_dict in lora_layers_state.items(): # Extract dimensions from state dict lora_A = state_dict["lora_A"] lora_B = state_dict["lora_B"] rank = lora_A.shape[0] in_dim = lora_A.shape[1] out_dim = lora_B.shape[0] # Create layer lora_layer = LoRALayer(in_dim, out_dim, rank=rank) lora_layer.load_state_dict(state_dict) lora_layers[name] = lora_layer.to(device) # Create model model = AudioLLM( llama_model=llama, whisper_encoder=whisper_encoder, projector=projector, lora_layers=lora_layers, tokenizer=tokenizer ) return model def transcribe_and_generate(model, audio_path, prompt="", max_new_tokens=256, temperature=0.7): # Process audio and generate text response device = next(model.llama.model.parameters()).device # Process audio audio_features = model._process_audio(audio_path) audio_features = audio_features.to(device) # Tokenize prompt encoded_prompt = model.tokenizer( prompt, return_tensors="pt", padding="max_length", max_length=512, truncation=True ) input_ids = encoded_prompt.input_ids attention_mask = encoded_prompt.attention_mask # Generate response response = model.generate( input_ids=input_ids, attention_mask=attention_mask, audio_features=audio_features, max_new_tokens=max_new_tokens, temperature=temperature ) return response # Example usage if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="AudioLLM Inference") parser.add_argument("--repo_id", type=str, required=True, help="HuggingFace repo ID") parser.add_argument("--audio_path", type=str, required=True, help="Path to audio file") parser.add_argument("--prompt", type=str, default="", help="Text prompt") parser.add_argument("--max_new_tokens", type=int, default=256, help="Max tokens to generate") parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature") parser.add_argument("--llama_path", type=str, default=None, help="Optional: path to LLaMA model") parser.add_argument("--whisper_path", type=str, default=None, help="Optional: path to Whisper model") parser.add_argument("--device", type=str, default="cuda", help="Device (cuda or cpu)") args = parser.parse_args() # Load model model = load_audio_llm( repo_id=args.repo_id, llama_path=args.llama_path, whisper_path=args.whisper_path, device=args.device ) # Generate response response = transcribe_and_generate( model=model, audio_path=args.audio_path, prompt=args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature ) print(f"Prompt: {args.prompt}") print(f"Response: {response}")