|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
from transformers import LlamaForCausalLM, WhisperModel, AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
|
|
class FrozenModelWrapper: |
|
def __init__(self, model): |
|
self.model = model |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, *args, **kwargs): |
|
with torch.no_grad(): |
|
return self.model(*args, **kwargs) |
|
|
|
def to(self, device): |
|
self.model = self.model.to(device) |
|
return self |
|
|
|
class AudioProjector(nn.Module): |
|
def __init__(self, input_dim, output_dim, hidden_dim=None): |
|
super().__init__() |
|
if hidden_dim is None: |
|
hidden_dim = (input_dim + output_dim) // 2 |
|
|
|
self.layers = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Linear(hidden_dim, output_dim), |
|
nn.LayerNorm(output_dim) |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
class LoRALayer(nn.Module): |
|
def __init__(self, in_dim, out_dim, rank=8, alpha=16): |
|
super().__init__() |
|
self.lora_A = nn.Parameter(torch.zeros(rank, in_dim)) |
|
self.lora_B = nn.Parameter(torch.randn(out_dim, rank) * 0.01) |
|
self.rank = rank |
|
self.alpha = alpha |
|
self.scaling = alpha / rank |
|
|
|
def forward(self, x): |
|
return (x @ (self.lora_B @ self.lora_A).T) * self.scaling |
|
|
|
def lora_forward_hook(module, input, output, lora_layer): |
|
|
|
return output + lora_layer(input[0]) |
|
|
|
class AudioLLM(nn.Module): |
|
def __init__(self, llama_model, whisper_encoder, projector, lora_layers, tokenizer): |
|
super().__init__() |
|
|
|
self.llama = FrozenModelWrapper(llama_model) |
|
self.whisper_encoder = FrozenModelWrapper(whisper_encoder) |
|
self.projector = projector |
|
self.lora_layers = lora_layers |
|
self.tokenizer = tokenizer |
|
|
|
|
|
self.hooks = [] |
|
for name, module in self.llama.model.named_modules(): |
|
if name in self.lora_layers: |
|
hook = module.register_forward_hook( |
|
lambda mod, inp, out, n=name: lora_forward_hook(mod, inp, out, self.lora_layers[n]) |
|
) |
|
self.hooks.append(hook) |
|
|
|
self.audio_start_token = "<audio>" |
|
self.audio_end_token = "</audio>" |
|
|
|
def _process_audio(self, audio_path, max_audio_length=30, sample_rate=16000): |
|
|
|
if not os.path.exists(audio_path): |
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
waveform, file_sample_rate = torchaudio.load(audio_path) |
|
max_frames = max_audio_length * sample_rate |
|
|
|
|
|
if waveform.shape[1] > max_frames: |
|
waveform = waveform[:, :max_frames] |
|
elif waveform.shape[1] < max_frames: |
|
pad_len = max_frames - waveform.shape[1] |
|
waveform = nn.functional.pad(waveform, (0, pad_len)) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
if file_sample_rate != sample_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=file_sample_rate, new_freq=sample_rate |
|
) |
|
waveform = resampler(waveform) |
|
|
|
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
return waveform |
|
|
|
def generate(self, |
|
input_ids=None, |
|
attention_mask=None, |
|
audio_features=None, |
|
max_new_tokens=256, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
**kwargs): |
|
|
|
device = next(self.llama.model.parameters()).device |
|
|
|
|
|
if input_ids is not None: |
|
input_ids = input_ids.to(device) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(device) |
|
if audio_features is not None: |
|
audio_features = audio_features.to(device) |
|
|
|
|
|
text_embeddings = self.llama.model.model.embed_tokens(input_ids) |
|
|
|
|
|
if audio_features is not None: |
|
audio_features = audio_features.squeeze(1) |
|
|
|
with torch.no_grad(): |
|
whisper_output = self.whisper_encoder.model(audio_features) |
|
whisper_embeddings = whisper_output.last_hidden_state |
|
|
|
projected_audio = self.projector(whisper_embeddings) |
|
|
|
|
|
audio_start_id = self.tokenizer.convert_tokens_to_ids(self.audio_start_token) |
|
audio_end_id = self.tokenizer.convert_tokens_to_ids(self.audio_end_token) |
|
|
|
audio_start_tokens = torch.tensor([[audio_start_id]], device=device) |
|
audio_end_tokens = torch.tensor([[audio_end_id]], device=device) |
|
|
|
audio_start_embedding = self.llama.model.model.embed_tokens(audio_start_tokens) |
|
audio_end_embedding = self.llama.model.model.embed_tokens(audio_end_tokens) |
|
|
|
|
|
combined_embeddings = torch.cat([ |
|
audio_start_embedding, |
|
projected_audio, |
|
audio_end_embedding, |
|
text_embeddings |
|
], dim=1) |
|
|
|
|
|
batch_size, text_seq_len = attention_mask.shape |
|
audio_seq_len = combined_embeddings.shape[1] - text_embeddings.shape[1] |
|
audio_attention = torch.ones(batch_size, audio_seq_len, device=device) |
|
combined_attention_mask = torch.cat([audio_attention, attention_mask], dim=1) |
|
else: |
|
combined_embeddings = text_embeddings |
|
combined_attention_mask = attention_mask |
|
|
|
|
|
generation_config = { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"do_sample": do_sample, |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
"bos_token_id": self.tokenizer.bos_token_id, |
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
} |
|
|
|
|
|
generation_config.update(kwargs) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.llama.model.generate( |
|
inputs_embeds=combined_embeddings, |
|
attention_mask=combined_attention_mask, |
|
**generation_config |
|
) |
|
|
|
|
|
input_length = input_ids.shape[1] |
|
if audio_features is not None: |
|
input_length += audio_seq_len |
|
|
|
|
|
generated_tokens = outputs[0, input_length:] |
|
|
|
|
|
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
def load_audio_llm(repo_id, llama_path=None, whisper_path=None, device="cuda"): |
|
|
|
|
|
config_file = hf_hub_download(repo_id=repo_id, filename="config.json") |
|
projector_file = hf_hub_download(repo_id=repo_id, filename="projector.pt") |
|
lora_file = hf_hub_download(repo_id=repo_id, filename="lora_layers.pt") |
|
|
|
|
|
with open(config_file, "r") as f: |
|
config = json.load(f) |
|
|
|
|
|
llama_path = llama_path or config["llama_model_path"] |
|
whisper_path = whisper_path or config["whisper_model_path"] |
|
lora_rank = config.get("lora_rank", 64) |
|
|
|
print(f"Loading LLaMA model from {llama_path}...") |
|
llama = LlamaForCausalLM.from_pretrained(llama_path, device_map=device) |
|
|
|
print(f"Loading Whisper model from {whisper_path}...") |
|
whisper_encoder = WhisperModel.from_pretrained(whisper_path, device_map=device).encoder |
|
|
|
|
|
try: |
|
tokenizer_path = os.path.join(os.path.dirname(config_file), "tokenizer") |
|
if os.path.exists(tokenizer_path): |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
print("Loaded tokenizer from repository") |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained(llama_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
audio_tokens = {"additional_special_tokens": ["<audio>", "</audio>"]} |
|
tokenizer.add_special_tokens(audio_tokens) |
|
print("Added special tokens to tokenizer") |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {e}. Falling back to base tokenizer.") |
|
tokenizer = AutoTokenizer.from_pretrained(llama_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
llama.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
projector_state = torch.load(projector_file, map_location=device) |
|
|
|
|
|
first_layer = list(projector_state.keys())[0] |
|
if "layers.0.weight" in projector_state: |
|
input_dim = projector_state["layers.0.weight"].shape[1] |
|
output_dim = projector_state["layers.2.weight"].shape[0] |
|
else: |
|
|
|
input_dim = whisper_encoder.config.d_model |
|
output_dim = llama.config.hidden_size |
|
|
|
|
|
projector = AudioProjector(input_dim, output_dim) |
|
projector.load_state_dict(projector_state) |
|
projector = projector.to(device) |
|
|
|
|
|
lora_layers_state = torch.load(lora_file, map_location=device) |
|
lora_layers = {} |
|
|
|
|
|
for name, state_dict in lora_layers_state.items(): |
|
|
|
lora_A = state_dict["lora_A"] |
|
lora_B = state_dict["lora_B"] |
|
|
|
rank = lora_A.shape[0] |
|
in_dim = lora_A.shape[1] |
|
out_dim = lora_B.shape[0] |
|
|
|
|
|
lora_layer = LoRALayer(in_dim, out_dim, rank=rank) |
|
lora_layer.load_state_dict(state_dict) |
|
lora_layers[name] = lora_layer.to(device) |
|
|
|
|
|
model = AudioLLM( |
|
llama_model=llama, |
|
whisper_encoder=whisper_encoder, |
|
projector=projector, |
|
lora_layers=lora_layers, |
|
tokenizer=tokenizer |
|
) |
|
|
|
return model |
|
|
|
def transcribe_and_generate(model, audio_path, prompt="", max_new_tokens=256, temperature=0.7): |
|
|
|
device = next(model.llama.model.parameters()).device |
|
|
|
|
|
audio_features = model._process_audio(audio_path) |
|
audio_features = audio_features.to(device) |
|
|
|
|
|
encoded_prompt = model.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=512, |
|
truncation=True |
|
) |
|
|
|
input_ids = encoded_prompt.input_ids |
|
attention_mask = encoded_prompt.attention_mask |
|
|
|
|
|
response = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
audio_features=audio_features, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature |
|
) |
|
|
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="AudioLLM Inference") |
|
parser.add_argument("--repo_id", type=str, required=True, help="HuggingFace repo ID") |
|
parser.add_argument("--audio_path", type=str, required=True, help="Path to audio file") |
|
parser.add_argument("--prompt", type=str, default="", help="Text prompt") |
|
parser.add_argument("--max_new_tokens", type=int, default=256, help="Max tokens to generate") |
|
parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature") |
|
parser.add_argument("--llama_path", type=str, default=None, help="Optional: path to LLaMA model") |
|
parser.add_argument("--whisper_path", type=str, default=None, help="Optional: path to Whisper model") |
|
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda or cpu)") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
model = load_audio_llm( |
|
repo_id=args.repo_id, |
|
llama_path=args.llama_path, |
|
whisper_path=args.whisper_path, |
|
device=args.device |
|
) |
|
|
|
|
|
response = transcribe_and_generate( |
|
model=model, |
|
audio_path=args.audio_path, |
|
prompt=args.prompt, |
|
max_new_tokens=args.max_new_tokens, |
|
temperature=args.temperature |
|
) |
|
|
|
print(f"Prompt: {args.prompt}") |
|
print(f"Response: {response}") |
|
|