import math import torch import torch.nn.functional as F from torch import nn from einops import reduce from tqdm.auto import tqdm from functools import partial from ..model_utils import default, identity, extract from .control import * from .diff_csdi import diff_CSDI from .csdi import CSDI_base import numpy as np def linear_beta_schedule(timesteps): scale = 1000 / timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) class Tiffusion(nn.Module): def __init__( self, seq_length, feature_size, n_layer_enc=3, n_layer_dec=6, d_model=None, timesteps=1000, sampling_timesteps=None, loss_type="l1", beta_schedule="cosine", n_heads=4, mlp_hidden_times=4, eta=0.0, attn_pd=0.0, resid_pd=0.0, kernel_size=None, padding_size=None, use_ff=True, reg_weight=None, control_signal={}, moving_average=False, is_unconditional=False, target_strategy="mix", **kwargs, ): super(Tiffusion, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.eta, self.use_ff = eta, use_ff self.seq_length = seq_length self.feature_size = feature_size self.ff_weight = default(reg_weight, math.sqrt(self.seq_length) / 5) self.sum_weight = default(reg_weight, math.sqrt(self.seq_length // 10) / 50) self.training_control_signal = control_signal # training control signal self.moving_average = moving_average self.is_unconditional = is_unconditional self.target_strategy = target_strategy self.target_strategy = "random" config = { "model": { "timeemb": 128, "featureemb": 16, "is_unconditional": False, "target_strategy": "mix", }, "diffusion": { "layers": 3, "channels": 64, "nheads": 8, "diffusion_embedding_dim": 128, "is_linear": False, "beta_start": 0.0001, "beta_end": 0.5, "schedule": "quad", "num_steps": 50, } } self.emb_time_dim = config["model"]["timeemb"] self.emb_feature_dim = config["model"]["featureemb"] self.is_unconditional = config["model"]["is_unconditional"] self.target_strategy = config["model"]["target_strategy"] # parameters for diffusion models config_diff = config["diffusion"] self.num_steps = config_diff["num_steps"] if config_diff["schedule"] == "quad": self.beta = np.linspace( config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps ) ** 2 elif config_diff["schedule"] == "linear": self.beta = np.linspace( config_diff["beta_start"], config_diff["beta_end"], self.num_steps ) self.alpha_hat = 1 - self.beta self.alpha = np.cumprod(self.alpha_hat) self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1) self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim if self.is_unconditional == False: self.emb_total_dim += 1 # for conditional mask self.target_dim = feature_size print(feature_size) self.embed_layer = nn.Embedding( num_embeddings=self.target_dim , embedding_dim=self.emb_feature_dim ) self.diffmodel = diff_CSDI( { "layers": 3, "channels": 64, "nheads": 8, "diffusion_embedding_dim": 128, "is_linear": False, "beta_start": 0.0001, "beta_end": 0.5, "schedule": "quad", "num_steps": 50, "side_dim": self.emb_total_dim }, (1 if self.is_unconditional == True else 2) ) if beta_schedule == "linear": betas = linear_beta_schedule(timesteps) elif beta_schedule == "cosine": betas = cosine_beta_schedule(timesteps) else: raise ValueError(f"unknown beta schedule {beta_schedule}") alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type # sampling related parameters self.sampling_timesteps = default( sampling_timesteps, timesteps ) # default num sampling timesteps to number of timesteps at training assert self.sampling_timesteps <= timesteps self.fast_sampling = self.sampling_timesteps < timesteps # helper function to register buffer from float64 to float32 register_buffer = lambda name, val: self.register_buffer( name, val.to(torch.float32) ) register_buffer("betas", betas) register_buffer("alphas_cumprod", alphas_cumprod) register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) register_buffer( "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) ) register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) register_buffer( "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer("posterior_variance", posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer( "posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)), ) register_buffer( "posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), ) register_buffer( "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), ) # calculate reweighting register_buffer( "loss_weight", torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100, ) def predict_noise_from_start(self, x_t, t, x0): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0 ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract( self.posterior_log_variance_clipped, t, x_t.shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def output(self, x, t, padding_masks=None, control_signal=None): """Modified output function to work with CSDI""" if isinstance(t, int): t = torch.tensor([t]).to(x.device) # Prepare side info observed_tp = torch.arange(x.shape[1], device=x.device).float() observed_tp = observed_tp.unsqueeze(0).expand(x.shape[0], -1) side_info = self.get_side_info(observed_tp, padding_masks) # Get model prediction predicted, _ = self.diffmodel(x, side_info, t) return predicted def generate_mts(self, batch_size=16): feature_size, seq_length = self.feature_size, self.seq_length sample_fn = self.fast_sample if self.fast_sampling else self.sample return sample_fn((batch_size, seq_length, feature_size)) def generate_mts_infill(self, target, partial_mask=None, clip_denoised=True, model_kwargs=None): """Improved method for conditional generation""" with torch.no_grad(): # Setup inputs observed_tp = torch.arange(target.shape[1], device=target.device).float() observed_tp = observed_tp.unsqueeze(0).expand(target.shape[0], -1) # Generate side info side_info = self.get_side_info(observed_tp, partial_mask) # Sample using CSDI imputation samples = self.impute( observed_data=target, cond_mask=partial_mask, side_info=side_info, n_samples=1 ) return samples.squeeze(1) # def fast_sample_infill_float_mask( # self, # shape, # target: torch.Tensor, # target time series # [B, L, C] # sampling_timesteps, # partial_mask: torch.Tensor = None, # float mask between 0 and 1 # [B, L, C] # clip_denoised=True, # model_kwargs=None, # ): # batch, device, total_timesteps, eta = ( # shape[0], # self.betas.device, # self.num_timesteps, # self.eta, # ) # # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps # times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # times = list(reversed(times.int().tolist())) # time_pairs = list( # zip(times[:-1], times[1:]) # ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] # # Initialize with noise # img = torch.randn(shape, device=device) # [B, L, C] # for time, time_next in tqdm( # time_pairs, desc="conditional sampling loop time step" # ): # time_cond = torch.full((batch,), time, device=device, dtype=torch.long) # # pred_noise, x_start, *_ = self.model_predictions( # # img, # # time_cond, # # clip_x_start=clip_denoised, # # control_signal=model_kwargs.get("model_control_signal", {}), # # ) # # x, t, clip_x_start=False, padding_masks=None, control_signal=None # # if padding_masks is None: # padding_masks = torch.ones( # img.shape[0], self.seq_length, dtype=bool, device=img.device # ) # maybe_clip = ( # partial(torch.clamp, min=-1.0, max=1.0) if clip_denoised else identity # ) # # def output(self, x, t, padding_masks=None, control_signal=None): # # """Modified output function to work with CSDI""" # # if isinstance(t, int): # # t = torch.tensor([t]).to(x.device) # # # Prepare side info # # observed_tp = torch.arange(x.shape[1], device=x.device).float() # # observed_tp = observed_tp.unsqueeze(0).expand(x.shape[0], -1) # # side_info = self.get_side_info(observed_tp, padding_masks) # # # Get model prediction # # predicted, _ = self.diffmodel(x, side_info, t) # # return predicted # predicted, _ = self.diffmodel(img, time_cond) # coeff1 = 1 / self.alpha_hat[time] ** 0.5 # coeff2 = (1 - self.alpha_hat[time]) / (1 - self.alpha[time]) ** 0.5 # x_start = coeff1 * (img - coeff2 * predicted) # # x_start = self.output(img, time_cond, padding_masks) # x_start = maybe_clip(x_start) # pred_noise = self.predict_noise_from_start(img, time_cond, x_start) # # return pred_noise, x_start # if time_next < 0: # img = x_start # continue # # Compute the predicted mean # alpha = self.alphas_cumprod[time] # alpha_next = self.alphas_cumprod[time_next] # sigma = ( # eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() # ) # c = (1 - alpha_next - sigma**2).sqrt() # noise = torch.randn_like(img) # pred_mean = x_start * alpha_next.sqrt() + c * pred_noise # img = pred_mean + sigma * noise # # # Langevin Dynamics part for additional gradient updates # # img = self.langevin_fn( # # sample=img, # # mean=pred_mean, # # sigma=sigma, # # t=time_cond, # # tgt_embs=target, # # partial_mask=partial_mask, # # enable_float_mask=True, # # **model_kwargs, # # ) # img = img * (1 - partial_mask) + target * partial_mask # img = img * (1 - partial_mask) + target * partial_mask # return img def langevin_fn( self, coef, partial_mask, tgt_embs, learning_rate, sample, mean, sigma, t, coef_=0.0, gradient_control_signal={}, model_control_signal={}, side_info=None, **kwargs, ): # we thus run more gradient updates at large diffusion step t to guide the generation then # reduce the number of gradient steps in stages to accelerate sampling. if t[0].item() < self.num_timesteps * 0.02 : K = 0 elif t[0].item() > self.num_timesteps * 0.9: K = 3 elif t[0].item() > self.num_timesteps * 0.75: K = 2 learning_rate = learning_rate * 0.5 else: K = 1 learning_rate = learning_rate * 0.25 input_embs_param = torch.nn.Parameter(sample) # 获取时间相关的权重调整因子 time_weight = get_time_dependent_weights(t[0], self.num_timesteps) with torch.enable_grad(): for iteration in range(K): # x_i+1 = x_i + noise * grad(logp(x_i)) + sqrt(2*noise) * z_i optimizer = torch.optim.Adagrad([input_embs_param], lr=learning_rate) optimizer.zero_grad() # x_start = self.output( # x=input_embs_param, # t=t, # control_signal=model_control_signal, # ) # Prepare model input # if self.is_unconditional: # diff_input = cond_mask * observed_data + (1.0 - cond_mask) * current_sample # diff_input = diff_input.unsqueeze(1) # else: # cond_obs = (cond_mask * observed_data).unsqueeze(1) # noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) # diff_input = torch.cat([cond_obs, noisy_target], dim=1) if self.is_unconditional: diff_input = input_embs_param.unsqueeze(1) else: cond_obs = (partial_mask * tgt_embs).unsqueeze(1) noisy_target = ((1 - partial_mask) * input_embs_param).unsqueeze(1) diff_input = torch.cat([cond_obs, noisy_target], dim=1) x_start, _ = self.diffmodel(diff_input, side_info, t) if sigma.mean() == 0: logp_term = ( coef * ((mean - input_embs_param) ** 2 / 1.0).mean(dim=0).sum() ) # determine the partical_mask is float if kwargs.get("enable_float_mask", False): infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2 else: infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2 infill_loss = infill_loss.mean(dim=0).sum() else: logp_term = ( coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum() ) if kwargs.get("enable_float_mask", False): infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2 else: infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2 infill_loss = (infill_loss / sigma.mean()).mean(dim=0).sum() gradient_scale = gradient_control_signal.get("gradient_scale", 1.0) # 全局梯度缩放因子 control_loss = 0 auc_sum, peak_points, bar_regions, target_freq = \ gradient_control_signal.get("auc"), gradient_control_signal.get("peak_points"), gradient_control_signal.get("bar_regions"), gradient_control_signal.get("target_freq") # 1. 原有的sum控制 if auc_sum is not None: sum_weight = gradient_control_signal.get("auc_weight", 1.0) * time_weight auc_loss = - sum_weight * sum_guidance( x=input_embs_param, t=t, target_sum=auc_sum, gradient_scale=gradient_scale, segments=gradient_control_signal.get("segments", ()) ) control_loss += auc_loss # 峰值引导 if peak_points is not None: peak_weight = gradient_control_signal.get("peak_weight", 1.0) * time_weight peak_loss = - peak_weight * peak_guidance( x=input_embs_param, t=t, peak_points=peak_points, window_size=gradient_control_signal.get("peak_window_size", 5), alpha_1=gradient_control_signal.get("peak_alpha_1", 1.2), gradient_scale=gradient_scale ) control_loss += peak_loss # 区间引导 if bar_regions is not None: bar_weight = gradient_control_signal.get("bar_weight", 1.0) * time_weight bar_loss = -bar_weight * bar_guidance( x=input_embs_param, t=t, bar_regions=bar_regions, gradient_scale=gradient_scale ) control_loss += bar_loss # 频率引导 if target_freq is not None: freq_weight = gradient_control_signal.get("freq_weight", 1.0) * time_weight freq_loss = -freq_weight * frequency_guidance( x=input_embs_param, t=t, target_freq=target_freq, freq_weight=freq_weight, gradient_scale=gradient_scale ) control_loss += freq_loss loss = logp_term + infill_loss + control_loss loss.backward() optimizer.step() torch.nn.utils.clip_grad_norm_([input_embs_param], gradient_control_signal.get("max_grad_norm", 1.0)) epsilon = torch.randn_like(input_embs_param.data) noise_scale = coef_ * sigma.mean().item() input_embs_param = torch.nn.Parameter( ( input_embs_param.data + noise_scale * epsilon ).detach() ) if kwargs.get("enable_float_mask", False): sample = sample * partial_mask + input_embs_param.data * (1 - partial_mask) else: sample[~partial_mask] = input_embs_param.data[~partial_mask] return sample def predict_weighted_points( self, observed_points: torch.Tensor, observed_mask: torch.Tensor, coef=1e-1, stepsize=1e-1, sampling_steps=50, **kargs, ): model_kwargs = {} model_kwargs["coef"] = coef model_kwargs["learning_rate"] = stepsize model_kwargs = {**model_kwargs, **kargs} assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" x = observed_points.unsqueeze(0) float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor binary_mask = float_mask.clone() binary_mask[binary_mask > 0] = 1 x = x * 2 - 1 # normalize self.device = x.device x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device) if sampling_steps == self.num_timesteps: print("normal sampling") raise NotImplementedError sample = self.ema.ema_model.sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=float_mask, model_kwargs=model_kwargs, ) # x: partially noise : (batch_size, seq_length, feature_dim) else: print("fast sampling") sample = self.fast_sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=float_mask, model_kwargs=model_kwargs, sampling_timesteps=sampling_steps, ) # unnormalize sample = (sample + 1) / 2 return sample.squeeze(0).detach().cpu().numpy() def forward(self, x, **kwargs): """Modified forward pass for CSDI training""" # Convert input from [B, C, L] to [B, L, C] observed_data = x.permute(0, 2, 1) observed_mask = kwargs.get("observed_mask", torch.ones_like(observed_data)) observed_tp = torch.arange(observed_data.shape[1], device=x.device).float() observed_tp = observed_tp.unsqueeze(0).expand(x.shape[0], -1) # Generate masks is_train = kwargs.get("is_train", 1) if is_train: cond_mask = self.get_randmask(observed_mask) else: gt_mask = kwargs.get("gt_mask", observed_mask.clone()) if "pred_length" in kwargs: gt_mask[:,:,-kwargs["pred_length"]:] = 0 cond_mask = gt_mask # Get side info and calculate loss side_info = self.get_side_info(observed_tp, cond_mask) loss_func = self.calc_loss if is_train else self.calc_loss_valid return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train) def time_embedding(self, pos, d_model=128): pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device) position = pos.unsqueeze(2) div_term = 1 / torch.pow( 10000.0, torch.arange(0, d_model, 2).to(pos.device) / d_model ) pe[:, :, 0::2] = torch.sin(position * div_term) pe[:, :, 1::2] = torch.cos(position * div_term) return pe def get_randmask(self, observed_mask): rand_for_mask = torch.rand_like(observed_mask) * observed_mask rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) for i in range(len(observed_mask)): sample_ratio = np.random.rand() # missing ratio num_observed = observed_mask[i].sum().item() num_masked = round(num_observed * sample_ratio) rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1 cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() return cond_mask def get_hist_mask(self, observed_mask, for_pattern_mask=None): if for_pattern_mask is None: for_pattern_mask = observed_mask if self.target_strategy == "mix": rand_mask = self.get_randmask(observed_mask) cond_mask = observed_mask.clone() for i in range(len(cond_mask)): mask_choice = np.random.rand() if self.target_strategy == "mix" and mask_choice > 0.5: cond_mask[i] = rand_mask[i] else: # draw another sample for histmask (i-1 corresponds to another sample) cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1] return cond_mask def get_test_pattern_mask(self, observed_mask, test_pattern_mask): return observed_mask * test_pattern_mask def get_side_info(self, observed_tp, cond_mask): B, K, L = cond_mask.shape time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb) torch.Size([64, 24, 128]) # print(time_embed.shape) time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) feature_embed = self.embed_layer( torch.arange(self.target_dim).to(observed_tp.device) ) # (K, emb) # print("feature_embed",feature_embed.shape) feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) # torch.Size([64, 24, 24, 128])[64, 28, 28, 16]) # print(time_embed.shape, feature_embed.shape) side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*) side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) if self.is_unconditional == False: side_mask = cond_mask.unsqueeze(1) # (B,1,K,L) side_info = torch.cat([side_info, side_mask], dim=1) return side_info def calc_loss_valid( self, observed_data, cond_mask, observed_mask, side_info, is_train ): loss_sum = 0 for t in range(self.num_steps): # calculate loss for all t loss = self.calc_loss( observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t ) loss_sum += loss.detach() return loss_sum / self.num_steps def calc_loss( self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1 ): B, K, L = observed_data.shape if is_train != 1: # for validation t = (torch.ones(B) * set_t).long().to(self.device) else: t = torch.randint(0, self.num_steps, [B]).to(self.device) current_alpha = self.alpha_torch[t] # (B,1,1) noise = torch.randn_like(observed_data) noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) predicted, _ = self.diffmodel(total_input, side_info, t) # (B,K,L) target_mask = observed_mask - cond_mask residual = (noise - predicted) * target_mask num_eval = target_mask.sum() loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1) return loss def evaluate(self, batch, n_samples): ( observed_data, # [B, L, K] observed_mask, # 1 for observed, 0 for missing observed_tp, # [0, 1, 2, ..., L-1] gt_mask, _, cut_length, ) = self.process_data(batch) with torch.no_grad(): cond_mask = gt_mask target_mask = observed_mask - cond_mask # 1 for missing, 0 for observed side_info = self.get_side_info(observed_tp, cond_mask) samples = self.impute(observed_data, cond_mask, side_info, n_samples) for i in range(len(cut_length)): # to avoid double evaluation target_mask[i, ..., 0 : cut_length[i].item()] = 0 return samples, observed_data, target_mask, observed_mask, observed_tp def impute(self, observed_data, cond_mask, side_info, n_samples): """Modified impute function with Langevin dynamics and control signals""" B, K, L = observed_data.shape imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device) # Setup sampling parameters # times = torch.linspace(-1, self.num_steps - 1, steps=self.sampling_timesteps + 1) # times = list(reversed(times.int().tolist())) # time_pairs = list(zip(times[:-1], times[1:])) for i in range(n_samples): # Initialize with noise current_sample = torch.randn_like(observed_data) # for t, time_next in tqdm(time_pairs, desc="Imputation sampling"): for t in range(self.num_steps - 1, -1, -1): # Prepare time condition # time_cond = torch.full((B,), time, device=self.device, dtype=torch.long) time_cond = torch.tensor([t]).to(self.device) # Prepare model input if self.is_unconditional: diff_input = cond_mask * observed_data + (1.0 - cond_mask) * current_sample diff_input = diff_input.unsqueeze(1) else: cond_obs = (cond_mask * observed_data).unsqueeze(1) noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) diff_input = torch.cat([cond_obs, noisy_target], dim=1) predicted, _ = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device)) coeff1 = 1 / self.alpha_hat[t] ** 0.5 coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 current_sample = coeff1 * (current_sample - coeff2 * predicted) if t > 0: noise = torch.randn_like(current_sample) sigma = ( (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t] ) ** 0.5 current_sample += sigma * noise # # Get prediction # predicted = self.diffmodel(diff_input, side_info, time_cond)[0] # if time_next < 0: # current_sample = predicted # continue # # Update sample with noise # alpha = self.alpha[time] # alpha_next = self.alpha[time_next] # # Compute transition parameters # sigma = self.eta * ((1 - alpha_next) / (1 - alpha) * (1 - alpha / alpha_next)).sqrt() # c = (1 - alpha_next - sigma**2).sqrt() # # Update sample # noise = torch.randn_like(current_sample) # pred_mean = predicted * alpha_next.sqrt() + c * current_sample # current_sample = pred_mean + sigma * noise # # # Apply Langevin dynamics and control signals # # if model_kwargs is not None: # # current_sample = self.langevin_fn( # # sample=current_sample, # # mean=pred_mean, # # sigma=sigma, # # t=time_cond, # # tgt_embs=observed_data, # # partial_mask=cond_mask, # # enable_float_mask=True, # # side_info=side_info, # # **model_kwargs # # ) # # Apply conditioning # current_sample = current_sample * (1 - cond_mask) + observed_data * cond_mask imputed_samples[:, i] = current_sample return imputed_samples def fast_sample_infill_float_mask( self, shape, target: torch.Tensor, sampling_timesteps, partial_mask: torch.Tensor = None, clip_denoised=True, model_kwargs=None, ): """Simplified fast sampling that uses improved impute function""" batch = shape[0] device = self.device target = target.permute(0, 2, 1) partial_mask = partial_mask.permute(0, 2, 1) # Generate timepoints observed_tp = torch.arange(shape[1], device=device).float() observed_tp = observed_tp.unsqueeze(0).expand(batch, -1) # Get side info side_info = self.get_side_info(observed_tp, partial_mask) # Use modified impute function with control signals samples = self.impute( observed_data=target, cond_mask=partial_mask, side_info=side_info, n_samples=1, ) return samples.squeeze(1).permute(0, 2, 1)