Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import typing as tp | |
from functools import partial | |
from dataclasses import dataclass, field | |
from typing import Dict, List, Optional, Tuple, Union | |
import copy | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers.models.auto import AutoModel | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.utils import logging | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.activations import ACT2FN | |
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig | |
logger = logging.get_logger(__name__) | |
import os | |
# Try to import APEX FusedRMSNorm | |
try: | |
from apex.normalization.fused_layer_norm import fused_rms_norm_affine | |
APEX_AVAILABLE = True | |
logger.info("APEX FusedRMSNorm is available and will be used for optimization") | |
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: | |
APEX_AVAILABLE = False | |
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0") | |
except ImportError: | |
APEX_AVAILABLE = False | |
logger.warning("APEX FusedRMSNorm not available, using native implementation") | |
# APEX_AVAILABLE=False | |
# Normalization modules | |
class ConvLayerNorm(nn.LayerNorm): | |
""" | |
Convolution-friendly LayerNorm that moves channels to last dimensions | |
before running the normalization and moves them back to original position right after. | |
""" | |
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): | |
super().__init__(normalized_shape, **kwargs) | |
def forward(self, x): | |
x = x.transpose(1, 2) # b ... t -> b t ... | |
x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x) | |
x = x.transpose(1, 2) # b t ... -> b ... t | |
return x | |
class RMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None): | |
super().__init__() | |
self.dim = dim | |
self.eps = eps | |
self.elementwise_affine = elementwise_affine | |
if self.elementwise_affine: | |
weight_shape = (dim,) if weight_shape is None else weight_shape | |
self.weight = nn.Parameter(torch.ones(weight_shape)) | |
else: | |
self.register_parameter('weight', None) | |
def _norm(self, x): | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
output = self._norm(x.float()).type_as(x) | |
if self.weight is not None: | |
output = output * self.weight | |
return output | |
def extra_repr(self) -> str: | |
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' | |
class ConvRMSNorm(RMSNorm): | |
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None): | |
super().__init__(dim, eps, elementwise_affine, weight_shape) | |
def forward(self, x): | |
x = x.transpose(1, 2) # b ... t -> b t ... | |
if (not APEX_AVAILABLE) or (not self.elementwise_affine): | |
# Fallback to native implementation | |
output = self._norm(x.float()).type_as(x) | |
if self.weight is not None: | |
output = output * self.weight | |
else: | |
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps) | |
output = output.transpose(1, 2) # b t ... -> b ... t | |
return output | |
# Convolutional layers and utilities | |
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', | |
'time_layer_norm', 'layer_norm', 'time_group_norm']) | |
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: | |
assert norm in CONV_NORMALIZATIONS | |
if norm == 'weight_norm': | |
return nn.utils.weight_norm(module) | |
elif norm == 'spectral_norm': | |
return nn.utils.spectral_norm(module) | |
else: | |
# We already check was in CONV_NORMALIZATION, so any other choice | |
# doesn't need reparametrization. | |
return module | |
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: | |
"""Return the proper normalization module. If causal is True, this will ensure the returned | |
module is causal, or return an error if the normalization doesn't support causal evaluation. | |
""" | |
assert norm in CONV_NORMALIZATIONS | |
if norm == 'layer_norm': | |
assert isinstance(module, nn.modules.conv._ConvNd) | |
return ConvLayerNorm(module.out_channels, **norm_kwargs) | |
elif norm == 'time_group_norm': | |
if causal: | |
raise ValueError("GroupNorm doesn't support causal evaluation.") | |
assert isinstance(module, nn.modules.conv._ConvNd) | |
return nn.GroupNorm(1, module.out_channels, **norm_kwargs) | |
else: | |
return nn.Identity() | |
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, | |
padding_total: int = 0) -> int: | |
"""Calculate extra padding needed for convolution to have the same output length""" | |
length = x.shape[-1] | |
n_frames = (length - kernel_size + padding_total) / stride + 1 | |
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
return ideal_length - length | |
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): | |
"""Pad 1D input with handling for small inputs in reflect mode""" | |
length = x.shape[-1] | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
if mode == 'reflect': | |
max_pad = max(padding_left, padding_right) | |
extra_pad = 0 | |
if length <= max_pad: | |
extra_pad = max_pad - length + 1 | |
x = F.pad(x, (0, extra_pad)) | |
padded = F.pad(x, paddings, mode, value) | |
end = padded.shape[-1] - extra_pad | |
return padded[..., :end] | |
else: | |
return F.pad(x, paddings, mode, value) | |
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
"""Remove padding from x, handling properly zero padding. Only for 1d!""" | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
assert (padding_left + padding_right) <= x.shape[-1] | |
end = x.shape[-1] - padding_right | |
return x[..., padding_left: end] | |
class NormConv1d(nn.Module): | |
"""Wrapper around Conv1d and normalization applied to this conv""" | |
def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
super().__init__() | |
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) | |
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) | |
self.norm_type = norm | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.norm(x) | |
return x | |
class NormConvTranspose1d(nn.Module): | |
"""Wrapper around ConvTranspose1d and normalization applied to this conv""" | |
def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
super().__init__() | |
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) | |
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) | |
self.norm_type = norm | |
def forward(self, x): | |
x = self.convtr(x) | |
x = self.norm(x) | |
return x | |
class VibeVoiceTokenizerStreamingCache: | |
"""Cache for streaming convolution, similar to KV cache in attention""" | |
def __init__(self): | |
self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor | |
def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]: | |
"""Get cached states for given layer and sample indices""" | |
states = [] | |
max_length = 0 | |
# First pass: collect states and find max length | |
for idx in sample_indices.tolist(): | |
key = (layer_id, idx) | |
if key not in self.cache: | |
return None # If any sample is missing, return None | |
state = self.cache[key] | |
states.append(state) | |
max_length = max(max_length, state.shape[-1]) | |
# Second pass: pad states to max length if needed | |
if len(states) > 0 and states[0].dim() >= 2: | |
padded_states = [] | |
for state in states: | |
if state.shape[-1] < max_length: | |
# Pad on the time dimension (last dimension) | |
pad_size = max_length - state.shape[-1] | |
# Pad with zeros on the LEFT to align the most recent samples | |
padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0) | |
padded_states.append(padded_state) | |
else: | |
padded_states.append(state) | |
return torch.stack(padded_states, dim=0) | |
else: | |
return torch.stack(states, dim=0) | |
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor): | |
"""Set cached states for given layer and sample indices""" | |
for i, idx in enumerate(sample_indices.tolist()): | |
key = (layer_id, idx) | |
self.cache[key] = states[i].detach() | |
def set_to_zero(self, sample_indices: torch.Tensor): | |
"""Set all cached states to zero for given sample indices""" | |
for key in list(self.cache.keys()): | |
layer_id, sample_idx = key | |
if sample_idx in sample_indices.tolist(): | |
# Create zero tensor with same shape and dtype as cached tensor | |
cached_tensor = self.cache[key] | |
self.cache[key] = torch.zeros_like(cached_tensor) | |
def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None): | |
"""Clear cache for specific layer/samples or everything""" | |
if layer_id is None and sample_indices is None: | |
self.cache.clear() | |
elif layer_id is not None and sample_indices is None: | |
# Clear all samples for a specific layer | |
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id] | |
for k in keys_to_remove: | |
del self.cache[k] | |
elif layer_id is not None and sample_indices is not None: | |
# Clear specific samples for a specific layer | |
for idx in sample_indices.tolist(): | |
key = (layer_id, idx) | |
self.cache.pop(key, None) | |
class SConv1d(nn.Module): | |
"""Conv1d with built-in handling of asymmetric or causal padding and normalization.""" | |
def __init__(self, in_channels: int, out_channels: int, | |
kernel_size: int, stride: int = 1, dilation: int = 1, | |
groups: int = 1, bias: bool = True, causal: bool = False, | |
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
pad_mode: str = 'reflect'): | |
super().__init__() | |
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, | |
dilation=dilation, groups=groups, bias=bias, causal=causal, | |
norm=norm, norm_kwargs=norm_kwargs) | |
self.causal = causal | |
self.pad_mode = pad_mode | |
# Store configuration | |
self.kernel_size = kernel_size | |
self.dilation = dilation | |
self.stride = stride | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
# For causal convolution, we need to maintain kernel_size - 1 samples as context | |
# need to check use which context_size is more suitable | |
# self.context_size = (kernel_size - 1) * dilation | |
self.context_size = (kernel_size - 1) * dilation - (stride - 1) | |
# For non-streaming mode, calculate padding | |
self.padding_total = (kernel_size - 1) * dilation - (stride - 1) | |
# Create a unique layer ID for cache management | |
self._layer_id = None | |
def layer_id(self): | |
if self._layer_id is None: | |
self._layer_id = f"sconv1d_{id(self)}" | |
return self._layer_id | |
def forward(self, x: torch.Tensor, | |
cache: Optional[VibeVoiceTokenizerStreamingCache] = None, | |
sample_indices: Optional[torch.Tensor] = None, | |
use_cache: bool = False, | |
debug: bool = False) -> torch.Tensor: | |
""" | |
Forward pass with optional streaming support via cache. | |
Args: | |
x: Input tensor [batch_size, channels, time] | |
cache: VibeVoiceTokenizerStreamingCache object for maintaining states | |
sample_indices: Indices identifying each sample for cache management | |
use_cache: Whether to use cached states for streaming | |
debug: Whether to print debug information | |
Returns: | |
Output tensor | |
""" | |
B, C, T = x.shape | |
# Non-streaming mode | |
if not use_cache or cache is None: | |
return self._forward_non_streaming(x, debug=debug) | |
# Streaming mode | |
assert self.causal, "Streaming mode is only supported for causal convolutions" | |
assert sample_indices is not None, "sample_indices must be provided for streaming mode" | |
assert len(sample_indices) == B, "sample_indices must match batch size" | |
return self._forward_streaming(x, cache, sample_indices, debug) | |
def _forward_streaming(self, x: torch.Tensor, | |
cache: VibeVoiceTokenizerStreamingCache, | |
sample_indices: torch.Tensor, | |
debug: bool = False) -> torch.Tensor: | |
"""Streaming forward pass with cache operations kept separate from compiled code""" | |
B, C, T = x.shape | |
# Cache operations (not compiled) | |
cached_states = cache.get(self.layer_id, sample_indices) | |
if cached_states is None: | |
# First chunk - initialize with zeros for context | |
if self.context_size > 0: | |
cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype) | |
if debug: | |
print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}") | |
else: | |
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) | |
if debug: | |
print(f"[DEBUG] No context needed (kernel_size=stride)") | |
# Concatenate cached states with input | |
if cached_states.shape[2] > 0: | |
input_with_context = torch.cat([cached_states, x], dim=2) | |
else: | |
input_with_context = x | |
if debug: | |
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}") | |
# Apply convolution directly - no extra padding in streaming mode | |
# The conv layer will handle its own padding internally | |
output = self.conv(input_with_context) | |
if debug: | |
print(f"[DEBUG] Output shape: {output.shape}") | |
# Update cache for next chunk | |
if self.context_size > 0: | |
# Calculate how many samples to keep | |
total_input_length = input_with_context.shape[2] | |
# Keep the last context_size samples | |
if total_input_length >= self.context_size: | |
new_cache_start = total_input_length - self.context_size | |
new_cache = input_with_context[:, :, new_cache_start:] | |
else: | |
# If we have less than context_size samples, keep everything | |
new_cache = input_with_context | |
if debug: | |
print(f"[DEBUG] New cache shape: {new_cache.shape}") | |
cache.set(self.layer_id, sample_indices, new_cache) | |
return output | |
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor: | |
"""Standard forward pass without streaming""" | |
B, C, T = x.shape | |
kernel_size = self.kernel_size | |
stride = self.stride | |
dilation = self.dilation | |
padding_total = self.padding_total | |
# Compute extra padding for stride alignment | |
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) | |
if debug: | |
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}") | |
if self.causal: | |
# Left padding for causal | |
if self.pad_mode == 'constant': | |
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0) | |
else: | |
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) | |
else: | |
# Symmetric padding for non-causal | |
padding_right = padding_total // 2 | |
padding_left = padding_total - padding_right | |
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) | |
if debug: | |
print(f"[DEBUG NON-STREAMING] After padding: {x.shape}") | |
output = self.conv(x) | |
if debug: | |
print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}") | |
return output | |
class SConvTranspose1d(nn.Module): | |
"""ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization.""" | |
def __init__(self, in_channels: int, out_channels: int, | |
kernel_size: int, stride: int = 1, causal: bool = False, | |
norm: str = 'none', trim_right_ratio: float = 1., | |
norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True): | |
super().__init__() | |
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, | |
causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias) | |
self.causal = causal | |
self.trim_right_ratio = trim_right_ratio | |
assert self.causal or self.trim_right_ratio == 1., \ | |
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions" | |
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. | |
# Store configuration | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
# For transposed convolution, padding calculation is different | |
self.padding_total = kernel_size - stride | |
# For streaming, we need to keep track of input history | |
# Transposed conv needs to see multiple input samples to produce correct output | |
self.context_size = kernel_size - 1 | |
# Create a unique layer ID for cache management | |
self._layer_id = None | |
def layer_id(self): | |
if self._layer_id is None: | |
self._layer_id = f"sconvtr1d_{id(self)}" | |
return self._layer_id | |
def forward(self, x: torch.Tensor, | |
cache: Optional[VibeVoiceTokenizerStreamingCache] = None, | |
sample_indices: Optional[torch.Tensor] = None, | |
use_cache: bool = False, | |
debug: bool = False) -> torch.Tensor: | |
""" | |
Forward pass with optional streaming support via cache. | |
""" | |
B, C, T = x.shape | |
# Non-streaming mode | |
if not use_cache or cache is None: | |
return self._forward_non_streaming(x, debug=debug) | |
# Streaming mode | |
assert sample_indices is not None, "sample_indices must be provided for streaming mode" | |
assert len(sample_indices) == B, "sample_indices must match batch size" | |
return self._forward_streaming(x, cache, sample_indices, debug) | |
def _forward_streaming(self, x: torch.Tensor, | |
cache: VibeVoiceTokenizerStreamingCache, | |
sample_indices: torch.Tensor, | |
debug: bool = False) -> torch.Tensor: | |
"""Streaming forward pass with cache operations kept separate from compiled code""" | |
B, C, T = x.shape | |
# Cache operations (not compiled) | |
cached_input = cache.get(self.layer_id, sample_indices) | |
if cached_input is None: | |
# First chunk - no history yet | |
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) | |
if debug: | |
print(f"[DEBUG] Initialized empty cache for transposed conv") | |
# Concatenate cached input with new input | |
full_input = torch.cat([cached_input, x], dim=2) | |
if debug: | |
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}") | |
# First chunk or debug mode - use uncompiled version | |
full_output = self.convtr(full_input) | |
if debug: | |
print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}") | |
# Calculate padding to remove | |
if self.causal: | |
padding_right = math.ceil(self.padding_total * self.trim_right_ratio) | |
padding_left = self.padding_total - padding_right | |
else: | |
padding_right = self.padding_total // 2 | |
padding_left = self.padding_total - padding_right | |
# Remove padding | |
if padding_left + padding_right > 0: | |
full_output = unpad1d(full_output, (padding_left, padding_right)) | |
if debug: | |
print(f"[DEBUG] After unpadding: {full_output.shape}") | |
# Determine which part of the output corresponds to the new input | |
if cached_input.shape[2] == 0: | |
# First chunk - return all output | |
output = full_output | |
else: | |
# Subsequent chunks - return only the new output | |
expected_new_output = T * self.stride | |
# Take the last expected_new_output samples | |
if full_output.shape[2] >= expected_new_output: | |
output = full_output[:, :, -expected_new_output:] | |
else: | |
output = full_output | |
if debug: | |
print(f"[DEBUG] Final streaming output shape: {output.shape}") | |
# Update cache | |
if full_input.shape[2] > self.context_size: | |
new_cache = full_input[:, :, -self.context_size:] | |
else: | |
new_cache = full_input | |
if debug: | |
print(f"[DEBUG] New cache shape: {new_cache.shape}") | |
cache.set(self.layer_id, sample_indices, new_cache) | |
return output | |
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor: | |
"""Standard forward pass without streaming""" | |
if debug: | |
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}") | |
# Apply transposed convolution | |
y = self.convtr(x) | |
if debug: | |
print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}") | |
# Calculate and remove padding | |
if self.causal: | |
padding_right = math.ceil(self.padding_total * self.trim_right_ratio) | |
padding_left = self.padding_total - padding_right | |
else: | |
padding_right = self.padding_total // 2 | |
padding_left = self.padding_total - padding_right | |
if padding_left + padding_right > 0: | |
y = unpad1d(y, (padding_left, padding_right)) | |
if debug: | |
print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}") | |
return y | |
# FFN | |
class FFN(nn.Module): | |
def __init__( | |
self, | |
embed_dim, | |
ffn_dim, | |
bias=False, | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias) | |
self.gelu = ACT2FN["gelu"] | |
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias) | |
def forward(self, x): | |
x = self.linear1(x) | |
x = self.gelu(x) | |
x = self.linear2(x) | |
return x | |
class Convlayer(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
dilation=1, | |
groups=1, | |
bias=True, | |
pad_mode='zeros', | |
norm='weight_norm', | |
causal=True, | |
): | |
super().__init__() | |
self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, | |
groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal) | |
def forward(self, x): | |
return self.conv(x) | |
class Block1D(nn.Module): | |
def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv', | |
layer_scale_init_value=1e-6, **kwargs): | |
super().__init__() | |
if kwargs.get('layernorm', 'LN') == 'LN': | |
self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm': | |
self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
if mixer_layer == 'conv': | |
self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1), | |
kernel_size=kernel_size, | |
pad_mode=kwargs.get('pad_mode', 'reflect'), | |
norm=kwargs.get('norm', 'none'), | |
causal=kwargs.get('causal', True), | |
bias=kwargs.get('bias', True), | |
) | |
elif mixer_layer == 'depthwise_conv': | |
self.mixer = Convlayer(dim, dim, groups=dim, | |
kernel_size=kernel_size, | |
pad_mode=kwargs.get('pad_mode', 'reflect'), | |
norm=kwargs.get('norm', 'none'), | |
causal=kwargs.get('causal', True), | |
bias=kwargs.get('bias', True), | |
) | |
else: | |
raise ValueError(f"Unsupported mixer layer: {mixer_layer}") | |
self.ffn = FFN( | |
dim, | |
kwargs.get('ffn_expansion', 4) * dim, | |
bias=kwargs.get('bias', False), | |
) | |
self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path) | |
if layer_scale_init_value > 0: | |
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
else: | |
self.gamma = None | |
self.ffn_gamma = None | |
def forward(self, x): | |
# mixer | |
residual = x | |
x = self.norm(x) | |
x = self.mixer(x) | |
if self.gamma is not None: | |
x = x * self.gamma.unsqueeze(-1) | |
x = residual + self.drop_path(x) | |
# ffn | |
residual = x | |
x = self.ffn_norm(x) | |
x = x.permute(0, 2, 1) | |
x = self.ffn(x) | |
x = x.permute(0, 2, 1) | |
if self.ffn_gamma is not None: | |
x = x * self.ffn_gamma.unsqueeze(-1) | |
x = residual + self.drop_path(x) | |
return x | |
class TokenizerEncoder(nn.Module): | |
""" | |
Encoder component for the VibeVoice tokenizer that converts audio to latent representations. | |
Args: | |
config: Configuration object with model parameters | |
""" | |
def __init__(self, config): | |
super().__init__() | |
# Extract parameters from config | |
self.channels = config.channels | |
self.dimension = config.dimension | |
self.n_filters = config.n_filters | |
self.ratios = list(reversed(config.ratios)) | |
self.depths = config.depths | |
self.n_residual_layers = getattr(config, "n_residual_layers", 1) | |
self.hop_length = np.prod(self.ratios) | |
self.causal = config.causal | |
# Additional config parameters with defaults | |
kernel_size = getattr(config, "kernel_size", 7) | |
last_kernel_size = getattr(config, "last_kernel_size", 7) | |
norm = getattr(config, "norm", "none") | |
norm_params = getattr(config, "norm_params", {}) | |
pad_mode = getattr(config, "pad_mode", "reflect") | |
bias = getattr(config, "bias", True) | |
layernorm = getattr(config, "layernorm", "LN") | |
layernorm_eps = getattr(config, "layernorm_eps", 1e-6) | |
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) | |
drop_path_rate = getattr(config, "drop_path_rate", 0.0) | |
mixer_layer = getattr(config, "mixer_layer", "conv") | |
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) | |
disable_last_norm = getattr(config, "disable_last_norm", False) | |
# determine the norm type based on layernorm | |
if layernorm == 'LN': | |
norm_type = ConvLayerNorm | |
elif layernorm == 'RMSNorm': | |
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) | |
else: | |
raise ValueError(f"Unsupported norm type: {layernorm}") | |
# stem and intermediate downsampling conv layers | |
stem = nn.Sequential( | |
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), | |
) | |
self.downsample_layers = nn.ModuleList() | |
self.downsample_layers.append(stem) | |
for i in range(len(self.ratios)): | |
in_ch = self.n_filters * (2 ** i) | |
out_ch = self.n_filters * (2 ** (i + 1)) | |
downsample_layer = nn.Sequential( | |
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) | |
) | |
self.downsample_layers.append(downsample_layer) | |
# configure the transformer blocks | |
layer_type = partial( | |
Block1D, | |
mixer_layer=mixer_layer, | |
layernorm=layernorm, | |
eps=layernorm_eps, | |
causal=self.causal, | |
pad_mode=pad_mode, | |
norm=norm, | |
bias=bias, | |
layer_scale_init_value=layer_scale_init_value, | |
) | |
self.stages = nn.ModuleList() | |
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] | |
cur = 0 | |
for i in range(len(self.depths)): | |
in_ch = self.n_filters * (2 ** i) | |
stage = nn.Sequential( | |
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] | |
) | |
self.stages.append(stage) | |
cur += self.depths[i] | |
if not disable_last_norm: | |
self.norm = norm_type(in_ch, eps=layernorm_eps) | |
else: | |
self.norm = nn.Identity() | |
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) | |
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
for i in range(len(self.depths)): | |
# Apply downsampling | |
for layer in self.downsample_layers[i]: | |
if isinstance(layer, SConv1d): | |
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
else: | |
x = layer(x) | |
# Apply stage (Block1D contains Convlayer which contains SConv1d) | |
for block in self.stages[i]: | |
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): | |
# Block1D forward with cache support | |
residual = x | |
x = block.norm(x) | |
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
if block.gamma is not None: | |
x = x * block.gamma.unsqueeze(-1) | |
x = residual + x | |
# FFN part | |
residual = x | |
x = block.ffn_norm(x) | |
x = x.permute(0, 2, 1) | |
x = block.ffn(x) | |
x = x.permute(0, 2, 1) | |
if block.ffn_gamma is not None: | |
x = x * block.ffn_gamma.unsqueeze(-1) | |
x = residual + x | |
else: | |
x = block(x) | |
return self.norm(x) | |
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
return x | |
class TokenizerDecoder(nn.Module): | |
""" | |
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio. | |
Args: | |
config: Configuration object with model parameters | |
""" | |
def __init__(self, config): | |
super().__init__() | |
# Extract parameters from config | |
self.dimension = config.dimension | |
self.channels = config.channels | |
self.n_filters = config.n_filters | |
self.ratios = config.ratios | |
# IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel | |
self.depths = config.depths # Changed from list(reversed(config.depths)) | |
self.n_residual_layers = getattr(config, "n_residual_layers", 1) | |
self.hop_length = np.prod(self.ratios) | |
self.causal = config.causal | |
# Additional config parameters with defaults | |
kernel_size = getattr(config, "kernel_size", 7) | |
last_kernel_size = getattr(config, "last_kernel_size", 7) | |
norm = getattr(config, "norm", "none") | |
norm_params = getattr(config, "norm_params", {}) | |
pad_mode = getattr(config, "pad_mode", "reflect") | |
bias = getattr(config, "bias", True) | |
layernorm = getattr(config, "layernorm", "LN") | |
layernorm_eps = getattr(config, "layernorm_eps", 1e-6) | |
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0) | |
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) | |
drop_path_rate = getattr(config, "drop_path_rate", 0.0) | |
mixer_layer = getattr(config, "mixer_layer", "conv") | |
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) | |
disable_last_norm = getattr(config, "disable_last_norm", False) | |
# determine the norm type based on layernorm | |
if layernorm == 'LN': | |
norm_type = ConvLayerNorm | |
elif layernorm == 'RMSNorm': | |
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) | |
else: | |
raise ValueError(f"Unsupported norm type: {layernorm}") | |
# stem and upsampling layers | |
stem = nn.Sequential( | |
SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm, | |
norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), | |
) | |
self.upsample_layers = nn.ModuleList() | |
self.upsample_layers.append(stem) | |
for i in range(len(self.ratios)): | |
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i)) | |
out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1)) | |
upsample_layer = nn.Sequential( | |
SConvTranspose1d(in_ch, out_ch, | |
kernel_size=self.ratios[i] * 2, stride=self.ratios[i], | |
norm=norm, norm_kwargs=norm_params, bias=bias, | |
causal=self.causal, trim_right_ratio=trim_right_ratio), | |
) | |
self.upsample_layers.append(upsample_layer) | |
# configure transformer blocks | |
layer_type = partial( | |
Block1D, | |
mixer_layer=mixer_layer, | |
layernorm=layernorm, | |
eps=layernorm_eps, | |
causal=self.causal, | |
pad_mode=pad_mode, | |
norm=norm, | |
bias=bias, | |
layer_scale_init_value=layer_scale_init_value, | |
) | |
self.stages = nn.ModuleList() | |
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] | |
cur = 0 | |
# Create stages in the same order as the original model | |
for i in range(len(self.depths)): | |
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i)) | |
stage = nn.Sequential( | |
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] | |
) | |
self.stages.append(stage) | |
cur += self.depths[i] | |
if not disable_last_norm: | |
self.norm = norm_type(in_ch, eps=layernorm_eps) | |
else: | |
self.norm = nn.Identity() | |
self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) | |
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
for i in range(len(self.depths)): | |
# Apply upsampling | |
for layer in self.upsample_layers[i]: | |
if isinstance(layer, (SConv1d, SConvTranspose1d)): | |
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
else: | |
x = layer(x) | |
# Apply stage (Block1D contains Convlayer which contains SConv1d) | |
for block in self.stages[i]: | |
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): | |
# Block1D forward with cache support | |
residual = x | |
x = block.norm(x) | |
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
if block.gamma is not None: | |
x = x * block.gamma.unsqueeze(-1) | |
x = residual + x | |
# FFN part | |
residual = x | |
x = block.ffn_norm(x) | |
x = x.permute(0, 2, 1) | |
x = block.ffn(x) | |
x = x.permute(0, 2, 1) | |
if block.ffn_gamma is not None: | |
x = x * block.ffn_gamma.unsqueeze(-1) | |
x = residual + x | |
else: | |
x = block(x) | |
return self.norm(x) | |
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
return x | |
class VibeVoiceTokenizerEncoderOutput: | |
""" | |
Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance. | |
Args: | |
mean (`torch.FloatTensor`): The mean parameters of the distribution. | |
std (`float` or `torch.FloatTensor`): Fixed standard deviation value. | |
""" | |
mean: torch.Tensor | |
std: Optional[Union[float, torch.Tensor]] = None | |
def sample(self, dist_type='fix'): | |
""" | |
Sample from the distribution. | |
Args: | |
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'. | |
Returns: | |
`torch.FloatTensor`: Sampled values. | |
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian'). | |
""" | |
if dist_type == 'fix': | |
x = self.mean + self.std * torch.randn_like(self.mean) | |
return x, self.std | |
elif dist_type == 'gaussian': | |
batch_size = self.mean.size(0) | |
value = self.std / 0.8 | |
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value | |
while std.dim() < self.mean.dim(): | |
std = std.unsqueeze(-1) | |
x = self.mean + std * torch.randn_like(self.mean) | |
return x, std | |
else: | |
return self.mean, self.std | |
def kl(self): | |
"""Compute KL divergence between this distribution and a standard normal.""" | |
target = torch.zeros_like(self.mean) | |
return F.mse_loss(self.mean, target, reduction='none') | |
def mode(self): | |
"""Return the distribution mode (which is the mean for Gaussian).""" | |
return self.mean | |
class VibeVoiceAcousticTokenizerModel(PreTrainedModel): | |
"""VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens""" | |
config_class = VibeVoiceAcousticTokenizerConfig | |
base_model_prefix = "vibevoice_acoustic_tokenizer" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False) | |
self.std_dist_type = getattr(config, "std_dist_type", "fix") | |
# Parse encoder depths | |
if isinstance(config.encoder_depths, str): | |
encoder_depths = [int(d) for d in config.encoder_depths.split('-')] | |
else: | |
encoder_depths = config.encoder_depths | |
# Parse decoder depths if provided | |
if config.decoder_depths is not None and isinstance(config.decoder_depths, str): | |
decoder_depths = [int(d) for d in config.decoder_depths.split('-')] | |
else: | |
# Default: use reversed encoder depths if decoder_depths is None | |
decoder_depths = list(reversed(encoder_depths)) | |
# Create encoder config | |
encoder_config = copy.deepcopy(config) | |
encoder_config.dimension = config.vae_dim | |
encoder_config.n_filters = config.encoder_n_filters | |
encoder_config.ratios = config.encoder_ratios | |
encoder_config.depths = encoder_depths | |
encoder_config.norm = config.conv_norm | |
encoder_config.pad_mode = config.pad_mode | |
encoder_config.bias = config.conv_bias | |
encoder_config.layernorm_eps = config.layernorm_eps | |
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine | |
encoder_config.mixer_layer = config.mixer_layer | |
encoder_config.layer_scale_init_value = config.layer_scale_init_value | |
encoder_config.disable_last_norm = config.disable_last_norm | |
# Create decoder config | |
decoder_config = copy.deepcopy(config) | |
decoder_config.dimension = config.vae_dim | |
decoder_config.n_filters = config.decoder_n_filters | |
decoder_config.ratios = config.decoder_ratios | |
decoder_config.depths = decoder_depths | |
decoder_config.norm = config.conv_norm | |
decoder_config.pad_mode = config.pad_mode | |
decoder_config.bias = config.conv_bias | |
decoder_config.layernorm_eps = config.layernorm_eps | |
decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine | |
decoder_config.mixer_layer = config.mixer_layer | |
decoder_config.layer_scale_init_value = config.layer_scale_init_value | |
decoder_config.disable_last_norm = config.disable_last_norm | |
# Initialize encoder and decoder | |
self.encoder = TokenizerEncoder(encoder_config) | |
self.decoder = TokenizerDecoder(decoder_config) | |
# Initialize weights | |
self.apply(self._init_weights) | |
def _init_weights(self, module): | |
"""Initialize weights for the model""" | |
if isinstance(module, nn.Linear): | |
nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.LayerNorm): | |
nn.init.ones_(module.weight) | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Conv1d): | |
nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
"""Convert audio to latent representations""" | |
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std) | |
def sampling(self, encoder_output, dist_type=None): | |
"""Sample from the encoder output distribution""" | |
dist_type = dist_type or self.std_dist_type | |
if dist_type == 'fix': | |
return encoder_output.sample(dist_type='fix') | |
elif dist_type == 'gaussian': | |
return encoder_output.sample(dist_type='gaussian') | |
else: | |
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'") | |
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False): | |
"""Convert latent representations back to audio""" | |
if latents.shape[1] == self.config.vae_dim: | |
pass | |
else: | |
latents = latents.permute(0, 2, 1) | |
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
return audio | |
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
"""Full forward pass: encode audio to latents, then decode back to audio""" | |
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
sampled_latents, _ = self.sampling(encoder_output) | |
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
return reconstructed, sampled_latents | |
class VibeVoiceSemanticTokenizerModel(PreTrainedModel): | |
"""VibeVoice speech tokenizer model with only encoder for semantic tokens""" | |
config_class = VibeVoiceSemanticTokenizerConfig | |
base_model_prefix = "vibevoice_semantic_tokenizer" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
_no_split_modules = ["TokenizerEncoder"] | |
def __init__(self, config): | |
super().__init__(config) | |
# Parse encoder depths | |
if isinstance(config.encoder_depths, str): | |
encoder_depths = [int(d) for d in config.encoder_depths.split('-')] | |
else: | |
encoder_depths = config.encoder_depths | |
# Create encoder config | |
encoder_config = copy.deepcopy(config) | |
encoder_config.dimension = config.vae_dim | |
encoder_config.n_filters = config.encoder_n_filters | |
encoder_config.ratios = config.encoder_ratios | |
encoder_config.depths = encoder_depths | |
encoder_config.norm = config.conv_norm | |
encoder_config.pad_mode = config.pad_mode | |
encoder_config.bias = config.conv_bias | |
encoder_config.layernorm_eps = config.layernorm_eps | |
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine | |
encoder_config.mixer_layer = config.mixer_layer | |
encoder_config.layer_scale_init_value = config.layer_scale_init_value | |
encoder_config.disable_last_norm = config.disable_last_norm | |
# Initialize encoder and decoder | |
self.encoder = TokenizerEncoder(encoder_config) | |
# Initialize weights | |
self.apply(self._init_weights) | |
def _init_weights(self, module): | |
"""Initialize weights for the model""" | |
if isinstance(module, nn.Linear): | |
nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.LayerNorm): | |
nn.init.ones_(module.weight) | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Conv1d): | |
nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
"""Convert audio to latent representations""" | |
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) | |
def sampling(self, encoder_output, dist_type=None): | |
"""Sample from the encoder output distribution""" | |
return encoder_output.sample(dist_type='none') | |
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
"""Full forward pass: encode audio to latents, then decode back to audio""" | |
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
sampled_latents, _ = self.sampling(encoder_output, dist_type='none') | |
return None, sampled_latents | |
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel) | |
AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel) | |
__all__ = [ | |
"VibeVoiceTokenizerStreamingCache", | |
"VibeVoiceAcousticTokenizerModel", | |
"VibeVoiceSemanticTokenizerModel", | |
] |