#!/usr/bin/env python3 # Copyright 2024 Xiaomi Corp. (authors: Han Zhu) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import torch class DiffusionModel(torch.nn.Module): """A wrapper of diffusion models for inference. Args: model: The diffusion model. func_name: The function name to call. """ def __init__( self, model: torch.nn.Module, func_name: str = "forward_fm_decoder", ): super().__init__() self.model = model self.func_name = func_name self.model_func = getattr(self.model, func_name) def forward( self, t: torch.Tensor, x: torch.Tensor, text_condition: torch.Tensor, speech_condition: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, guidance_scale: Union[float, torch.Tensor] = 0.0, **kwargs ) -> torch.Tensor: """ Forward function that Handles the classifier-free guidance. Args: t: The current timestep, a tensor of a tensor of a single float. x: The initial value, with the shape (batch, seq_len, emb_dim). text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim). speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim). padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1). Retrun: The prediction with the shape (batch, seq_len, emb_dim). """ if not torch.is_tensor(guidance_scale): guidance_scale = torch.tensor( guidance_scale, dtype=t.dtype, device=t.device ) if (guidance_scale == 0.0).all(): return self.model_func( t=t, xt=x, text_condition=text_condition, speech_condition=speech_condition, padding_mask=padding_mask, **kwargs ) else: assert t.dim() == 0 x = torch.cat([x] * 2, dim=0) padding_mask = torch.cat([padding_mask] * 2, dim=0) text_condition = torch.cat( [torch.zeros_like(text_condition), text_condition], dim=0 ) if t > 0.5: speech_condition = torch.cat( [torch.zeros_like(speech_condition), speech_condition], dim=0 ) else: guidance_scale = guidance_scale * 2 speech_condition = torch.cat( [speech_condition, speech_condition], dim=0 ) data_uncond, data_cond = self.model_func( t=t, xt=x, text_condition=text_condition, speech_condition=speech_condition, padding_mask=padding_mask, **kwargs ).chunk(2, dim=0) res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond return res class DistillDiffusionModel(DiffusionModel): """A wrapper of distilled diffusion models for inference. Args: model: The distilled diffusion model. func_name: The function name to call. """ def __init__( self, model: torch.nn.Module, func_name: str = "forward_fm_decoder", ): super().__init__(model=model, func_name=func_name) def forward( self, t: torch.Tensor, x: torch.Tensor, text_condition: torch.Tensor, speech_condition: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, guidance_scale: Union[float, torch.Tensor] = 0.0, **kwargs ) -> torch.Tensor: """ Forward function that Handles the classifier-free guidance. Args: t: The current timestep, a tensor of a single float. x: The initial value, with the shape (batch, seq_len, emb_dim). text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim). speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim). padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1). Retrun: The prediction with the shape (batch, seq_len, emb_dim). """ if not torch.is_tensor(guidance_scale): guidance_scale = torch.tensor( guidance_scale, dtype=t.dtype, device=t.device ) return self.model_func( t=t, xt=x, text_condition=text_condition, speech_condition=speech_condition, padding_mask=padding_mask, guidance_scale=guidance_scale, **kwargs ) class EulerSolver: def __init__( self, model: torch.nn.Module, func_name: str = "forward_fm_decoder", ): """Construct a Euler Solver Args: model: The diffusion model. func_name: The function name to call. """ self.model = DiffusionModel(model, func_name=func_name) def sample( self, x: torch.Tensor, text_condition: torch.Tensor, speech_condition: torch.Tensor, padding_mask: torch.Tensor, num_step: int = 10, guidance_scale: Union[float, torch.Tensor] = 0.0, t_start: float = 0.0, t_end: float = 1.0, t_shift: float = 1.0, **kwargs ) -> torch.Tensor: """ Compute the sample at time `t_end` by Euler Solver. Args: x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim). text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim). speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim). padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). num_step: The number of ODE steps. guidance_scale: The scale for classifier-free guidance, which is a float or a tensor with the shape (batch, 1, 1). t_start: the start timestep in the range of [0, 1]. t_end: the end time_step in the range of [0, 1]. t_shift: shift the t toward smaller numbers so that the sampling will emphasize low SNR region. Should be in the range of (0, 1]. The shifting will be more significant when the number is smaller. Returns: The approximated solution at time `t_end`. """ device = x.device assert isinstance(t_start, float) and isinstance(t_end, float) timesteps = get_time_steps( t_start=t_start, t_end=t_end, num_step=num_step, t_shift=t_shift, device=device, ) for step in range(num_step): v = self.model( t=timesteps[step], x=x, text_condition=text_condition, speech_condition=speech_condition, padding_mask=padding_mask, guidance_scale=guidance_scale, **kwargs ) x = x + v * (timesteps[step + 1] - timesteps[step]) return x class DistillEulerSolver(EulerSolver): def __init__( self, model: torch.nn.Module, func_name: str = "forward_fm_decoder", ): """Construct a Euler Solver for distilled diffusion models. Args: model: The diffusion model. """ self.model = DistillDiffusionModel(model, func_name=func_name) def get_time_steps( t_start: float = 0.0, t_end: float = 1.0, num_step: int = 10, t_shift: float = 1.0, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Compute the intermediate time steps for sampling. Args: t_start: The starting time of the sampling (default is 0). t_end: The starting time of the sampling (default is 1). num_step: The number of sampling. t_shift: shift the t toward smaller numbers so that the sampling will emphasize low SNR region. Should be in the range of (0, 1]. The shifting will be more significant when the number is smaller. device: A torch device. Returns: The time step with the shape (num_step + 1,). """ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device) timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) return timesteps