|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
import torch |
|
|
|
|
|
class DiffusionModel(torch.nn.Module): |
|
"""A wrapper of diffusion models for inference. |
|
Args: |
|
model: The diffusion model. |
|
func_name: The function name to call. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: torch.nn.Module, |
|
func_name: str = "forward_fm_decoder", |
|
): |
|
super().__init__() |
|
self.model = model |
|
self.func_name = func_name |
|
self.model_func = getattr(self.model, func_name) |
|
|
|
def forward( |
|
self, |
|
t: torch.Tensor, |
|
x: torch.Tensor, |
|
text_condition: torch.Tensor, |
|
speech_condition: torch.Tensor, |
|
padding_mask: Optional[torch.Tensor] = None, |
|
guidance_scale: Union[float, torch.Tensor] = 0.0, |
|
**kwargs |
|
) -> torch.Tensor: |
|
""" |
|
Forward function that Handles the classifier-free guidance. |
|
Args: |
|
t: The current timestep, a tensor of a tensor of a single float. |
|
x: The initial value, with the shape (batch, seq_len, emb_dim). |
|
text_condition: The text_condition of the diffision model, with |
|
the shape (batch, seq_len, emb_dim). |
|
speech_condition: The speech_condition of the diffision model, with the |
|
shape (batch, seq_len, emb_dim). |
|
padding_mask: The mask for padding; True means masked position, with the |
|
shape (batch, seq_len). |
|
guidance_scale: The scale of classifier-free guidance, a float or a tensor |
|
of shape (batch, 1, 1). |
|
Retrun: |
|
The prediction with the shape (batch, seq_len, emb_dim). |
|
""" |
|
if not torch.is_tensor(guidance_scale): |
|
guidance_scale = torch.tensor( |
|
guidance_scale, dtype=t.dtype, device=t.device |
|
) |
|
|
|
if (guidance_scale == 0.0).all(): |
|
return self.model_func( |
|
t=t, |
|
xt=x, |
|
text_condition=text_condition, |
|
speech_condition=speech_condition, |
|
padding_mask=padding_mask, |
|
**kwargs |
|
) |
|
else: |
|
assert t.dim() == 0 |
|
|
|
x = torch.cat([x] * 2, dim=0) |
|
padding_mask = torch.cat([padding_mask] * 2, dim=0) |
|
|
|
text_condition = torch.cat( |
|
[torch.zeros_like(text_condition), text_condition], dim=0 |
|
) |
|
|
|
if t > 0.5: |
|
speech_condition = torch.cat( |
|
[torch.zeros_like(speech_condition), speech_condition], dim=0 |
|
) |
|
else: |
|
guidance_scale = guidance_scale * 2 |
|
speech_condition = torch.cat( |
|
[speech_condition, speech_condition], dim=0 |
|
) |
|
|
|
data_uncond, data_cond = self.model_func( |
|
t=t, |
|
xt=x, |
|
text_condition=text_condition, |
|
speech_condition=speech_condition, |
|
padding_mask=padding_mask, |
|
**kwargs |
|
).chunk(2, dim=0) |
|
|
|
res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond |
|
return res |
|
|
|
|
|
class DistillDiffusionModel(DiffusionModel): |
|
"""A wrapper of distilled diffusion models for inference. |
|
Args: |
|
model: The distilled diffusion model. |
|
func_name: The function name to call. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: torch.nn.Module, |
|
func_name: str = "forward_fm_decoder", |
|
): |
|
super().__init__(model=model, func_name=func_name) |
|
|
|
def forward( |
|
self, |
|
t: torch.Tensor, |
|
x: torch.Tensor, |
|
text_condition: torch.Tensor, |
|
speech_condition: torch.Tensor, |
|
padding_mask: Optional[torch.Tensor] = None, |
|
guidance_scale: Union[float, torch.Tensor] = 0.0, |
|
**kwargs |
|
) -> torch.Tensor: |
|
""" |
|
Forward function that Handles the classifier-free guidance. |
|
Args: |
|
t: The current timestep, a tensor of a single float. |
|
x: The initial value, with the shape (batch, seq_len, emb_dim). |
|
text_condition: The text_condition of the diffision model, with |
|
the shape (batch, seq_len, emb_dim). |
|
speech_condition: The speech_condition of the diffision model, with the |
|
shape (batch, seq_len, emb_dim). |
|
padding_mask: The mask for padding; True means masked position, with the |
|
shape (batch, seq_len). |
|
guidance_scale: The scale of classifier-free guidance, a float or a tensor |
|
of shape (batch, 1, 1). |
|
Retrun: |
|
The prediction with the shape (batch, seq_len, emb_dim). |
|
""" |
|
if not torch.is_tensor(guidance_scale): |
|
guidance_scale = torch.tensor( |
|
guidance_scale, dtype=t.dtype, device=t.device |
|
) |
|
return self.model_func( |
|
t=t, |
|
xt=x, |
|
text_condition=text_condition, |
|
speech_condition=speech_condition, |
|
padding_mask=padding_mask, |
|
guidance_scale=guidance_scale, |
|
**kwargs |
|
) |
|
|
|
|
|
class EulerSolver: |
|
def __init__( |
|
self, |
|
model: torch.nn.Module, |
|
func_name: str = "forward_fm_decoder", |
|
): |
|
"""Construct a Euler Solver |
|
Args: |
|
model: The diffusion model. |
|
func_name: The function name to call. |
|
""" |
|
|
|
self.model = DiffusionModel(model, func_name=func_name) |
|
|
|
def sample( |
|
self, |
|
x: torch.Tensor, |
|
text_condition: torch.Tensor, |
|
speech_condition: torch.Tensor, |
|
padding_mask: torch.Tensor, |
|
num_step: int = 10, |
|
guidance_scale: Union[float, torch.Tensor] = 0.0, |
|
t_start: float = 0.0, |
|
t_end: float = 1.0, |
|
t_shift: float = 1.0, |
|
**kwargs |
|
) -> torch.Tensor: |
|
""" |
|
Compute the sample at time `t_end` by Euler Solver. |
|
Args: |
|
x: The initial value at time `t_start`, with the shape (batch, seq_len, |
|
emb_dim). |
|
text_condition: The text condition of the diffision mode, with the |
|
shape (batch, seq_len, emb_dim). |
|
speech_condition: The speech condition of the diffision model, with the |
|
shape (batch, seq_len, emb_dim). |
|
padding_mask: The mask for padding; True means masked position, with the |
|
shape (batch, seq_len). |
|
num_step: The number of ODE steps. |
|
guidance_scale: The scale for classifier-free guidance, which is |
|
a float or a tensor with the shape (batch, 1, 1). |
|
t_start: the start timestep in the range of [0, 1]. |
|
t_end: the end time_step in the range of [0, 1]. |
|
t_shift: shift the t toward smaller numbers so that the sampling |
|
will emphasize low SNR region. Should be in the range of (0, 1]. |
|
The shifting will be more significant when the number is smaller. |
|
|
|
Returns: |
|
The approximated solution at time `t_end`. |
|
""" |
|
device = x.device |
|
assert isinstance(t_start, float) and isinstance(t_end, float) |
|
|
|
timesteps = get_time_steps( |
|
t_start=t_start, |
|
t_end=t_end, |
|
num_step=num_step, |
|
t_shift=t_shift, |
|
device=device, |
|
) |
|
|
|
for step in range(num_step): |
|
v = self.model( |
|
t=timesteps[step], |
|
x=x, |
|
text_condition=text_condition, |
|
speech_condition=speech_condition, |
|
padding_mask=padding_mask, |
|
guidance_scale=guidance_scale, |
|
**kwargs |
|
) |
|
x = x + v * (timesteps[step + 1] - timesteps[step]) |
|
return x |
|
|
|
|
|
class DistillEulerSolver(EulerSolver): |
|
def __init__( |
|
self, |
|
model: torch.nn.Module, |
|
func_name: str = "forward_fm_decoder", |
|
): |
|
"""Construct a Euler Solver for distilled diffusion models. |
|
Args: |
|
model: The diffusion model. |
|
""" |
|
self.model = DistillDiffusionModel(model, func_name=func_name) |
|
|
|
|
|
def get_time_steps( |
|
t_start: float = 0.0, |
|
t_end: float = 1.0, |
|
num_step: int = 10, |
|
t_shift: float = 1.0, |
|
device: torch.device = torch.device("cpu"), |
|
) -> torch.Tensor: |
|
"""Compute the intermediate time steps for sampling. |
|
|
|
Args: |
|
t_start: The starting time of the sampling (default is 0). |
|
t_end: The starting time of the sampling (default is 1). |
|
num_step: The number of sampling. |
|
t_shift: shift the t toward smaller numbers so that the sampling |
|
will emphasize low SNR region. Should be in the range of (0, 1]. |
|
The shifting will be more significant when the number is smaller. |
|
device: A torch device. |
|
Returns: |
|
The time step with the shape (num_step + 1,). |
|
""" |
|
|
|
timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device) |
|
|
|
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) |
|
|
|
return timesteps |
|
|