import torch def sum_guidance( x: torch.Tensor, t: torch.Tensor, target_sum: torch.Tensor, sigma: float = 1.0, gradient_scale: float = 1.0, segments: tuple = None ): """Enhanced regression guidance with stronger gradients""" x_with_grad = x # if x_with_grad.shape[1] < x_with_grad.shape[2]: # [B, C, L] # current_sum = x_with_grad[:, 0] # current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1] # current_sum = current_sum.sum(dim=1) # if segments: # 使用segments来指定计算的区间 # for i, (start_idx, end_idx) in enumerate(segments): # if i==0: # current_sum = current_sum[:,0, start_idx:end_idx] # else: # current_sum += current_sum[:,0, start_idx:end_idx] # assert False, "Not implemented yet" # else: # [B, L, C] # print(x_with_grad.shape) current_sum = x_with_grad[:, :, 0] current_sum = current_sum / 2 + 0.5 current_sum = current_sum.sum(dim=1) if segments: # 使用segments来指定计算的区间 for i, (start_idx, end_idx) in enumerate(segments): if i==0: current_sum = current_sum[:, start_idx:end_idx,0] else: current_sum += current_sum[:, start_idx:end_idx,0] # 使用更小的sigma来增强梯度 # sigma = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(current_sum) # sigma = sigma / gradient_scale # 缩小sigma增强梯度 if sigma == 0: pred_std = torch.ones_like(current_sum) else: pred_std = torch.ones_like(current_sum) * sigma # 使用指数函数增强梯度 # log_prob = torch.exp(-0.5 * (target_sum - current_sum)**2 / (pred_std**2)) log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ (target_sum - current_sum)**2 / (2 * pred_std**2) return log_prob.mean() def peak_guidance( x: torch.Tensor, t: torch.Tensor, peak_points: list, window_size: int = 5, alpha_1: float = 1.2, sigma: float = 1.0, gradient_scale: float = 1.0 # 新增梯度缩放参数 ): x_with_grad = x log_prob = 0 if x_with_grad.shape[1] < x_with_grad.shape[2]: signal = x_with_grad[:, 0] else: signal = x_with_grad[:, :, 0] signal = signal / 2 + 0.5 for x_coord in peak_points: # 全局均值条件 # global_mean = signal.mean(dim=1, keepdim=True) # peak_diff = signal[:, x_coord] - global_mean.squeeze() # 局部窗口均值条件 half_window = window_size // 2 start_idx = max(0, x_coord - half_window) end_idx = min(signal.shape[1], x_coord + half_window + 1) # local_mean = signal[:, start_idx:end_idx].mean(dim=1) # local_diff = signal[:, x_coord] - local_mean * alpha_1 # local_mean not include the peak point local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1) local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean() # # 动态调整sigma以增强梯度 # sigma_t = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(peak_diff) # sigma_t = sigma_t / gradient_scale # 缩小sigma增强梯度 # sigma_t = torch.ones_like(peak_diff) * sigma if sigma > 0 else torch.ones_like(peak_diff) # 使用指数函数增强梯度 # log_prob += torch.exp(-0.5 * ((peak_diff - 2 * sigma)**2) / (sigma_t**2)).mean() # log_prob += torch.exp(-0.5 * (local_diff**2) / (sigma_t**2)).mean() # 不使用指数函数增强梯度 # log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ # (target_sum - current_sum)**2 / (2 * pred_std**2) # if sigma == 0: # pred_std = torch.ones_like(local_mean) # else: # pred_std = torch.ones_like(local_mean) * sigma log_prob += - (local_diff**2) / (2 * sigma**2) return log_prob.mean() def bar_guidance( x: torch.Tensor, t: torch.Tensor, bar_regions: list, sigma: float = 1.0, gradient_scale: float = 1.0 ): x_with_grad = x log_prob = 0 if x_with_grad.shape[1] < x_with_grad.shape[2]: signal = x_with_grad[:, 0] else: signal = x_with_grad[:, :, 0] signal = signal / 2 + 0.5 for start_idx, end_idx, target_value in bar_regions: region_mean = signal[:, start_idx:end_idx].mean(dim=1) # sigma_t = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(region_mean) # sigma_t = sigma_t / gradient_scale sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean) # 使用指数函数增强梯度 log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean() return log_prob def frequency_guidance( x: torch.Tensor, t: torch.Tensor, target_freq: float, freq_weight: float = 1.0, gradient_scale: float = 1.0 ): x_with_grad = x if x_with_grad.shape[1] < x_with_grad.shape[2]: signal = x_with_grad[:, 0] else: signal = x_with_grad[:, :, 0] fft_signal = torch.fft.rfft(signal, dim=1) freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0) # 使用更窄的高斯窗口增强特定频率 freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2)) freq_window = freq_window.to(x.device) magnitude = torch.abs(fft_signal) * freq_window[None, :] # 使用指数函数增强梯度 return torch.exp(freq_weight * magnitude.mean()) def get_time_dependent_weights(t, num_timesteps): """ 根据时间步长动态调整控制信号的权重 较早的时间步长使用更大的权重 """ progress = t.float() / num_timesteps # 在早期时间步长使用更大的权重 weight_scale = torch.exp(-5 * progress) return weight_scale