# --------------------------------------------------------
# SenseTime
# Copyright (c) 2025 SenseTime
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import warnings
from typing import Any, List, Optional, Tuple, Union
import re
import json
import math
import librosa
import numpy as np
from PIL import Image
from decord import VideoReader, cpu
from torch import nn
import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import (GenerationConfig, Qwen3ForCausalLM, WhisperFeatureExtractor)
from transformers.modeling_utils import PreTrainedModel
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
import torchaudio
from transformers.utils.hub import cached_file
from .configuration_interactiveomni import InteractiveOmniConfig
from .modeling_intern_vit import InternVisionModel
from .modeling_whisper import AudioWhisperModel
from .modeling_voicelm import VoiceLM
from .conversation import get_conv_template
from .modeling_flow import CausalMaskedDiffWithXvec
from .modeling_hifigan import HiFTGenerator
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_START_TOKEN = '
'
IMG_END_TOKEN = ''
IMG_CONTEXT_TOKEN = ''
AUDIO_START_TOKEN = ''
AUDIO_CONTEXT_TOKEN = ''
class InteractiveOmniModel(PreTrainedModel):
config_class = InteractiveOmniConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_no_split_modules = ['InternVisionModel', 'AudioWhisperModel', 'Qwen3DecoderLayer', 'Qwen2DecoderLayer']
def __init__(self, config: InteractiveOmniConfig, vision_model=None, language_model=None, audio_model=None):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.select_layer = config.select_layer
self.template = config.template
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.audio_feature_extractor = WhisperFeatureExtractor(**config.audio_preprocessor_config)
self.transform = self.build_transform(input_size=image_size)
self.campplus_session = None
self.default_speaker_embedding = None
self.default_wav_path = None
logger.info(f'num_image_token: {self.num_image_token}')
logger.info(f'ps_version: {self.ps_version}')
if vision_model is not None:
self.vision_model = vision_model
else:
self.vision_model = InternVisionModel(config.vision_config)
if audio_model is not None:
self.audio_model = audio_model
else:
self.audio_model = AudioWhisperModel(config.audio_config)
if language_model is not None:
self.language_model = language_model
else:
self.language_model = Qwen3ForCausalLM(config.llm_config)
self.voicelm_model = VoiceLM(config.voicelm_config)
self.flow_model = CausalMaskedDiffWithXvec(config.flow_config).float()
self.hifigan_model = HiFTGenerator(config.hifigan_config).float()
vit_hidden_size = config.vision_config.hidden_size
audio_hidden_size = config.audio_config.d_model
llm_hidden_size = config.llm_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
self.mlp2 = nn.Sequential(
nn.LayerNorm(audio_hidden_size),
nn.Linear(audio_hidden_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
self.mlp_llm2voicelm = nn.Sequential(
nn.LayerNorm(llm_hidden_size),
nn.Linear(llm_hidden_size, config.voicelm_config.llm_input_size),
nn.GELU(),
nn.Linear(config.voicelm_config.llm_input_size, config.voicelm_config.llm_input_size)
)
self.gate = nn.Sequential(
nn.Linear(2 * llm_hidden_size, llm_hidden_size),
nn.Sigmoid()
)
self.img_context_token_id = None
self.audio_context_token_id = None
self.neftune_alpha = None
self.post_init()
pass
def fusion(self, rep, emb):
gate = self.gate(torch.cat([rep, emb], dim=-1))
return rep * gate + emb * (1 - gate)
def __load_campplus_session(self, campplus_path:str):
''''''
logger.info(f"load campplus session: {campplus_path}")
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
campplus_session = onnxruntime.InferenceSession(
campplus_path,
sess_options=option,
providers=["CPUExecutionProvider"],
)
self.campplus_session = campplus_session
return campplus_session
def extract_speaker_embedding(self, prompt_wav:str):
'''extract speaker embedding tensor'''
logger.info(f"extract speaker embedding: {prompt_wav}")
target_sr = 16000
prompt_speech_16k, sample_rate = torchaudio.load(prompt_wav)
prompt_speech_16k = prompt_speech_16k.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
prompt_speech_16k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(prompt_speech_16k)
feat = kaldi.fbank(
prompt_speech_16k,
num_mel_bins=80,
dither=0,
sample_frequency=target_sr,
)
feat = feat - feat.mean(dim=0, keepdim=True)
speaker_embedding = self.campplus_session.run(
None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()},
)[0].flatten().tolist()
speaker_embedding = torch.tensor([speaker_embedding])
return speaker_embedding
def build_transform(self, input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(self, image, min_num=1, max_num=6, image_size=448):
assert min_num == 1
original_width, original_height = image.size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / (image_size * image_size)
multiple = min(math.ceil(ratio), max_num)
if multiple <= 1:
return [1, 1]
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i > max_num:
continue
candidate_split_grids_nums.append(i)
candidate_grids = []
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
return best_grid
def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
target_aspect_ratio = self.find_closest_aspect_ratio(image, min_num, max_num, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(self, image, input_size=448, max_num=12):
if not isinstance(image, Image.Image):
image = Image.open(image).convert('RGB')
images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
return images
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
'which results in a transposed image.')
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
if self.training and self.neftune_alpha is not None:
vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
return vit_embeds
def get_T_after_cnn(self, L_in, dilation=1):
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
return L_out
def process_audio(self, audio, return_tensors, sampling_rate=16000):
L = (audio.shape[0] if audio.shape[0] <= 480000 else 480000) # max_length < 30s
mel_len = L // 160
audio_len_after_cnn = self.get_T_after_cnn(mel_len)
audio_token_num = (audio_len_after_cnn - 2) // 2 + 1
inputs = self.audio_feature_extractor(audio, return_tensors=return_tensors, sampling_rate=sampling_rate)
inputs['audio_len_after_cnn'] = torch.tensor(audio_len_after_cnn, dtype=torch.long)
inputs['audio_token_num'] = torch.tensor(audio_token_num, dtype=torch.long)
return inputs
def load_audio(self, audio_file, sampling_rate=16000):
audio_values, _ = librosa.load(audio_file, sr=sampling_rate) # sample rate should be 16000
audio_process_values = self.process_audio(audio_values, sampling_rate=sampling_rate, return_tensors="pt")
input_features = audio_process_values['input_features']
audio_len_after_cnn = audio_process_values['audio_len_after_cnn']
audio_token_num = audio_process_values['audio_token_num']
audio_input_dict = {'audio_values': input_features,
'audio_len_after_cnn': audio_len_after_cnn,
'audio_token_num': audio_token_num,
}
return audio_input_dict
def extract_audio_feature(self, audio_values, audio_len_after_cnn):
audio_values = audio_values.squeeze(1)
max_len_in_batch = int(torch.max(audio_len_after_cnn).item())
padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype, device=audio_values.device)
for index in range(len(audio_values)):
padding_mask[index, :int(audio_len_after_cnn[index].item())] = 0
last_hidden_state = self.audio_model(audio_values, padding_mask, audio_len_after_cnn) # (bs, max_token_num, 1280)
audio_embeds = self.mlp2(last_hidden_state)
return audio_embeds
def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(num_segments)
])
return frame_indices
def load_video(self, video_path, bound=None, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
frames = list()
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
frames.append(img)
return frames
def find_second_last_occurrence(self, input_ids_list, target_id):
'''find taget_id index'''
reversed_list = list(reversed(input_ids_list))
first_occurrence = -1
second_occurrence = -1
for idx, val in enumerate(reversed_list):
if val == target_id:
if first_occurrence == -1:
first_occurrence = idx # first index
elif second_occurrence == -1:
second_occurrence = idx # second index
break
if second_occurrence == -1:
return -1
return len(input_ids_list) - second_occurrence - 1
def decode_speech_tokens(
self,
speech_tokens,
speaker_embedding=None,
flow_prompt_speech_token=None,
prompt_speech_feat=None,
finalize=True,
token_offset=0,
):
if speaker_embedding is None:
speaker_embedding = torch.zeros(1, 192)
pass
if flow_prompt_speech_token is None:
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
pass
if prompt_speech_feat is None:
prompt_speech_feat = torch.zeros(1, 0, 80)
pass
self.flow_model.encoder.static_chunk_size = 2 * self.flow_model.input_frame_rate # 50
self.flow_model.decoder.estimator.static_chunk_size = 2 * self.flow_model.input_frame_rate * self.flow_model.token_mel_ratio # 100
device = speech_tokens.device
tts_mel, _ = self.flow_model.inference(
token=speech_tokens.to(device),
token_len=torch.tensor([speech_tokens.shape[1]], dtype=torch.int32).to(device),
prompt_token=flow_prompt_speech_token.to(device),
prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(device),
prompt_feat=prompt_speech_feat.to(device),
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(device),
embedding=speaker_embedding.to(device),
finalize=finalize,
)
tts_mel = tts_mel[:, :, token_offset * self.config.flow_config.token_mel_ratio:]
hift_cache_source = torch.zeros(1, 1, 0)
tts_speech, tts_source = self.hifigan_model.inference(speech_feat=tts_mel, cache_source=hift_cache_source) # [1, sampling point num]
return tts_speech
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor,
attention_mask: torch.LongTensor,
visual_features: Optional[torch.FloatTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
audio_len_after_cnn: Optional[bool] = None,
audio_token_num: Optional[bool] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
start_token_id:int = 151644,
generate_audio:bool = False,
speaker_embedding:torch.Tensor = torch.zeros(1, 192),
mix_ratio:list=[5,25],
**generate_kwargs,
) -> torch.LongTensor:
assert self.img_context_token_id is not None
assert self.audio_context_token_id is not None
vit_embeds = None
if visual_features is not None:
vit_embeds = visual_features
elif pixel_values is not None:
vit_embeds = self.extract_feature(pixel_values)
cur_conv_start_id = self.find_second_last_occurrence(input_ids.tolist()[0], start_token_id)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
if vit_embeds is not None:
selected = (input_ids == self.img_context_token_id)
input_embeds[selected] = vit_embeds.reshape(-1, C)
if audio_values is not None and audio_len_after_cnn is not None and audio_token_num is not None:
audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn)
output_audios = []
for i in range(len(audio_token_num)):
token_num = int(audio_token_num[i].item())
audio = audio_embeds[i][:token_num]
output_audios.append(audio)
output_audios = torch.cat(output_audios, dim=0)
selected = (input_ids == self.audio_context_token_id)
input_embeds[selected] = output_audios.reshape(-1, C)
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states or generate_audio,
return_dict_in_generate=generate_audio,
use_cache=True,
**generate_kwargs,
)
if not generate_audio:
return outputs, None, None
hidden_states = torch.cat(
[outputs.hidden_states[0][-1][:, -1:, :]] + [outputs.hidden_states[i][-1] for i in range(1, len(outputs.hidden_states))],
dim=1,
)
sampled_token = outputs.sequences
if sampled_token.shape[1] == hidden_states.shape[1] + 1:
sampled_token = sampled_token[:, 1:]
sampled_token_embeddings = self.language_model.get_input_embeddings()(sampled_token)
target_text_token_hidden_states = self.fusion(hidden_states, sampled_token_embeddings)
input_token_hidden_states = outputs.hidden_states[0][-1][:, cur_conv_start_id:-1, :]
question_input_embeddings = input_embeds[:, cur_conv_start_id+1:, :]
input_token_hidden_states = self.fusion(input_token_hidden_states, question_input_embeddings)
input_feature = self.mlp_llm2voicelm(input_token_hidden_states)
target_text_feature = self.mlp_llm2voicelm(target_text_token_hidden_states) #
try:
speech_tokens = self.voicelm_model.inference_bistream(input_feature, target_text_feature, mix_ratio=mix_ratio)
speech_tokens = torch.LongTensor([speech_tokens]).to(input_feature.device)
tts_speech = self.decode_speech_tokens(
speech_tokens,
speaker_embedding=speaker_embedding,
)
except Exception as e:
logger.warning(f"=========voice lm except:{e}")
return outputs.sequences,None, None
return outputs.sequences, speech_tokens, tts_speech
def chat(
self,
tokenizer,
generation_config,
messages,
max_patch_num=12,
frame=8,
generate_audio=False,
speaker_embedding=torch.zeros(1, 192),
print_flag=True,
):
if self.flow_model.dtype != torch.float32 or self.hifigan_model.dtype != torch.float32:
logger.info(f"reset flow model and higigan model dtype to float32")
self.reset_vocoder()
pass
if messages is None or len(messages) == 0:
raise RuntimeError('no messages')
role_transfer_dict = {
'system': ['user'],
'user': ['assistant'],
'assistant': ['user'],
}
first_role = ['system', 'user']
last_role = ['user']
if messages[-1]['role'] not in last_role:
raise RuntimeError(f"last role error, expect {last_role}, but got {messages[-1]}")
current_role = None
dynamic_images = list()
dynamic_nums = list()
audio_values = list()
audio_len_after_cnn = list()
audio_token_num = list()
template = get_conv_template(self.template)
for index in range(len(messages)):
text = ''
audios = list()
images = list()
message = messages[index]
if index == 0:
if message['role'] not in first_role:
raise RuntimeError(f'first role error expect {first_role}, but got {message}')
else:
if message['role'] not in current_role:
raise RuntimeError(f'role error expect {current_role}, but got {message}')
current_role = message['role']
if isinstance(message["content"], list):
for item in message["content"]:
if item['type'] == 'text':
if item.get('text', None) is None:
continue
text += item['text']
elif item['type'] == 'audio':
if item.get('audio', None) is None:
continue
if type(item['audio']) is list:
assert len(item['audio']) == 1, f'only support 1 audio file in round, but got {item["audio"]}'
audio = item['audio'][0]
else:
audio = item['audio']
audios.append(audio)
elif item['type'] == 'image':
if item.get('image', None) is None:
continue
if type(item['image']) is not list:
images.append(item['image'])
else:
images.extend(item['image'])
elif item['type'] == 'video':
if item.get('video', None) is None:
continue
if type(item['video']) is list:
assert len(item['video']) == 1, f'only support 1 video file in round, but got {item["video"]}'
video = item['video'][0]
else:
video = item['video']
frames = self.load_video(video, num_segments=frame)
images.extend(frames)
else:
assert isinstance(message["content"], str), message["content"]
text = message["content"]
if len(audios) != 0:
assert len(audios) == 1, f'only support 1 audio file in round, but got {audios}'
if '