OmniSVG-3B / deepsvg /visualize.py
OmniSVG's picture
Upload 80 files
c1ce505 verified
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as patches
from deepsvg.svglib.svg import SVG
from deepsvg.svglib.geom import Bbox
from deepsvg.svgtensor_dataset import SVGTensorDataset, load_dataset
from configs.deepsvg.hierarchical_ordered import Config
from deepsvg.difflib.tensor import SVGTensor
from deepsvg import utils
# utils.set_seed(2022)
cfg = Config()
train, valid = load_dataset(cfg)
def draw_svg(idx):
data = train.get(idx, [*cfg.model_args, 'tensor_grouped'])
svg_tensor = data['tensor_grouped'][0].copy().drop_sos().unpad()
svg = SVG.from_tensor(svg_tensor.data, viewbox=Bbox(256)).normalize()
svg.draw(do_display=False, file_path='test.svg')
draw_svg(10)
def draw_png(idx):
data = train.get(idx, [*cfg.model_args])
print(data['args'].shape, data['commands'].shape)
commands = data['commands']
args = data['args']
fig, ax = plt.subplots()
ax.set_aspect('equal')
ax.set_axis_off()
ax.set_xlim(0, 255)
ax.set_ylim(0, 255)
for path_commands, path_args in zip(commands, args):
verts = []
codes = []
for command, args in zip(path_commands, path_args):
cmd = SVGTensor.COMMANDS_SIMPLIFIED[int(command)]
# MOVETO
if cmd == 'm':
x = args[-2].item()
y = args[-1].item()
verts.append((x, 256 - y))
codes.append(matplotlib.path.Path.MOVETO)
# LINETO
elif cmd == 'l':
x = args[-2].item()
y = args[-1].item()
verts.append((x, 256 - y))
codes.append(matplotlib.path.Path.LINETO)
# CUBIC BEZIER
elif cmd == 'c':
qx1 = args[-6].item()
qy1 = args[-5].item()
qx2 = args[-4].item()
qy2 = args[-3].item()
x = args[-2].item()
y = args[-1].item()
verts.append((qx1, 256 - qy1))
codes.append(matplotlib.path.Path.CURVE4)
verts.append((qx2, 256 - qy2))
codes.append(matplotlib.path.Path.CURVE4)
verts.append((x, 256 - y))
codes.append(matplotlib.path.Path.CURVE4)
# CLOSE POLY
elif cmd == 'z':
# z doesn't appear to be in the data
verts.append((0, 0))
codes.append(matplotlib.path.Path.CLOSEPOLY)
if len(verts) > 0:
path = matplotlib.path.Path(verts, codes)
patch = patches.PathPatch(path, facecolor='none', lw=2)
ax.add_patch(patch)
# plt.show()
plt.savefig('test.png')