import os import sys import time import torch import numpy as np from pathlib import Path from tqdm.auto import tqdm from ema_pytorch import EMA from torch.optim import Adam from torch.nn.utils import clip_grad_norm_ from utils.io_utils import instantiate_from_config, get_model_parameters_info sys.path.append(os.path.join(os.path.dirname(__file__), "../")) def cycle(dl): while True: for data in dl: yield data class Trainer(object): def __init__(self, config, args, model, dataloader, logger=None): super().__init__() if os.getenv("WANDB_ENABLED") == "true": import wandb self.run = wandb.init(project="tiffusion-revenue", config=config) else: self.run = None self.model = model self.device = self.model.betas.device self.train_num_steps = config["solver"]["max_epochs"] self.gradient_accumulate_every = config["solver"]["gradient_accumulate_every"] self.save_cycle = config["solver"]["save_cycle"] self.dl = cycle(dataloader["dataloader"]) self.step = 0 self.milestone = 0 self.args = args self.logger = logger self.results_folder = Path( config["solver"]["results_folder"] + f"_{model.seq_length}" ) os.makedirs(self.results_folder, exist_ok=True) start_lr = config["solver"].get("base_lr", 1.0e-4) ema_decay = config["solver"]["ema"]["decay"] ema_update_every = config["solver"]["ema"]["update_interval"] self.opt = Adam( filter(lambda p: p.requires_grad, self.model.parameters()), lr=start_lr, betas=[0.9, 0.96], ) self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to( self.device ) sc_cfg = config["solver"]["scheduler"] sc_cfg["params"]["optimizer"] = self.opt self.sch = instantiate_from_config(sc_cfg) if self.logger is not None: self.logger.log_info(str(get_model_parameters_info(self.model))) self.log_frequency = 100 def save(self, milestone, verbose=False): if self.logger is not None and verbose: self.logger.log_info( "Save current model to {}".format( str(self.results_folder / f"checkpoint-{milestone}.pt") ) ) data = { "step": self.step, "model": self.model.state_dict(), "ema": self.ema.state_dict(), "opt": self.opt.state_dict(), } torch.save(data, str(self.results_folder / f"checkpoint-{milestone}.pt")) def load(self, milestone, verbose=False, from_folder=None): if self.logger is not None and verbose: self.logger.log_info( "Resume from {}".format( os.path.join(from_folder, f"checkpoint-{milestone}.pt") ) ) device = self.device data = torch.load( os.path.join(from_folder,f"checkpoint-{milestone}.pt") if from_folder else str(self.results_folder / f"checkpoint-{milestone}.pt"), map_location=device, weights_only=True ) self.model.load_state_dict(data["model"], ) self.step = data["step"] self.opt.load_state_dict(data["opt"]) self.ema.load_state_dict(data["ema"]) self.milestone = milestone def train(self): device = self.device step = 0 if self.logger is not None: tic = time.time() self.logger.log_info( "{}: start training...".format(self.args.name), check_primary=False ) with tqdm(initial=step, total=self.train_num_steps) as pbar: while step < self.train_num_steps: total_loss = 0.0 for _ in range(self.gradient_accumulate_every): data = next(self.dl).to(device) loss = self.model(data, target=data) loss = loss / self.gradient_accumulate_every loss.backward() total_loss += loss.item() pbar.set_description( f'loss: {total_loss:.6f} lr: {self.opt.param_groups[0]["lr"]:.6f}' ) if self.run is not None: wandb.log( { "step": step, "loss": total_loss, "lr": self.opt.param_groups[0]["lr"], }, step=self.step, ) clip_grad_norm_(self.model.parameters(), 1.0) self.opt.step() self.sch.step(total_loss) self.opt.zero_grad() self.step += 1 step += 1 self.ema.update() with torch.no_grad(): if self.step != 0 and self.step % self.save_cycle == 0: self.milestone += 1 self.save(self.milestone) # self.logger.log_info('saved in {}'.format(str(self.results_folder / f'checkpoint-{self.milestone}.pt'))) if self.logger is not None and self.step % self.log_frequency == 0: # info = '{}: train'.format(self.args.name) # info = info + ': Epoch {}/{}'.format(self.step, self.train_num_steps) # info += ' ||' # info += '' if loss_f == 'none' else ' Fourier Loss: {:.4f}'.format(loss_f.item()) # info += '' if loss_r == 'none' else ' Reglarization: {:.4f}'.format(loss_r.item()) # info += ' | Total Loss: {:.6f}'.format(total_loss) # self.logger.log_info(info) self.logger.add_scalar( tag="train/loss", scalar_value=total_loss, global_step=self.step, ) pbar.update(1) print("training complete") if self.logger is not None: self.logger.log_info( "Training done, time: {:.2f}".format(time.time() - tic) ) def sample(self, num, size_every, shape=None): if self.logger is not None: tic = time.time() self.logger.log_info("Begin to sample...") samples = np.empty([0, shape[0], shape[1]]) num_cycle = int(num // size_every) + 1 for _ in range(num_cycle): sample = self.ema.ema_model.generate_mts(batch_size=size_every) samples = np.row_stack([samples, sample.detach().cpu().numpy()]) torch.cuda.empty_cache() if self.logger is not None: self.logger.log_info( "Sampling done, time: {:.2f}".format(time.time() - tic) ) return samples def control_sample(self, num, size_every, shape=None, model_kwargs={}, target=None, partial_mask=None): samples = np.empty([0, shape[0], shape[1]]) import math num_cycle = math.ceil(num / size_every) assert not ((target is None) ^ (partial_mask is None)), "target and partial_mask should be provided" if self.logger is not None: tic = time.time() self.logger.log_info("Begin to infill sample...") target = torch.tensor(target).to(self.device) if target is not None else torch.zeros(shape).to(self.device) target = target.repeat(size_every, 1, 1) if len(target.shape) == 2 else target partial_mask = torch.tensor(partial_mask).to(self.device) if partial_mask is not None else torch.zeros(shape).to(self.device) partial_mask = partial_mask.repeat(size_every, 1, 1) if len(partial_mask.shape) == 2 else partial_mask for _ in range(num_cycle): sample = self.ema.ema_model.generate_mts_infill(target, partial_mask, model_kwargs=model_kwargs) samples = np.row_stack([samples, sample.detach().cpu().numpy()]) torch.cuda.empty_cache() if self.logger is not None: self.logger.log_info( "Sampling done, time: {:.2f}".format(time.time() - tic) ) return samples def predict( self, observed_points: torch.Tensor, coef=1e-1, stepsize=1e-1, sampling_steps=50, **kargs, ): model_kwargs = {} model_kwargs["coef"] = coef model_kwargs["learning_rate"] = stepsize model_kwargs = {**model_kwargs, **kargs} assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" x = observed_points.unsqueeze(0) t_m = x != 0 x = x * 2 - 1 # normalize x, t_m = x.to(self.device), t_m.to(self.device) if sampling_steps == self.model.num_timesteps: print("normal sampling") sample = self.ema.ema_model.sample_infill( shape=x.shape, target=x * t_m, partial_mask=t_m, model_kwargs=model_kwargs, ) # x: partially noise : (batch_size, seq_length, feature_dim) else: print("fast sampling") sample = self.ema.ema_model.fast_sample_infill( shape=x.shape, target=x * t_m, partial_mask=t_m, model_kwargs=model_kwargs, sampling_timesteps=sampling_steps, ) # unnormalize sample = (sample + 1) / 2 return sample.squeeze(0).detach().cpu().numpy() def predict_weighted_points( self, observed_points: torch.Tensor, observed_mask: torch.Tensor, coef=1e-1, stepsize=1e-1, sampling_steps=50, **kargs, ): model_kwargs = {} model_kwargs["coef"] = coef model_kwargs["learning_rate"] = stepsize model_kwargs = {**model_kwargs, **kargs} assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" x = observed_points.unsqueeze(0) float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor binary_mask = float_mask.clone() binary_mask[binary_mask > 0] = 1 x = x * 2 - 1 # normalize x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device) if sampling_steps == self.model.num_timesteps: print("normal sampling") raise NotImplementedError sample = self.ema.ema_model.sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=float_mask, model_kwargs=model_kwargs, ) # x: partially noise : (batch_size, seq_length, feature_dim) else: print("fast sampling") sample = self.ema.ema_model.fast_sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=float_mask, model_kwargs=model_kwargs, sampling_timesteps=sampling_steps, ) # unnormalize sample = (sample + 1) / 2 return sample.squeeze(0).detach().cpu().numpy() def restore( self, raw_dataloader, shape=None, coef=1e-1, stepsize=1e-1, sampling_steps=50, **kargs, ): if self.logger is not None: tic = time.time() self.logger.log_info("Begin to restore...") model_kwargs = {} model_kwargs["coef"] = coef model_kwargs["learning_rate"] = stepsize model_kwargs = {**model_kwargs, **kargs} test = kargs.get("test", False) samples = np.empty([0, shape[0], shape[1]]) # seq_length, feature_dim reals = np.empty([0, shape[0], shape[1]]) masks = np.empty([0, shape[0], shape[1]]) for idx, (x, t_m) in enumerate(raw_dataloader): # # take first 5 example # # x, t_m = x[:5], t_m[:5] # # x[~t_m] = 0 # # print(x, t_m) # # 1M 2021/2/10 9 # # 2M 2021/2/16 6+9 # # 3M 2021/2/19 9+9 # # 4M 2021/2/24 14+ # # 5M 2021/3/3 20+9 # x = torch.zeros_like(x)[:1] # # x[0, 0, 0] = 0.03 # # x[0, 9, 0] = 0.16 # # x[0, 15, 0] = 0.25 # # x[0, 18, 0] = 0.22 # # x[0, 24, 0] = 0.21 # # x[0, 33, 0] = 0.16 # x[0, 0, 0] = 0.04 # x[0, 2, 0] = 0.58 # x[0, 6, 0] = 0.27 # x[0, 58, 0] = 1. # x[0, -1, 0] = 0.05 # # x[0, 0, 0] = 0.01 # # x[0, -1, 0] = 0.01 # # x[0, -20, 0] = 0.01 # # x[0, -100, 0] = 0.01 # # x[0, -50, 0] = 0.01 # # x[0, -120, 0] = 0.01 # # import math # # for i in range(35, 240, 2): # # x[0, i, 0] = max(0.01, math.exp(-0.01*i) / 10) # # import matplotlib.pyplot as plt # # plt.plot(x[0, :, 0].detach().cpu().numpy()) # # plt.show() t_m = x == 0 # x != 0, 1 for observed, 0 for missing, bool tensor # # if test: t_m = t_m.type_as(x) binary_mask = t_m.clone() binary_mask[binary_mask > 0] = 1 else: binary_mask = t_m # x = x * 2 - 1 x, t_m = x.to(self.device), t_m.to(self.device) binary_mask = binary_mask.to(self.device) if sampling_steps == self.model.num_timesteps: print("normal sampling") sample = self.ema.ema_model.sample_infill( shape=x.shape, target=x * t_m, partial_mask=t_m, model_kwargs=model_kwargs, ) # x: partially noise : (batch_size, seq_length, feature_dim) else: print("fast sampling") if test: sample = self.ema.ema_model.fast_sample_infill_float_mask( shape=x.shape, target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing partial_mask=t_m, model_kwargs=model_kwargs, sampling_timesteps=sampling_steps, ) else: sample = self.ema.ema_model.fast_sample_infill( shape=x.shape, target=x * t_m, partial_mask=t_m, model_kwargs=model_kwargs, sampling_timesteps=sampling_steps, ) samples = np.row_stack([samples, sample.detach().cpu().numpy()]) reals = np.row_stack([reals, x.detach().cpu().numpy()]) masks = np.row_stack([masks, t_m.detach().cpu().numpy()]) break if self.logger is not None: self.logger.log_info( "Imputation done, time: {:.2f}".format(time.time() - tic) ) return samples, reals, masks # return samples