# -------------------------------------------------------- # SenseTime # Copyright (c) 2025 SenseTime # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import warnings from typing import Any, List, Optional, Tuple, Union import re import json import math import librosa import numpy as np from PIL import Image from decord import VideoReader, cpu from torch import nn import torch import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import (GenerationConfig, Qwen3ForCausalLM, WhisperFeatureExtractor) from transformers.modeling_utils import PreTrainedModel import onnxruntime import torchaudio.compliance.kaldi as kaldi import torchaudio from transformers.utils.hub import cached_file from .configuration_interactiveomni import InteractiveOmniConfig from .modeling_intern_vit import InternVisionModel from .modeling_whisper import AudioWhisperModel from .modeling_voicelm import VoiceLM from .conversation import get_conv_template from .modeling_flow import CausalMaskedDiffWithXvec from .modeling_hifigan import HiFTGenerator import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) IMG_START_TOKEN = '' IMG_END_TOKEN = '' IMG_CONTEXT_TOKEN = '' AUDIO_START_TOKEN = '' AUDIO_CONTEXT_TOKEN = '' class InteractiveOmniModel(PreTrainedModel): config_class = InteractiveOmniConfig main_input_name = 'pixel_values' base_model_prefix = 'language_model' _no_split_modules = ['InternVisionModel', 'AudioWhisperModel', 'Qwen3DecoderLayer', 'Qwen2DecoderLayer'] def __init__(self, config: InteractiveOmniConfig, vision_model=None, language_model=None, audio_model=None): super().__init__(config) image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size self.select_layer = config.select_layer self.template = config.template self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.audio_feature_extractor = WhisperFeatureExtractor(**config.audio_preprocessor_config) self.transform = self.build_transform(input_size=image_size) self.campplus_session = None self.default_speaker_embedding = None self.default_wav_path = None logger.info(f'num_image_token: {self.num_image_token}') logger.info(f'ps_version: {self.ps_version}') if vision_model is not None: self.vision_model = vision_model else: self.vision_model = InternVisionModel(config.vision_config) if audio_model is not None: self.audio_model = audio_model else: self.audio_model = AudioWhisperModel(config.audio_config) if language_model is not None: self.language_model = language_model else: self.language_model = Qwen3ForCausalLM(config.llm_config) self.voicelm_model = VoiceLM(config.voicelm_config) self.flow_model = CausalMaskedDiffWithXvec(config.flow_config).float() self.hifigan_model = HiFTGenerator(config.hifigan_config).float() vit_hidden_size = config.vision_config.hidden_size audio_hidden_size = config.audio_config.d_model llm_hidden_size = config.llm_config.hidden_size self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) self.mlp2 = nn.Sequential( nn.LayerNorm(audio_hidden_size), nn.Linear(audio_hidden_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) self.mlp_llm2voicelm = nn.Sequential( nn.LayerNorm(llm_hidden_size), nn.Linear(llm_hidden_size, config.voicelm_config.llm_input_size), nn.GELU(), nn.Linear(config.voicelm_config.llm_input_size, config.voicelm_config.llm_input_size) ) self.gate = nn.Sequential( nn.Linear(2 * llm_hidden_size, llm_hidden_size), nn.Sigmoid() ) self.img_context_token_id = None self.audio_context_token_id = None self.neftune_alpha = None self.post_init() pass def fusion(self, rep, emb): gate = self.gate(torch.cat([rep, emb], dim=-1)) return rep * gate + emb * (1 - gate) def __load_campplus_session(self, campplus_path:str): '''''' logger.info(f"load campplus session: {campplus_path}") option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 campplus_session = onnxruntime.InferenceSession( campplus_path, sess_options=option, providers=["CPUExecutionProvider"], ) self.campplus_session = campplus_session return campplus_session def extract_speaker_embedding(self, prompt_wav:str): '''extract speaker embedding tensor''' logger.info(f"extract speaker embedding: {prompt_wav}") target_sr = 16000 prompt_speech_16k, sample_rate = torchaudio.load(prompt_wav) prompt_speech_16k = prompt_speech_16k.mean(dim=0, keepdim=True) if sample_rate != target_sr: assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) prompt_speech_16k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(prompt_speech_16k) feat = kaldi.fbank( prompt_speech_16k, num_mel_bins=80, dither=0, sample_frequency=target_sr, ) feat = feat - feat.mean(dim=0, keepdim=True) speaker_embedding = self.campplus_session.run( None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}, )[0].flatten().tolist() speaker_embedding = torch.tensor([speaker_embedding]) return speaker_embedding def build_transform(self, input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(self, image, min_num=1, max_num=6, image_size=448): assert min_num == 1 original_width, original_height = image.size log_ratio = math.log(original_width / original_height) ratio = original_width * original_height / (image_size * image_size) multiple = min(math.ceil(ratio), max_num) if multiple <= 1: return [1, 1] candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i > max_num: continue candidate_split_grids_nums.append(i) candidate_grids = [] for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error return best_grid def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): target_aspect_ratio = self.find_closest_aspect_ratio(image, min_num, max_num, image_size) target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image(self, image, input_size=448, max_num=12): if not isinstance(image, Image.Image): image = Image.open(image).convert('RGB') images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) return images def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if self.ps_version == 'v1': warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " 'which results in a transposed image.') else: x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values): if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = vit_embeds[:, 1:, :] if self.training and self.neftune_alpha is not None: vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha) h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device) return vit_embeds def get_T_after_cnn(self, L_in, dilation=1): for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 L_out = 1 + L_out // stride L_in = L_out return L_out def process_audio(self, audio, return_tensors, sampling_rate=16000): L = (audio.shape[0] if audio.shape[0] <= 480000 else 480000) # max_length < 30s mel_len = L // 160 audio_len_after_cnn = self.get_T_after_cnn(mel_len) audio_token_num = (audio_len_after_cnn - 2) // 2 + 1 inputs = self.audio_feature_extractor(audio, return_tensors=return_tensors, sampling_rate=sampling_rate) inputs['audio_len_after_cnn'] = torch.tensor(audio_len_after_cnn, dtype=torch.long) inputs['audio_token_num'] = torch.tensor(audio_token_num, dtype=torch.long) return inputs def load_audio(self, audio_file, sampling_rate=16000): audio_values, _ = librosa.load(audio_file, sr=sampling_rate) # sample rate should be 16000 audio_process_values = self.process_audio(audio_values, sampling_rate=sampling_rate, return_tensors="pt") input_features = audio_process_values['input_features'] audio_len_after_cnn = audio_process_values['audio_len_after_cnn'] audio_token_num = audio_process_values['audio_token_num'] audio_input_dict = {'audio_values': input_features, 'audio_len_after_cnn': audio_len_after_cnn, 'audio_token_num': audio_token_num, } return audio_input_dict def extract_audio_feature(self, audio_values, audio_len_after_cnn): audio_values = audio_values.squeeze(1) max_len_in_batch = int(torch.max(audio_len_after_cnn).item()) padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype, device=audio_values.device) for index in range(len(audio_values)): padding_mask[index, :int(audio_len_after_cnn[index].item())] = 0 last_hidden_state = self.audio_model(audio_values, padding_mask, audio_len_after_cnn) # (bs, max_token_num, 1280) audio_embeds = self.mlp2(last_hidden_state) return audio_embeds def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments) ]) return frame_indices def load_video(self, video_path, bound=None, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) frames = list() for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') frames.append(img) return frames def find_second_last_occurrence(self, input_ids_list, target_id): '''find taget_id index''' reversed_list = list(reversed(input_ids_list)) first_occurrence = -1 second_occurrence = -1 for idx, val in enumerate(reversed_list): if val == target_id: if first_occurrence == -1: first_occurrence = idx # first index elif second_occurrence == -1: second_occurrence = idx # second index break if second_occurrence == -1: return -1 return len(input_ids_list) - second_occurrence - 1 def decode_speech_tokens( self, speech_tokens, speaker_embedding=None, flow_prompt_speech_token=None, prompt_speech_feat=None, finalize=True, token_offset=0, ): if speaker_embedding is None: speaker_embedding = torch.zeros(1, 192) pass if flow_prompt_speech_token is None: flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) pass if prompt_speech_feat is None: prompt_speech_feat = torch.zeros(1, 0, 80) pass self.flow_model.encoder.static_chunk_size = 2 * self.flow_model.input_frame_rate # 50 self.flow_model.decoder.estimator.static_chunk_size = 2 * self.flow_model.input_frame_rate * self.flow_model.token_mel_ratio # 100 device = speech_tokens.device tts_mel, _ = self.flow_model.inference( token=speech_tokens.to(device), token_len=torch.tensor([speech_tokens.shape[1]], dtype=torch.int32).to(device), prompt_token=flow_prompt_speech_token.to(device), prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(device), prompt_feat=prompt_speech_feat.to(device), prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(device), embedding=speaker_embedding.to(device), finalize=finalize, ) tts_mel = tts_mel[:, :, token_offset * self.config.flow_config.token_mel_ratio:] hift_cache_source = torch.zeros(1, 1, 0) tts_speech, tts_source = self.hifigan_model.inference(speech_feat=tts_mel, cache_source=hift_cache_source) # [1, sampling point num] return tts_speech @torch.no_grad() def generate( self, pixel_values: torch.FloatTensor, input_ids: torch.FloatTensor, attention_mask: torch.LongTensor, visual_features: Optional[torch.FloatTensor] = None, audio_values: Optional[torch.FloatTensor] = None, audio_len_after_cnn: Optional[bool] = None, audio_token_num: Optional[bool] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, start_token_id:int = 151644, generate_audio:bool = False, speaker_embedding:torch.Tensor = torch.zeros(1, 192), mix_ratio:list=[5,25], **generate_kwargs, ) -> torch.LongTensor: assert self.img_context_token_id is not None assert self.audio_context_token_id is not None vit_embeds = None if visual_features is not None: vit_embeds = visual_features elif pixel_values is not None: vit_embeds = self.extract_feature(pixel_values) cur_conv_start_id = self.find_second_last_occurrence(input_ids.tolist()[0], start_token_id) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) if vit_embeds is not None: selected = (input_ids == self.img_context_token_id) input_embeds[selected] = vit_embeds.reshape(-1, C) if audio_values is not None and audio_len_after_cnn is not None and audio_token_num is not None: audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn) output_audios = [] for i in range(len(audio_token_num)): token_num = int(audio_token_num[i].item()) audio = audio_embeds[i][:token_num] output_audios.append(audio) output_audios = torch.cat(output_audios, dim=0) selected = (input_ids == self.audio_context_token_id) input_embeds[selected] = output_audios.reshape(-1, C) input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states or generate_audio, return_dict_in_generate=generate_audio, use_cache=True, **generate_kwargs, ) if not generate_audio: return outputs, None, None hidden_states = torch.cat( [outputs.hidden_states[0][-1][:, -1:, :]] + [outputs.hidden_states[i][-1] for i in range(1, len(outputs.hidden_states))], dim=1, ) sampled_token = outputs.sequences if sampled_token.shape[1] == hidden_states.shape[1] + 1: sampled_token = sampled_token[:, 1:] sampled_token_embeddings = self.language_model.get_input_embeddings()(sampled_token) target_text_token_hidden_states = self.fusion(hidden_states, sampled_token_embeddings) input_token_hidden_states = outputs.hidden_states[0][-1][:, cur_conv_start_id:-1, :] question_input_embeddings = input_embeds[:, cur_conv_start_id+1:, :] input_token_hidden_states = self.fusion(input_token_hidden_states, question_input_embeddings) input_feature = self.mlp_llm2voicelm(input_token_hidden_states) target_text_feature = self.mlp_llm2voicelm(target_text_token_hidden_states) # try: speech_tokens = self.voicelm_model.inference_bistream(input_feature, target_text_feature, mix_ratio=mix_ratio) speech_tokens = torch.LongTensor([speech_tokens]).to(input_feature.device) tts_speech = self.decode_speech_tokens( speech_tokens, speaker_embedding=speaker_embedding, ) except Exception as e: logger.warning(f"=========voice lm except:{e}") return outputs.sequences,None, None return outputs.sequences, speech_tokens, tts_speech def chat( self, tokenizer, generation_config, messages, max_patch_num=12, frame=8, generate_audio=False, speaker_embedding=torch.zeros(1, 192), print_flag=True, ): if self.flow_model.dtype != torch.float32 or self.hifigan_model.dtype != torch.float32: logger.info(f"reset flow model and higigan model dtype to float32") self.reset_vocoder() pass if messages is None or len(messages) == 0: raise RuntimeError('no messages') role_transfer_dict = { 'system': ['user'], 'user': ['assistant'], 'assistant': ['user'], } first_role = ['system', 'user'] last_role = ['user'] if messages[-1]['role'] not in last_role: raise RuntimeError(f"last role error, expect {last_role}, but got {messages[-1]}") current_role = None dynamic_images = list() dynamic_nums = list() audio_values = list() audio_len_after_cnn = list() audio_token_num = list() template = get_conv_template(self.template) for index in range(len(messages)): text = '' audios = list() images = list() message = messages[index] if index == 0: if message['role'] not in first_role: raise RuntimeError(f'first role error expect {first_role}, but got {message}') else: if message['role'] not in current_role: raise RuntimeError(f'role error expect {current_role}, but got {message}') current_role = message['role'] if isinstance(message["content"], list): for item in message["content"]: if item['type'] == 'text': if item.get('text', None) is None: continue text += item['text'] elif item['type'] == 'audio': if item.get('audio', None) is None: continue if type(item['audio']) is list: assert len(item['audio']) == 1, f'only support 1 audio file in round, but got {item["audio"]}' audio = item['audio'][0] else: audio = item['audio'] audios.append(audio) elif item['type'] == 'image': if item.get('image', None) is None: continue if type(item['image']) is not list: images.append(item['image']) else: images.extend(item['image']) elif item['type'] == 'video': if item.get('video', None) is None: continue if type(item['video']) is list: assert len(item['video']) == 1, f'only support 1 video file in round, but got {item["video"]}' video = item['video'][0] else: video = item['video'] frames = self.load_video(video, num_segments=frame) images.extend(frames) else: assert isinstance(message["content"], str), message["content"] text = message["content"] if len(audios) != 0: assert len(audios) == 1, f'only support 1 audio file in round, but got {audios}' if '