import math import torch import torch.nn.functional as F from torch import nn from einops import reduce from tqdm.auto import tqdm from functools import partial from .transformer import Transformer from ..model_utils import default, identity, extract from .control import * import mlflow.pyfunc import mlflow from mlflow.models import infer_signature # import matplotlib.pyplot as plt # images_cache = [] def linear_beta_schedule(timesteps): scale = 1000 / timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) class Tiffusion(nn.Module): def __init__( self, seq_length, feature_size, n_layer_enc=3, n_layer_dec=6, d_model=None, timesteps=1000, sampling_timesteps=None, loss_type="l1", beta_schedule="cosine", n_heads=4, mlp_hidden_times=4, eta=0.0, attn_pd=0.0, resid_pd=0.0, kernel_size=None, padding_size=None, use_ff=True, reg_weight=None, control_signal={}, moving_average=False, **kwargs, ): super(Tiffusion, self).__init__() self.eta, self.use_ff = eta, use_ff self.seq_length = seq_length self.feature_size = feature_size self.ff_weight = default(reg_weight, math.sqrt(self.seq_length) / 5) self.sum_weight = default(reg_weight, math.sqrt(self.seq_length // 10) / 50) self.training_control_signal = control_signal # training control signal self.moving_average = moving_average self.model: Transformer = Transformer( n_feat=feature_size, n_channel=seq_length, n_layer_enc=n_layer_enc, n_layer_dec=n_layer_dec, n_heads=n_heads, attn_pdrop=attn_pd, resid_pdrop=resid_pd, mlp_hidden_times=mlp_hidden_times, max_len=seq_length, n_embd=d_model, conv_params=[kernel_size, padding_size], **kwargs, ) if beta_schedule == "linear": betas = linear_beta_schedule(timesteps) elif beta_schedule == "cosine": betas = cosine_beta_schedule(timesteps) else: raise ValueError(f"unknown beta schedule {beta_schedule}") alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type # sampling related parameters self.sampling_timesteps = default( sampling_timesteps, timesteps ) # default num sampling timesteps to number of timesteps at training assert self.sampling_timesteps <= timesteps self.fast_sampling = self.sampling_timesteps < timesteps # helper function to register buffer from float64 to float32 register_buffer = lambda name, val: self.register_buffer( name, val.to(torch.float32) ) register_buffer("betas", betas) register_buffer("alphas_cumprod", alphas_cumprod) register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) register_buffer( "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) ) register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) register_buffer( "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer("posterior_variance", posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer( "posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)), ) register_buffer( "posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), ) register_buffer( "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), ) # calculate reweighting register_buffer( "loss_weight", torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100, ) def predict_noise_from_start(self, x_t, t, x0): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0 ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract( self.posterior_log_variance_clipped, t, x_t.shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def output(self, x, t, padding_masks=None, control_signal=None): # if ss:=control_signal.get("sum") is not None and len(ss.shape) == 1: # bs = x.shape[0] # control_signal["sum"] = ss.unsqueeze(0).repeat(bs, 1) # print("control_signal", control_signal) trend, season = self.model( x, t, padding_masks=padding_masks, control_signal=control_signal ) model_output = trend + season return model_output def model_predictions( self, x, t, clip_x_start=False, padding_masks=None, control_signal=None ): if padding_masks is None: padding_masks = torch.ones( x.shape[0], self.seq_length, dtype=bool, device=x.device ) maybe_clip = ( partial(torch.clamp, min=-1.0, max=1.0) if clip_x_start else identity ) x_start = self.output(x, t, padding_masks, control_signal=control_signal) x_start = maybe_clip(x_start) pred_noise = self.predict_noise_from_start(x, t, x_start) return pred_noise, x_start def p_mean_variance(self, x, t, clip_denoised=True, control_signal=None): _, x_start = self.model_predictions(x, t, control_signal=control_signal) if clip_denoised: x_start.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_start, x_t=x, t=t ) return model_mean, posterior_variance, posterior_log_variance, x_start def p_sample(self, x, t: int, clip_denoised=True, control_signal=None): batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) model_mean, _, model_log_variance, x_start = self.p_mean_variance( x=x, t=batched_times, clip_denoised=clip_denoised, control_signal=control_signal ) noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img, x_start @torch.no_grad() def sample(self, shape, control_signal=None): device = self.betas.device img = torch.randn(shape, device=device) for t in tqdm( reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps, ): img, _ = self.p_sample(img, t, control_signal=control_signal) return img @torch.no_grad() def fast_sample(self, shape, clip_denoised=True, model_kwargs=None, ): batch, device, total_timesteps, sampling_timesteps, eta = ( shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.eta, ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) times = list(reversed(times.int().tolist())) time_pairs = list( zip(times[:-1], times[1:]) ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] img = torch.randn(shape, device=device) for time, time_next in tqdm(time_pairs, desc="sampling loop time step"): time_cond = torch.full((batch,), time, device=device, dtype=torch.long) pred_noise, x_start, *_ = self.model_predictions( img, time_cond, clip_x_start=clip_denoised, control_signal=model_kwargs.get("model_control_signal", {}) if model_kwargs else {} ) if time_next < 0: img = x_start continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = ( eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() ) c = (1 - alpha_next - sigma**2).sqrt() noise = torch.randn_like(img) img = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise return img def generate_mts(self, batch_size=16): feature_size, seq_length = self.feature_size, self.seq_length sample_fn = self.fast_sample if self.fast_sampling else self.sample return sample_fn((batch_size, seq_length, feature_size)) def generate_mts_infill(self, target, partial_mask=None, clip_denoised=True, model_kwargs=None): sample_fn = self.fast_sample_infill_float_mask # if self.fast_sampling else self.sample_infill print("model_kwargs", model_kwargs) print("partial_mask", partial_mask.shape) print("target", target.shape) return sample_fn( shape=target.shape, target=target, sampling_timesteps=self.sampling_timesteps, partial_mask=partial_mask, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) @property def loss_fn(self): if self.loss_type == "l1": return F.l1_loss elif self.loss_type == "l2": return F.mse_loss else: raise ValueError(f"invalid loss type {self.loss_type}") def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) @torch.no_grad() def calculate_dynamic_window(self, t: torch.Tensor) -> torch.Tensor: # Batch-wise time point normalization t_min = 0 # t.min() t_max = 500 # t.max() # t_normalized = (t - t_min) / (t_max - t_min) # Compute window sizes # windows = ((t_normalized.exp2() - 1) * 15 // 1 + 1).long() # plt.scatter(t, ( (5 ** ((t - 0) / 1000))) * 15 // 7 + 1) windows = ((5 ** ( t / 500)) * 15 // 5 - 2).long() return windows @torch.no_grad() def torch_moving_average(self, bs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Compute moving average for a time series tensor with dynamically calculated window sizes for each sample. Parameters: ----------- bs : torch.Tensor Input time series tensor of shape (batch_size, sequence_length, features) t : torch.Tensor Time points tensor of shape (batch_size, sequence_length) Returns: -------- torch.Tensor Moving average tensor with the same shape as input """ # Get tensor dimensions batch_size, total_seq_length, num_features = bs.shape # Calculate dynamic window sizes for each sample windows = self.calculate_dynamic_window(t) # Create output tensor initialized with the original values moving_avg = bs.clone() # Compute moving average for each sample and time point for b in range(batch_size): for i in range(total_seq_length): # Get the window size for this sample and time point current_window = windows[b].item() # Determine the start and end of the window start = max(0, i - current_window + 1) window = bs[b:b+1, start:i+1, :] # Compute average along the time dimension window_avg = window.mean(dim=1) # Replace values where we have enough previous steps if i >= current_window - 1: moving_avg[b, i, :] = window_avg return moving_avg def _train_loss( self, x_start, t, target=None, noise=None, padding_masks=None, control_signal=None, ): noise = default(noise, lambda: torch.randn_like(x_start)) if target is None: target = x_start x = self.q_sample(x_start=x_start, t=t, noise=noise) # noise sample # with torch.no_grad(): # if control_signal is None: # control_signal = { # "sum": target.mean(1), # "top-peak-position": target.topk(self.seq_length // 20, dim=1)[1], # } # .unsqueeze(-1) # # elif self.control_sum: # # ss = control_signal.get("sum") # # if len(ss.shape) == 1: # # bs = x.shape[0] # # control_signal["sum"] = ss.unsqueeze(0).repeat(bs, 1) # # control_signal = control_signal # else: # control_signal = {} model_out = self.output(x, t, padding_masks, control_signal=control_signal) # moving average according to the timestamp t, t larger means more stable, less noise if self.moving_average: target = self.torch_moving_average(target.cpu(), t.cpu()).to(model_out.device) train_loss = self.loss_fn(model_out, target, reduction="none") fourier_loss = torch.tensor([0.0]) if self.use_ff: fft1 = torch.fft.fft(model_out.transpose(1, 2), norm="forward") fft2 = torch.fft.fft(target.transpose(1, 2), norm="forward") fft1, fft2 = fft1.transpose(1, 2), fft2.transpose(1, 2) fourier_loss = self.loss_fn( torch.real(fft1), torch.real(fft2), reduction="none" ) + self.loss_fn(torch.imag(fft1), torch.imag(fft2), reduction="none") train_loss += self.ff_weight * fourier_loss # if self.control_sum: # train_loss += ( # self.loss_fn(model_out[..., 0].sum(1), target[..., 0].sum(1)) # / self.seq_length # ) # * self.sum_weight train_loss = reduce(train_loss, "b ... -> b (...)", "mean") train_loss = train_loss * extract(self.loss_weight, t, train_loss.shape) return train_loss.mean() # fmt: off def forward(self, x, **kwargs): b, c, n, device, feature_size, = *x.shape, x.device, self.feature_size assert n == feature_size, f'number of variable must be {feature_size}' t = torch.randint(0, self.num_timesteps, (b,), device=device).long() return self._train_loss(x_start=x, t=t, **kwargs) def return_components(self, x, t: int): b, c, n, device, feature_size, = *x.shape, x.device, self.feature_size assert n == feature_size, f'number of variable must be {feature_size}' t = torch.tensor([t]) t = t.repeat(b).to(device) x = self.q_sample(x, t) trend, season, residual = self.model(x, t, return_res=True) return trend, season, residual, x # fmt: on def fast_sample_infill_float_mask( self, shape, target: torch.Tensor, # target time series # [B, L, C] sampling_timesteps, partial_mask: torch.Tensor = None, # float mask between 0 and 1 # [B, L, C] clip_denoised=True, model_kwargs=None, ): batch, device, total_timesteps, eta = ( shape[0], self.betas.device, self.num_timesteps, self.eta, ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) times = list(reversed(times.int().tolist())) time_pairs = list( zip(times[:-1], times[1:]) ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] # Initialize with noise img = torch.randn(shape, device=device) # [B, L, C] for time, time_next in tqdm( time_pairs, desc="conditional sampling loop time step" ): time_cond = torch.full((batch,), time, device=device, dtype=torch.long) pred_noise, x_start, *_ = self.model_predictions( img, time_cond, clip_x_start=clip_denoised, control_signal=model_kwargs.get("model_control_signal", {}), ) if time_next < 0: img = x_start continue # Compute the predicted mean alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = ( eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() ) c = (1 - alpha_next - sigma**2).sqrt() noise = torch.randn_like(img) pred_mean = x_start * alpha_next.sqrt() + c * pred_noise img = pred_mean + sigma * noise # # Apply partial mask to the current sample # if partial_mask is not None: # target_t = self.q_sample(target, t=time_cond) # img = img * (1.0 - partial_mask) + target_t * partial_mask # Langevin Dynamics part for additional gradient updates img = self.langevin_fn( sample=img, mean=pred_mean, sigma=sigma, t=time_cond, tgt_embs=target, partial_mask=partial_mask, enable_float_mask=True, **model_kwargs, ) img = img * (1 - partial_mask) + target * partial_mask img = img * (1 - partial_mask) + target * partial_mask return img def fast_sample_infill( self, shape, target, sampling_timesteps, partial_mask=None, clip_denoised=True, model_kwargs=None, ): batch, device, total_timesteps, eta = ( shape[0], self.betas.device, self.num_timesteps, self.eta, ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) times = list(reversed(times.int().tolist())) time_pairs = list( zip(times[:-1], times[1:]) ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] img = torch.randn(shape, device=device) for time, time_next in tqdm( time_pairs, desc="conditional sampling loop time step" ): time_cond = torch.full((batch,), time, device=device, dtype=torch.long) pred_noise, x_start, *_ = self.model_predictions( img, time_cond, clip_x_start=clip_denoised, control_signal=model_kwargs.get("model_control_signal", {}), ) if time_next < 0: img = x_start continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = ( eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() ) c = (1 - alpha_next - sigma**2).sqrt() pred_mean = x_start * alpha_next.sqrt() + c * pred_noise noise = torch.randn_like(img) img = pred_mean + sigma * noise img = self.langevin_fn( sample=img, mean=pred_mean, sigma=sigma, t=time_cond, tgt_embs=target, partial_mask=partial_mask, # gradient_control_signal=model_kwargs.get("gradient_control_signal", {}), # model_control_signal=model_kwargs.get("model_control_signal", {}), **model_kwargs, ) target_t = self.q_sample(target, t=time_cond) img[partial_mask] = target_t[partial_mask] img[partial_mask] = target[partial_mask] return img def sample_infill( self, shape, target, partial_mask=None, clip_denoised=True, model_kwargs=None, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. """ batch, device = shape[0], self.betas.device img = torch.randn(shape, device=device) for t in tqdm( reversed(range(0, self.num_timesteps)), desc="conditional sampling loop time step", total=self.num_timesteps, ): img = self.p_sample_infill( x=img, t=t, clip_denoised=clip_denoised, target=target, partial_mask=partial_mask, model_kwargs=model_kwargs, ) img[partial_mask] = target[partial_mask] return img def p_sample_infill( self, x, target, t: int, partial_mask=None, clip_denoised=True, model_kwargs=None, ): b, *_, device = *x.shape, self.betas.device batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) model_mean, _, model_log_variance, _ = self.p_mean_variance( x=x, t=batched_times, clip_denoised=clip_denoised, control_signal=model_kwargs.get("model_control_signal", {}) # don't pass parameters to control signal, for model itself # Otherwise pass: control_signal=model_kwargs.get("control_signal", {}) ) noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0 sigma = (0.5 * model_log_variance).exp() pred_img = model_mean + sigma * noise pred_img = self.langevin_fn( sample=pred_img, mean=model_mean, sigma=sigma, t=batched_times, tgt_embs=target, partial_mask=partial_mask, # control_signal=model_kwargs.get("gradient_control_signal", {}), **model_kwargs, ) # fix point (must passed points) target_t = self.q_sample(target, t=batched_times) pred_img[partial_mask] = target_t[partial_mask] return pred_img @staticmethod def classifier_guidance( x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, classifier: torch.nn.Module ): with torch.enable_grad(): # 激活梯度计算 x_with_grad = x.detach().requires_grad_(True) # 获取 log 形式的概率分布 logits = classifier(x_with_grad, t) log_prob = F.log_softmax(logits, dim=-1) # 选取出 y 对应的项 selected = log_prob[range(len(logits)), y.view(-1)] # 计算梯度 return torch.autograd.grad(selected.sum(), x_with_grad)[0] @staticmethod def regression_guidance( x: torch.Tensor, t: torch.Tensor, target_sum: torch.Tensor, # Target sum value sigma: float = 1.0 ): """ Compute gradient for guiding the sum of first channel to match target value Args: x: Input tensor [batch_size, channels, length] or [batch_size, length, channels] t: Time steps target_sum: Target sum value [batch_size] sigma: Standard deviation for Gaussian likelihood """ # with torch.enable_grad(): # x_with_grad = x.detach().requires_grad_(True) # normalize to 0, 1 # x_with_grad = (x + x.min()) / (x.max() - x.min()) # x_with_grad = x / 2 + 0.5 # [-1,1 to 0,1] x_with_grad = x # Calculate sum of first channel/feature # Assuming x shape is [batch_size, channels, length] or [batch_size, length, channels] if x_with_grad.dim() == 3: if x_with_grad.shape[1] < x_with_grad.shape[2]: # [B, C, L] current_sum = x_with_grad[:1, 0] current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1] print("Current Sum: ", current_sum.max().item(), current_sum.min().item()) current_sum = current_sum.sum(dim=1) # Sum over length else: # [B, L, C] current_sum = x_with_grad[:1, :, 0] current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1] print("Current Sum: ", current_sum.max().item(), current_sum.min().item()) current_sum = current_sum.sum(dim=1) # Sum over length # Compute log probability under Gaussian distribution sigma = torch.log(t) / 5 print("sigma", sigma) if sigma.mean() == 0: pred_std = torch.ones_like(current_sum) else: pred_std = torch.ones_like(current_sum) * sigma log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ (target_sum - current_sum)**2 / (2 * pred_std**2) # print(target_sum, current_sum) # print("Current Sum: ", current_sum.mean().item()) # print("Current Diff: ", (target_sum - current_sum).mean().item()) return log_prob.mean() # return torch.autograd.grad(log_prob.sum(), x_with_grad)[0] def langevin_fn( self, coef, partial_mask, tgt_embs, learning_rate, sample, mean, sigma, t, coef_=0.0, gradient_control_signal={}, model_control_signal={}, **kwargs, ): # we thus run more gradient updates at large diffusion step t to guide the generation then # reduce the number of gradient steps in stages to accelerate sampling. if t[0].item() < self.num_timesteps * 0.02 : K = 0 elif t[0].item() > self.num_timesteps * 0.9: K = 3 elif t[0].item() > self.num_timesteps * 0.75: K = 2 learning_rate = learning_rate * 0.5 else: K = 1 learning_rate = learning_rate * 0.25 input_embs_param = torch.nn.Parameter(sample) # 获取时间相关的权重调整因子 time_weight = get_time_dependent_weights(t[0], self.num_timesteps) with torch.enable_grad(): for iteration in range(K): # x_i+1 = x_i + noise * grad(logp(x_i)) + sqrt(2*noise) * z_i optimizer = torch.optim.Adagrad([input_embs_param], lr=learning_rate) optimizer.zero_grad() x_start = self.output( x=input_embs_param, t=t, control_signal=model_control_signal, ) if sigma.mean() == 0: logp_term = ( coef * ((mean - input_embs_param) ** 2 / 1.0).mean(dim=0).sum() ) # determine the partical_mask is float if kwargs.get("enable_float_mask", False): infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2 else: infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2 infill_loss = infill_loss.mean(dim=0).sum() else: logp_term = ( coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum() ) if kwargs.get("enable_float_mask", False): infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2 else: infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2 infill_loss = (infill_loss / sigma.mean()).mean(dim=0).sum() # 第二个等号后面最后一项消失了,因为当我们要求模型生成“狗”的图像时,扩散过程始终 # 不变,对应的梯度也是0,可以抹掉。 # https://lichtung612.github.io/posts/3-diffusion-models/ # 第三个等号后面两项中,第一项是扩散模型本身的梯度引导,新增的只能是第二项,即classifier guidance只需要额外添加一个classifier的梯度来引导。 # 控制信号损失 gradient_scale = gradient_control_signal.get("gradient_scale", 1.0) # 全局梯度缩放因子 # Add regression guidance for sum constraint control_loss = 0 # target_sum = # normalize the sum to -1, 1 # seq_length = input_embs_param.shape[1] # target_sum = ((target_sum / seq_length ) * 2 - 1) * seq_length # if target_sum:=gradient_control_signal.get("sum") is not None: # # print("sigma", sigma.shape, sigma, end=" ") # reg_nll = self.regression_guidance( # x=input_embs_param, # t=t, # target_sum=target_sum, # sigma=sigma # ) # control_loss += - gradient_control_signal.get("reg_weight", 1.0) * reg_nll * (5 - K) # (reg_gradient * ).sum() # init control signal loss auc_sum, peak_points, bar_regions, target_freq = \ gradient_control_signal.get("auc"), gradient_control_signal.get("peak_points"), gradient_control_signal.get("bar_regions"), gradient_control_signal.get("target_freq") # 1. 原有的sum控制 if auc_sum is not None: sum_weight = gradient_control_signal.get("auc_weight", 1.0) * time_weight auc_loss = - sum_weight * sum_guidance( x=input_embs_param, t=t, target_sum=auc_sum, gradient_scale=gradient_scale, segments=gradient_control_signal.get("segments", ()) ) control_loss += auc_loss # 峰值引导 if peak_points is not None: peak_weight = gradient_control_signal.get("peak_weight", 1.0) * time_weight peak_loss = - peak_weight * peak_guidance( x=input_embs_param, t=t, peak_points=peak_points, window_size=gradient_control_signal.get("peak_window_size", 5), alpha_1=gradient_control_signal.get("peak_alpha_1", 1.2), gradient_scale=gradient_scale ) control_loss += peak_loss # 区间引导 if bar_regions is not None: bar_weight = gradient_control_signal.get("bar_weight", 1.0) * time_weight bar_loss = -bar_weight * bar_guidance( x=input_embs_param, t=t, bar_regions=bar_regions, gradient_scale=gradient_scale ) control_loss += bar_loss # 频率引导 if target_freq is not None: freq_weight = gradient_control_signal.get("freq_weight", 1.0) * time_weight freq_loss = -freq_weight * frequency_guidance( x=input_embs_param, t=t, target_freq=target_freq, freq_weight=freq_weight, gradient_scale=gradient_scale ) control_loss += freq_loss loss = logp_term + infill_loss + control_loss if iteration == 0: # Only print first iteration to avoid spam # print(f"Losses - Diffusion: {logp_term:.4f}, Infill: {infill_loss:.4f}, Control: {control_loss:.4f}") # if target_sum is not None: # # Print current sum vs target for monitoring # if x_start.shape[1] < x_start.shape[2]: # [B, C, L] # current_sum = input_embs_param[:, 0].sum(dim=1) # else: # [B, L, C] # current_sum = input_embs_param[:, :, 0].sum(dim=1) # print(f"Current sum: {current_sum.data}, Target sum: {target_sum}") # print(f"Losses - Diffusion: {logp_term:.4f}\tInfill: {infill_loss:.4f}", end="\t") # if auc_sum is not None: # print(f"Sum Control: {auc_loss.item():.4f}", end="\t") # if peak_points is not None: # print(f"Peak Control: {peak_loss.item():.4f}", end="\t") # if bar_regions is not None: # print(f"Bar Control: {bar_loss.item():.4f}", end="\t") # if target_freq is not None: # print(f"Freq Control: {freq_loss.item():.4f}", end="\t") # print() pass # loss = logp_term + infill_loss + auc_loss # print(logp_term, infill_loss, auc_loss) loss.backward() optimizer.step() torch.nn.utils.clip_grad_norm_([input_embs_param], gradient_control_signal.get("max_grad_norm", 1.0)) # add more noise epsilon = torch.randn_like(input_embs_param.data) noise_scale = coef_ * sigma.mean().item() # * 2 # noise_scale = noise_scale * time_weight # (1 - time_weight) # 随时间减少噪声 input_embs_param = torch.nn.Parameter( ( input_embs_param.data + noise_scale * epsilon ).detach() ) if kwargs.get("enable_float_mask", False): sample = sample * partial_mask + input_embs_param.data * (1 - partial_mask) else: sample[~partial_mask] = input_embs_param.data[~partial_mask] # if t[0].item() % 10 == 9: # print("Sampled Image") # images_cache.append(plt.plot(sample[0,:,0].detach().cpu().numpy())[0]) # if t[0].item() == 9: # plt.show() # images_cache.clear() # plt.show() # plt.savefig(f"sampled_{t[0].item()}.png") # plt.plot(sample[0,:,0].detach().cpu().numpy()) # plt.show() return sample # def load_weights(self, model_path): # data = torch.load(model_path, map_location="cuda:0", weights_only=True) # self.load_state_dict(data["model"]) # print("Model weights loaded successfully") def predict_weighted_points( self, observed_points: torch.Tensor, observed_mask: torch.Tensor, coef=1e-1, stepsize=1e-1, sampling_steps=50, **kargs, ): model_kwargs = {} model_kwargs["coef"] = coef model_kwargs["learning_rate"] = stepsize model_kwargs = {**model_kwargs, **kargs} assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" x = observed_points.unsqueeze(0) float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor binary_mask = float_mask.clone() binary_mask[binary_mask > 0] = 1 x = x * 2 - 1 # normalize self.device = x.device x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device) if sampling_steps == self.num_timesteps: print("normal sampling") raise NotImplementedError sample = self.ema.ema_model.sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=float_mask, model_kwargs=model_kwargs, ) # x: partially noise : (batch_size, seq_length, feature_dim) else: print("fast sampling") sample = self.fast_sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=float_mask, model_kwargs=model_kwargs, sampling_timesteps=sampling_steps, ) # unnormalize sample = (sample + 1) / 2 return sample.squeeze(0).detach().cpu().numpy() def register_model(self, registered_model_name, model_path="tiffusion_model", conda_env=None): """Register the model with MLflow model registry. Args: registered_model_name: Name to register the model under model_path: Local path to save model artifacts conda_env: Custom conda environment for the model """ # Create basic conda env if not provided if conda_env is None: conda_env = { 'channels': ['defaults', 'conda-forge'], 'dependencies': [ 'python>=3.8', 'pytorch', 'einops', 'tqdm' ], 'name': 'tiffusion_env' } # Start an MLflow run with mlflow.start_run() as run: # Log model parameters mlflow.log_params({ "seq_length": self.seq_length, "feature_size": self.feature_size, "n_layer_enc": self.model.n_layer_enc, "n_layer_dec": self.model.n_layer_dec, "n_heads": self.model.n_heads, "timesteps": self.num_timesteps, "loss_type": self.loss_type }) # Create a custom Python model class for MLflow class TiffusionWrapper(mlflow.pyfunc.PythonModel): def __init__(self, model): self.model = model def predict(self, context, model_input): # Generate predictions using the model with torch.no_grad(): result = self.model.generate_mts(batch_size=len(model_input)) return result.numpy() # Create wrapper instance wrapped_model = TiffusionWrapper(self) # Log and register the model mlflow.pyfunc.log_model( artifact_path=model_path, python_model=wrapped_model, conda_env=conda_env, registered_model_name=registered_model_name ) print(f"Model registered as: {registered_model_name}") print(f"Run ID: {run.info.run_id}") def load_registered_model(self, registered_model_name, version=None, stage=None): """Load a registered model from MLflow model registry. Args: registered_model_name: Name of registered model version: Optional specific version to load stage: Optional stage to load (e.g. 'Production', 'Staging') """ if version: model_uri = f"models:/{registered_model_name}/{version}" elif stage: model_uri = f"models:/{registered_model_name}/{stage}" else: model_uri = f"models:/{registered_model_name}/latest" return mlflow.pyfunc.load_model(model_uri)