Spaces:
Running
on
Zero
Running
on
Zero
# modified from https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L23 | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
class ModelSamplingDiscreteFlow(nn.Module): | |
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" | |
def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs): | |
super().__init__() | |
self.num_train_timesteps = num_train_timesteps | |
self.shift = shift | |
ts = self.to_sigma(torch.arange(1, num_train_timesteps + 1, 1)) # [1/1000, 1] | |
self.register_buffer("sigmas", ts) | |
def sigma_min(self): | |
return self.sigmas[0] | |
def sigma_max(self): | |
return self.sigmas[-1] | |
def to_timestep(self, sigma): | |
return sigma * self.num_train_timesteps | |
def to_sigma(self, timestep: torch.Tensor): | |
timestep = timestep / self.num_train_timesteps | |
if self.shift == 1.0: | |
return timestep | |
return self.shift * timestep / (1 + (self.shift - 1) * timestep) | |
def uniform_sample_t(self, batch_size, device): | |
ts = (self.sigma_max - self.sigma_min) * torch.rand(batch_size, device=device) + self.sigma_min | |
return ts | |
def calculate_denoised(self, sigma, model_output, model_input): | |
# model ouput, vector field, v = dx = (x_1 - x_0) | |
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) | |
return model_input - model_output * sigma | |
def noise_scaling(self, sigma, noise, latent_image): | |
return sigma * noise + (1.0 - sigma) * latent_image | |
def add_noise(self, sample, noise=None, timesteps=None): | |
# sample, B, L, D | |
if timesteps is None: | |
# Sample time step | |
batch_size = sample.shape[0] | |
sigmas = self.uniform_sample_t(batch_size, device=sample.device).to(dtype=sample.dtype) # (B,) | |
timesteps = self.to_timestep(sigmas) | |
else: | |
timesteps = timesteps.to(device=sample.device, dtype=sample.dtype) | |
sigmas = self.to_sigma(timesteps) | |
sigmas = sigmas.view(-1, 1, 1) # (B, 1, 1) | |
noise = torch.randn_like(sample) | |
noisy_samples = sigmas * noise + (1.0 - sigmas) * sample | |
return noisy_samples, noise, noise - sample, timesteps | |
def set_timesteps(self, num_inference_steps, device=None): | |
if num_inference_steps > self.num_train_timesteps: | |
raise ValueError( | |
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" | |
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" | |
f" maximal {self.num_train_timesteps} timesteps." | |
) | |
self.num_inference_steps = num_inference_steps | |
start = self.to_timestep(self.sigma_max) | |
end = self.to_timestep(self.sigma_min) | |
timesteps = torch.linspace(start, end, num_inference_steps) | |
self.timesteps = torch.from_numpy(np.array(timesteps)).to(device) | |
def append_dims(self, x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
return x[(...,) + (None,) * dims_to_append] | |
def to_d(self, x, sigma, denoised): | |
"""Converts a denoiser output to a Karras ODE derivative.""" | |
return (x - denoised) / self.append_dims(sigma, x.ndim) | |
def step(self, model_output, timestep, sample, method="euler", **kwargs): | |
""" | |
Args: | |
model_output (`torch.Tensor`): | |
The direct output from learned diffusion model, direction (noise - x_0). | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.Tensor`): | |
A current instance of a sample created by the diffusion process, x_t. | |
method (`str`): | |
ODE solver, `euler` or `dpmpp_2m` | |
Returns: | |
`tuple`: | |
the sample tensor. | |
""" | |
if self.num_inference_steps is None: | |
raise ValueError( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
) | |
sigma = self.to_sigma(timestep) | |
prev_sigma = sigma - (self.sigma_max - self.sigma_min) / (self.num_inference_steps - 1) | |
prev_sigma = 0.0 if prev_sigma < 0.0 else prev_sigma | |
if method == "euler": | |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" | |
dt = prev_sigma - sigma | |
prev_sample = sample + model_output * dt | |
elif method == "dpmpp_2m": | |
"""DPM-Solver++(2M).""" | |
raise NotImplementedError | |
else: | |
raise ValueError(f"Unsupported ode solver: {method}, only supports `euler` or `dpmpp_2m`") | |
pred_original_sample = sample - model_output * sigma | |
return ( | |
prev_sample, | |
pred_original_sample | |
) | |
def get_pred_original_sample(self, model_output, timestep, sample): | |
sigma = self.to_sigma(timestep).view(-1, 1, 1) | |
pred_original_sample = sample - model_output * sigma | |
return pred_original_sample |