import glob import io import logging import math import os import tarfile import uuid import safetensors import torch from transformers import WhisperFeatureExtractor, WhisperTokenizerFast import torchaudio from transformers import WhisperFeatureExtractor from speech_tokenizer.modeling_whisper import WhisperVQEncoder from flow_inference import AudioDecoder from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank from funasr.models.sense_voice.model import SenseVoiceSmall from .constants import ( AUD_CONTEXT_TOKEN, AUD_END_TOKEN, AUD_START_TOKEN, AUD_TAG_TOKEN, BOX_END_TOKEN, BOX_START_TOKEN, IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, IMG_TAG_TOKEN, PATCH_CONTEXT_TOKEN, PATCH_END_TOKEN, PATCH_START_TOKEN, QUAD_END_TOKEN, QUAD_START_TOKEN, REF_END_TOKEN, REF_START_TOKEN, VID_CONTEXT_TOKEN, VID_END_TOKEN, VID_START_TOKEN, VID_TAG_TOKEN, ) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def update_tokenizer_for_sensevoice_glm4voice(tokenizer): token_list = [ IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, AUD_START_TOKEN, AUD_END_TOKEN, AUD_CONTEXT_TOKEN, QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN, REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN, AUD_TAG_TOKEN, ] num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True) token_list = [f"<|audio_{i}|>" for i in range(16384)] num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False) # logger.info(f"tokenizer {tokenizer}") return tokenizer class SenseVoiceGLM4VoiceTokenizer: def __init__(self, model_name_or_path, flow_path=None, rank=None): self.model_name_or_path = model_name_or_path self.flow_path = flow_path if rank is None and torch.distributed.is_initialized(): rank = torch.distributed.get_rank() self.rank = rank % 8 else: self.rank = rank logger.info(f"{self.rank=}") self.sample_rate = 16000 self.is_discrete = True self.is_contiguous = True # # T A # text_audio_interval_ratio = [13, 26] # # T A T A T A # text_audio_interval_ratio = [1, 4, 3, 8, 4, 10] # # T A T A # text_audio_interval_ratio = [1, 10, 4, 10] # self.text_audio_interval_ratio = text_audio_interval_ratio def load_model(self): if hasattr(self, "whisper_model"): return import faulthandler faulthandler.enable() if self.rank is not None: self.device = f"cuda:{self.rank}" #torch.cuda.set_device(self.rank) else: self.device = "cpu" logger.info(f"{self.device=} Loading SenseVoiceSmall") from huggingface_hub import snapshot_download model_dir = snapshot_download(repo_id="FunAudioLLM/SenseVoiceSmall") _, self.kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=self.device) logger.info(f"{self.device=} Loading SenseVoiceSmall Done") logger.info(f"{self.device=} Loading GLM4VoiceTokenizer") self.whisper_model = ( WhisperVQEncoder.from_pretrained(self.model_name_or_path).eval().to(self.device) ) self.feature_extractor = WhisperFeatureExtractor.from_pretrained(self.model_name_or_path) if self.flow_path is not None: flow_config = os.path.join(self.flow_path, "config.yaml") flow_checkpoint = os.path.join(self.flow_path, "flow.pt") hift_checkpoint = os.path.join(self.flow_path, "hift.pt") # Flow & Hift self.audio_decoder = AudioDecoder( config_path=flow_config, flow_ckpt_path=flow_checkpoint, hift_ckpt_path=hift_checkpoint, device=self.device, ) logger.info(f"{self.device=} Loading GLM4VoiceTokenizer Done") def encode(self, audio_path, is_discrete=False, is_contiguous=True, **kwargs): if not hasattr(self, "whisper_model"): self.load_model() assert not (is_discrete and is_contiguous) assert is_discrete or is_contiguous if is_discrete: audio_tokens = extract_speech_token( self.whisper_model, self.feature_extractor, [audio_path], device=self.device )[0] return audio_tokens if is_contiguous: audio, sample_rate = torchaudio.load(audio_path) audio = audio.mean(0) if sample_rate != self.sample_rate: if sample_rate not in _resample_buffer: _resample_buffer[sample_rate] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.sample_rate ).to(self.device) audio = audio.to(self.device) audio = _resample_buffer[sample_rate](audio[None, :])[0, :] audio = audio.cpu() # resampler = torchaudio.transforms.Resample( # orig_freq=sample_rate, new_freq=self.sample_rate # ) # audio = resampler(audio[None, :])[0, :] # audio = audio.to(self.device) frontend = self.kwargs["frontend"] speech, speech_lengths = extract_fbank(audio, data_type="sound", frontend=frontend) speech = speech[0] # print(f"{speech_lengths=}") # print(f"{speech.size()=}") return speech def decode(self, audio_tokens, option_steps=10, **kwargs): if not hasattr(self, "whisper_model"): self.load_model() this_uuid = str(uuid.uuid4()) this_uuid = "abc" tts_token = torch.tensor(audio_tokens, device=self.device).unsqueeze(0) flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(self.device) prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device) tts_speech, tts_mel = self.audio_decoder.token2wav( tts_token, uuid=this_uuid, prompt_token=flow_prompt_speech_token.to(self.device), prompt_feat=prompt_speech_feat.to(self.device), finalize=True, option_steps=option_steps, ) tts_speechs = [] tts_speechs.append(tts_speech.squeeze()) tts_speech = torch.cat(tts_speechs, dim=-1).cpu() return tts_speech def apply_to_role(self, role, **kwargs): is_discrete = kwargs.get("is_discrete", False) if is_discrete and role in ["assistant", "gpt"]: return True is_contiguous = kwargs.get("is_contiguous", False) if is_contiguous and role in ["user", "human"]: return True return False _resample_buffer: dict[int, torchaudio.transforms.Resample] = {} def extract_speech_token(model, feature_extractor, utts, device="cuda"): with torch.no_grad(): audios, indices = [], [] for idx, utt in enumerate(utts): if isinstance(utt, tuple): audio, sample_rate = utt else: audio, sample_rate = torchaudio.load(utt) audio = audio.to(device) if sample_rate != 16000: if sample_rate not in _resample_buffer: _resample_buffer[sample_rate] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000 ).to(device) audio = _resample_buffer[sample_rate](audio) # if audio.shape[0] > 1: # audio = audio[:1] audio = audio[0] audio = audio.cpu().numpy() time_step = 0 while time_step * 16000 < audio.shape[0]: audio_segment = audio[time_step * 16000 : (time_step + 30) * 16000] audios.append(audio_segment) indices.append(idx) time_step += 30 pooling_kernel_size = model.config.pooling_kernel_size or 1 stride = ( model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length ) all_speech_tokens = [[] for _ in range(len(utts))] batch_size = 128 for start in range(0, len(audios), batch_size): features = feature_extractor( audios[start : start + batch_size], sampling_rate=16000, return_attention_mask=True, return_tensors="pt", device=device, padding="longest", pad_to_multiple_of=stride, ) features = features.to(device=device) outputs = model(**features) speech_tokens = outputs.quantized_token_ids attention_mask = features.attention_mask[ :, :: model.conv1.stride[0] * model.conv2.stride[0] ] attention_mask = attention_mask[:, :: model.config.pooling_kernel_size] assert attention_mask.shape == speech_tokens.shape for i in range(len(speech_tokens)): idx = indices[start + i] speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() all_speech_tokens[idx].extend(speech_token) return all_speech_tokens