Spaces:
Paused
Paused
import torch | |
from torch.utils.data import DataLoader | |
import torch.nn as nn | |
from configs.deepsvg.hierarchical_ordered import Config | |
from deepsvg import utils | |
from deepsvg.svglib.svg import SVG | |
from deepsvg.difflib.tensor import SVGTensor | |
from deepsvg.svglib.geom import Bbox | |
from deepsvg.svgtensor_dataset import load_dataset, SVGFinetuneDataset | |
from deepsvg.utils.utils import batchify | |
from .state.project import DeepSVGProject, Frame | |
from .utils import easein_easeout | |
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu") | |
pretrained_path = "./pretrained/hierarchical_ordered.pth.tar" | |
cfg = Config() | |
cfg.model_cfg.dropout = 0. # for faster convergence | |
model = cfg.make_model().to(device) | |
model.eval() | |
dataset = load_dataset(cfg) | |
def decode(z): | |
commands_y, args_y, _ = model.greedy_sample(z=z) | |
tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu()) | |
svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256)) | |
return svg_path_sample | |
def encode_svg(svg): | |
data = dataset.get(model_args=[*cfg.model_args, "tensor_grouped"], svg=svg) | |
model_args = batchify((data[key] for key in cfg.model_args), device) | |
z = model(*model_args, encode_mode=True) | |
return z | |
def interpolate_svg(svg1, svg2, n=10, ease=True): | |
z1, z2 = encode_svg(svg1), encode_svg(svg2) | |
alphas = torch.linspace(0., 1., n+2)[1:-1] | |
if ease: | |
alphas = easein_easeout(alphas) | |
z_list = [(1 - a) * z1 + a * z2 for a in alphas] | |
svgs = [decode(z) for z in z_list] | |
return svgs | |
def finetune_model(project: DeepSVGProject, nb_augmentations=3500): | |
keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe] | |
if len(keyframe_ids) < 2: | |
return | |
svgs = [project.frames[i].svg for i in keyframe_ids] | |
utils.load_model(pretrained_path, model) | |
print("Finetuning...") | |
finetune_dataset = SVGFinetuneDataset(dataset, svgs, frac=1.0, nb_augmentations=nb_augmentations) | |
dataloader = DataLoader(finetune_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=False, | |
num_workers=cfg.loader_num_workers, collate_fn=cfg.collate_fn) | |
# Optimizer, lr & warmup schedulers | |
optimizers = cfg.make_optimizers(model) | |
scheduler_lrs = cfg.make_schedulers(optimizers, epoch_size=len(dataloader)) | |
scheduler_warmups = cfg.make_warmup_schedulers(optimizers, scheduler_lrs) | |
loss_fns = [l.to(device) for l in cfg.make_losses()] | |
epoch = 0 | |
for step, data in enumerate(dataloader): | |
model.train() | |
model_args = [data[arg].to(device) for arg in cfg.model_args] | |
labels = data["label"].to(device) if "label" in data else None | |
params_dict, weights_dict = cfg.get_params(step, epoch), cfg.get_weights(step, epoch) | |
for i, (loss_fn, optimizer, scheduler_lr, scheduler_warmup, optimizer_start) in enumerate( | |
zip(loss_fns, optimizers, scheduler_lrs, scheduler_warmups, cfg.optimizer_starts), 1): | |
optimizer.zero_grad() | |
output = model(*model_args, params=params_dict) | |
loss_dict = loss_fn(output, labels, weights=weights_dict) | |
loss_dict["loss"].backward() | |
if cfg.grad_clip is not None: | |
nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) | |
optimizer.step() | |
if scheduler_lr is not None: | |
scheduler_lr.step() | |
if scheduler_warmup is not None: | |
scheduler_warmup.step() | |
if step % 20 == 0: | |
print(f"Step {step}: loss: {loss_dict['loss']}") | |
print("Finetuning done.") | |
def compute_interpolation(project: DeepSVGProject): | |
finetune_model(project) | |
keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe] | |
if len(keyframe_ids) < 2: | |
return | |
model.eval() | |
for i1, i2 in zip(keyframe_ids[:-1], keyframe_ids[1:]): | |
frames_inbetween = i2 - i1 - 1 | |
if frames_inbetween == 0: | |
continue | |
svgs = interpolate_svg(project.frames[i1].svg, project.frames[i2].svg, n=frames_inbetween, ease=False) | |
for di, svg in enumerate(svgs, 1): | |
project.frames[i1 + di] = Frame(i1 + di, keyframe=False, svg=svg) | |