Spaces:
Runtime error
Runtime error
| import random | |
| import torch | |
| from slam_llm.utils.model_utils import get_custom_model_factory | |
| from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift | |
| import whisper | |
| import numpy as np | |
| from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME | |
| import os | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import hf_hub_download | |
| from typing import Callable | |
| def update_progress(progress_callback: Callable[[str], None] | None, message: str): | |
| if progress_callback: | |
| progress_callback(message) | |
| def pull_model_ckpt(): | |
| if not os.path.exists(CKPT_LOCAL_DIR): | |
| os.makedirs(CKPT_LOCAL_DIR) | |
| if os.path.exists(CKPT_PATH): | |
| return | |
| hf_hub_download( | |
| repo_id=CKPT_REPO, | |
| filename=CKPT_NAME, | |
| local_dir=CKPT_LOCAL_DIR, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| pull_model_ckpt() | |
| def extract_audio_feature(audio_path, mel_size): | |
| print("Extracting audio features from", audio_path) | |
| audio_raw = whisper.load_audio(audio_path) | |
| audio_raw = whisper.pad_or_trim(audio_raw) | |
| audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0) | |
| audio_length = (audio_mel.shape[0] + 1) // 2 | |
| audio_length = audio_length // 5 | |
| audio_res = audio_mel | |
| return audio_res, audio_length | |
| def get_input_ids(length, special_token_a, special_token_t, vocab_config): | |
| input_ids = [] | |
| for i in range(vocab_config.code_layer): | |
| input_ids_item = [] | |
| input_ids_item.append(layershift(vocab_config.input_a, i)) | |
| input_ids_item += [layershift(vocab_config.pad_a, i)] * length | |
| input_ids_item += [ | |
| (layershift(vocab_config.eoa, i)), | |
| layershift(special_token_a, i), | |
| ] | |
| input_ids.append(torch.tensor(input_ids_item).unsqueeze(0)) | |
| input_id_T = torch.tensor( | |
| [vocab_config.input_t] | |
| + [vocab_config.pad_t] * length | |
| + [vocab_config.eot, special_token_t] | |
| ) | |
| input_ids.append(input_id_T.unsqueeze(0)) | |
| return input_ids | |
| def generate_from_wav( | |
| wav_path, model, codec_decoder, dataset_config, decode_config, device | |
| ): | |
| mel_size = dataset_config.mel_size | |
| prompt = dataset_config.prompt | |
| prompt_template = "USER: {}\n ASSISTANT: " | |
| vocab_config = dataset_config.vocab_config | |
| special_token_a = vocab_config.answer_a | |
| special_token_t = vocab_config.answer_t | |
| code_layer = vocab_config.code_layer | |
| task_type = dataset_config.task_type | |
| audio_mel, audio_length = extract_audio_feature(wav_path, mel_size) | |
| prompt = prompt_template.format(prompt) | |
| prompt_ids = model.tokenizer.encode(prompt) | |
| prompt_length = len(prompt_ids) | |
| prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) | |
| example_ids = get_input_ids( | |
| audio_length + prompt_length, special_token_a, special_token_t, vocab_config | |
| ) | |
| text_layer = example_ids[code_layer] | |
| text_layer = torch.cat( | |
| ( | |
| text_layer[:, : audio_length + 1], | |
| prompt_ids.unsqueeze(0), | |
| text_layer[:, -2:], | |
| ), | |
| dim=1, | |
| ) # <bos> <audio> <prompt> <eos> <task> | |
| example_ids[code_layer] = text_layer | |
| input_length = audio_length | |
| example_mask = example_ids[0][0].ge(-1) | |
| example_ids = torch.stack(example_ids).squeeze() | |
| input_ids = example_ids.unsqueeze(0).to(device) | |
| attention_mask = example_mask.unsqueeze(0).to(device) | |
| audio_mel = audio_mel.unsqueeze(0).to(device) | |
| input_length = torch.tensor([input_length]).to(device) | |
| audio_length = torch.tensor([audio_length]).to(device) | |
| task_type = [task_type] | |
| modality_mask = torch.zeros_like(attention_mask) | |
| padding_left = 1 # +1 for <bos> | |
| modality_mask[0, padding_left : padding_left + audio_length] = True | |
| batch = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "audio_mel": audio_mel, | |
| "input_length": input_length, | |
| "audio_length": audio_length, | |
| "modality_mask": modality_mask, | |
| "task_types": task_type, | |
| } | |
| model_outputs = model.generate(**batch, **decode_config) | |
| text_outputs = model_outputs[7] | |
| audio_outputs = model_outputs[:7] | |
| output_text = model.tokenizer.decode( | |
| text_outputs, add_special_tokens=False, skip_special_tokens=True | |
| ) | |
| if decode_config.decode_text_only: | |
| return None, output_text | |
| audio_tokens = [audio_outputs[layer] for layer in range(7)] | |
| audiolist = reconscruct_snac(audio_tokens) | |
| audio = reconstruct_tensors(audiolist) | |
| with torch.inference_mode(): | |
| audio_hat = codec_decoder.decode(audio) | |
| return audio_hat, output_text | |
| model = None | |
| codec_decoder = None | |
| device = None | |
| def generate( | |
| wav_path: str, progress_callback: Callable[[str], None] | None = None | |
| ) -> tuple[np.ndarray, int | float]: | |
| global model, codec_decoder, device | |
| config = OmegaConf.structured(InferenceConfig()) | |
| train_config, model_config, dataset_config, decode_config = ( | |
| config.train_config, | |
| config.model_config, | |
| config.dataset_config, | |
| config.decode_config, | |
| ) | |
| torch.cuda.manual_seed(train_config.seed) | |
| torch.manual_seed(train_config.seed) | |
| random.seed(train_config.seed) | |
| if model is None or codec_decoder is None or device is None: | |
| update_progress(progress_callback, "Loading model") | |
| model_factory = get_custom_model_factory(model_config) | |
| model, _ = model_factory(train_config, model_config, CKPT_PATH) | |
| codec_decoder = model.codec_decoder | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| update_progress(progress_callback, "Generating") | |
| output_wav, output_text = generate_from_wav( | |
| wav_path, model, codec_decoder, dataset_config, decode_config, device | |
| ) | |
| return output_wav.squeeze().cpu().numpy(), 24000 | |
| if __name__ == "__main__": | |
| wav_path = "sample.wav" | |
| generate(wav_path) | |