import torch from torch.utils.data import DataLoader import torch.nn as nn from configs.deepsvg.hierarchical_ordered import Config from deepsvg import utils from deepsvg.svglib.svg import SVG from deepsvg.difflib.tensor import SVGTensor from deepsvg.svglib.geom import Bbox from deepsvg.svgtensor_dataset import load_dataset, SVGFinetuneDataset from deepsvg.utils.utils import batchify from .state.project import DeepSVGProject, Frame from .utils import easein_easeout device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu") pretrained_path = "./pretrained/hierarchical_ordered.pth.tar" cfg = Config() cfg.model_cfg.dropout = 0. # for faster convergence model = cfg.make_model().to(device) model.eval() dataset = load_dataset(cfg) def decode(z): commands_y, args_y, _ = model.greedy_sample(z=z) tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu()) svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256)) return svg_path_sample def encode_svg(svg): data = dataset.get(model_args=[*cfg.model_args, "tensor_grouped"], svg=svg) model_args = batchify((data[key] for key in cfg.model_args), device) z = model(*model_args, encode_mode=True) return z def interpolate_svg(svg1, svg2, n=10, ease=True): z1, z2 = encode_svg(svg1), encode_svg(svg2) alphas = torch.linspace(0., 1., n+2)[1:-1] if ease: alphas = easein_easeout(alphas) z_list = [(1 - a) * z1 + a * z2 for a in alphas] svgs = [decode(z) for z in z_list] return svgs def finetune_model(project: DeepSVGProject, nb_augmentations=3500): keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe] if len(keyframe_ids) < 2: return svgs = [project.frames[i].svg for i in keyframe_ids] utils.load_model(pretrained_path, model) print("Finetuning...") finetune_dataset = SVGFinetuneDataset(dataset, svgs, frac=1.0, nb_augmentations=nb_augmentations) dataloader = DataLoader(finetune_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=False, num_workers=cfg.loader_num_workers, collate_fn=cfg.collate_fn) # Optimizer, lr & warmup schedulers optimizers = cfg.make_optimizers(model) scheduler_lrs = cfg.make_schedulers(optimizers, epoch_size=len(dataloader)) scheduler_warmups = cfg.make_warmup_schedulers(optimizers, scheduler_lrs) loss_fns = [l.to(device) for l in cfg.make_losses()] epoch = 0 for step, data in enumerate(dataloader): model.train() model_args = [data[arg].to(device) for arg in cfg.model_args] labels = data["label"].to(device) if "label" in data else None params_dict, weights_dict = cfg.get_params(step, epoch), cfg.get_weights(step, epoch) for i, (loss_fn, optimizer, scheduler_lr, scheduler_warmup, optimizer_start) in enumerate( zip(loss_fns, optimizers, scheduler_lrs, scheduler_warmups, cfg.optimizer_starts), 1): optimizer.zero_grad() output = model(*model_args, params=params_dict) loss_dict = loss_fn(output, labels, weights=weights_dict) loss_dict["loss"].backward() if cfg.grad_clip is not None: nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) optimizer.step() if scheduler_lr is not None: scheduler_lr.step() if scheduler_warmup is not None: scheduler_warmup.step() if step % 20 == 0: print(f"Step {step}: loss: {loss_dict['loss']}") print("Finetuning done.") def compute_interpolation(project: DeepSVGProject): finetune_model(project) keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe] if len(keyframe_ids) < 2: return model.eval() for i1, i2 in zip(keyframe_ids[:-1], keyframe_ids[1:]): frames_inbetween = i2 - i1 - 1 if frames_inbetween == 0: continue svgs = interpolate_svg(project.frames[i1].svg, project.frames[i2].svg, n=frames_inbetween, ease=False) for di, svg in enumerate(svgs, 1): project.frames[i1 + di] = Frame(i1 + di, keyframe=False, svg=svg)