File size: 5,611 Bytes
0eb032f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import math

import torch


class EnhancedDDIMScheduler:
    def __init__(
        self,
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        prediction_type="epsilon",
        rescale_zero_terminal_snr=False,
    ):
        self.num_train_timesteps = num_train_timesteps
        if beta_schedule == "scaled_linear":
            betas = torch.square(
                torch.linspace(
                    math.sqrt(beta_start),
                    math.sqrt(beta_end),
                    num_train_timesteps,
                    dtype=torch.float32,
                )
            )
        elif beta_schedule == "linear":
            betas = torch.linspace(
                beta_start, beta_end, num_train_timesteps, dtype=torch.float32
            )
        else:
            raise NotImplementedError(f"{beta_schedule} is not implemented")
        self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
        if rescale_zero_terminal_snr:
            self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
        self.alphas_cumprod = self.alphas_cumprod.tolist()
        self.set_timesteps(10)
        self.prediction_type = prediction_type

    def rescale_zero_terminal_snr(self, alphas_cumprod):
        alphas_bar_sqrt = alphas_cumprod.sqrt()

        # Store old values.
        alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
        alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

        # Shift so the last timestep is zero.
        alphas_bar_sqrt -= alphas_bar_sqrt_T

        # Scale so the first timestep is back to the old value.
        alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

        # Convert alphas_bar_sqrt to betas
        alphas_bar = alphas_bar_sqrt.square()  # Revert sqrt

        return alphas_bar

    def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
        # The timesteps are aligned to 999...0, which is different from other implementations,
        # but I think this implementation is more reasonable in theory.
        max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
        num_inference_steps = min(num_inference_steps, max_timestep + 1)
        if num_inference_steps == 1:
            self.timesteps = torch.Tensor([max_timestep])
        else:
            step_length = max_timestep / (num_inference_steps - 1)
            self.timesteps = torch.Tensor(
                [
                    round(max_timestep - i * step_length)
                    for i in range(num_inference_steps)
                ]
            )

    def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
        if self.prediction_type == "epsilon":
            weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(
                alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t
            )
            weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
            prev_sample = sample * weight_x + model_output * weight_e
        elif self.prediction_type == "v_prediction":
            weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(
                alpha_prod_t * (1 - alpha_prod_t_prev)
            )
            weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt(
                (1 - alpha_prod_t) * (1 - alpha_prod_t_prev)
            )
            prev_sample = sample * weight_x + model_output * weight_e
        else:
            raise NotImplementedError(f"{self.prediction_type} is not implemented")
        return prev_sample

    def step(self, model_output, timestep, sample, to_final=False):
        alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
        if isinstance(timestep, torch.Tensor):
            timestep = timestep.cpu()
        timestep_id = torch.argmin((self.timesteps - timestep).abs())
        if to_final or timestep_id + 1 >= len(self.timesteps):
            alpha_prod_t_prev = 1.0
        else:
            timestep_prev = int(self.timesteps[timestep_id + 1])
            alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]

        return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)

    def return_to_timestep(self, timestep, sample, sample_stablized):
        alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
        noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(
            1 - alpha_prod_t
        )
        return noise_pred

    def add_noise(self, original_samples, noise, timestep):
        sqrt_alpha_prod = math.sqrt(
            self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
        )
        sqrt_one_minus_alpha_prod = math.sqrt(
            1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
        )
        noisy_samples = (
            sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        )
        return noisy_samples

    def training_target(self, sample, noise, timestep):
        if self.prediction_type == "epsilon":
            return noise
        else:
            sqrt_alpha_prod = math.sqrt(
                self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
            )
            sqrt_one_minus_alpha_prod = math.sqrt(
                1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
            )
            target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
            return target

    def training_weight(self, timestep):
        return 1.0