OmniSVG-3B / deepsvg /gui /interpolate.py
OmniSVG's picture
Upload 80 files
c1ce505 verified
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from configs.deepsvg.hierarchical_ordered import Config
from deepsvg import utils
from deepsvg.svglib.svg import SVG
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.geom import Bbox
from deepsvg.svgtensor_dataset import load_dataset, SVGFinetuneDataset
from deepsvg.utils.utils import batchify
from .state.project import DeepSVGProject, Frame
from .utils import easein_easeout
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
cfg = Config()
cfg.model_cfg.dropout = 0. # for faster convergence
model = cfg.make_model().to(device)
model.eval()
dataset = load_dataset(cfg)
def decode(z):
commands_y, args_y, _ = model.greedy_sample(z=z)
tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256))
return svg_path_sample
def encode_svg(svg):
data = dataset.get(model_args=[*cfg.model_args, "tensor_grouped"], svg=svg)
model_args = batchify((data[key] for key in cfg.model_args), device)
z = model(*model_args, encode_mode=True)
return z
def interpolate_svg(svg1, svg2, n=10, ease=True):
z1, z2 = encode_svg(svg1), encode_svg(svg2)
alphas = torch.linspace(0., 1., n+2)[1:-1]
if ease:
alphas = easein_easeout(alphas)
z_list = [(1 - a) * z1 + a * z2 for a in alphas]
svgs = [decode(z) for z in z_list]
return svgs
def finetune_model(project: DeepSVGProject, nb_augmentations=3500):
keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe]
if len(keyframe_ids) < 2:
return
svgs = [project.frames[i].svg for i in keyframe_ids]
utils.load_model(pretrained_path, model)
print("Finetuning...")
finetune_dataset = SVGFinetuneDataset(dataset, svgs, frac=1.0, nb_augmentations=nb_augmentations)
dataloader = DataLoader(finetune_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=False,
num_workers=cfg.loader_num_workers, collate_fn=cfg.collate_fn)
# Optimizer, lr & warmup schedulers
optimizers = cfg.make_optimizers(model)
scheduler_lrs = cfg.make_schedulers(optimizers, epoch_size=len(dataloader))
scheduler_warmups = cfg.make_warmup_schedulers(optimizers, scheduler_lrs)
loss_fns = [l.to(device) for l in cfg.make_losses()]
epoch = 0
for step, data in enumerate(dataloader):
model.train()
model_args = [data[arg].to(device) for arg in cfg.model_args]
labels = data["label"].to(device) if "label" in data else None
params_dict, weights_dict = cfg.get_params(step, epoch), cfg.get_weights(step, epoch)
for i, (loss_fn, optimizer, scheduler_lr, scheduler_warmup, optimizer_start) in enumerate(
zip(loss_fns, optimizers, scheduler_lrs, scheduler_warmups, cfg.optimizer_starts), 1):
optimizer.zero_grad()
output = model(*model_args, params=params_dict)
loss_dict = loss_fn(output, labels, weights=weights_dict)
loss_dict["loss"].backward()
if cfg.grad_clip is not None:
nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
optimizer.step()
if scheduler_lr is not None:
scheduler_lr.step()
if scheduler_warmup is not None:
scheduler_warmup.step()
if step % 20 == 0:
print(f"Step {step}: loss: {loss_dict['loss']}")
print("Finetuning done.")
def compute_interpolation(project: DeepSVGProject):
finetune_model(project)
keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe]
if len(keyframe_ids) < 2:
return
model.eval()
for i1, i2 in zip(keyframe_ids[:-1], keyframe_ids[1:]):
frames_inbetween = i2 - i1 - 1
if frames_inbetween == 0:
continue
svgs = interpolate_svg(project.frames[i1].svg, project.frames[i2].svg, n=frames_inbetween, ease=False)
for di, svg in enumerate(svgs, 1):
project.frames[i1 + di] = Frame(i1 + di, keyframe=False, svg=svg)