import numpy as np import torch import yaml from typing import List, Tuple, Dict, Optional, Union from deepsvg.difflib.tensor import SVGTensor from deepsvg.svglib.svg import SVG from deepsvg.svglib.geom import Bbox class SVGTokenizer: """SVG tokenizer for converting between tokens and SVG representations""" def __init__(self, config_path: str = "config.yaml"): with open(config_path, 'r') as f: self.config = yaml.safe_load(f) # Extract configuration values self.tokens_config = self.config['tokens'] self.coordinates_config = self.config['coordinates'] self.colors_config = self.config['colors'] self.svg_commands = self.config['svg_commands'] self.pixel2xy = self._create_pixel2xy_mapping() def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]: """Create mapping from pixel indices to xy coordinates""" bbox = self.coordinates_config['bbox'] coord_pad = self.coordinates_config['coord_pad_offset'] svg_end = self.tokens_config['svg_end'] pixel2xy = {} x = np.linspace(0, bbox-1, bbox) y = np.linspace(0, bbox-1, bbox) xx, yy = np.meshgrid(x, y) xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int) for pixel, xy in enumerate(xy_grid): pixel2xy[pixel] = xy + coord_pad + svg_end return pixel2xy def token_to_color(self, color_token: int) -> str: try: color_token_start = self.colors_config['color_token_start'] max_color_tokens = self.colors_config['max_color_tokens'] # Check special color tokens if color_token == color_token_start: return "none" # No color elif color_token == color_token_start + 1: return "currentColor" # Special color color_index = color_token - (color_token_start + 2) if color_index < 0 or color_index >= max_color_tokens: print(f"Warning: Color token {color_token} out of range, using default color") return "#808080" # Gray as default r = (color_index >> 8) & 0xF g = (color_index >> 4) & 0xF b = color_index & 0xF r = (r << 4) | r g = (g << 4) | g b = (b << 4) | b return f"#{r:02x}{g:02x}{b:02x}" except Exception as e: print(f"Error in token_to_color: {e}") return "#808080" def pixel_to_xy(self, pixel: int) -> np.ndarray: """Convert pixel token to xy coordinates""" base_offset = self.tokens_config['base_offset'] pix_pad = self.coordinates_config['pix_pad_offset'] svg_end = self.tokens_config['svg_end'] if self.tokens_config['eom'] < pixel < pix_pad + svg_end: xy = np.array([pixel - base_offset, pixel - base_offset]).astype(int) return xy elif pix_pad + svg_end <= pixel < self.colors_config['cmd_fill'] + base_offset + svg_end: pixel_index = pixel - pix_pad - svg_end if pixel_index in self.pixel2xy: return self.pixel2xy[pixel_index] - base_offset else: raise ValueError(f"Invalid pixel index: {pixel_index}") else: raise ValueError(f"Invalid pixel token: {pixel}") def raster_svg(self, pixels: np.ndarray) -> List[List[torch.Tensor]]: """Convert pixel sequence to SVG tensor representation""" try: adjustment = self.tokens_config['num_end_token'] + self.tokens_config['svg_end'] + 2 # 8 pixels = pixels - adjustment svg_tensors = [] path_tensor = [] i = 0 while i < len(pixels): try: pix = pixels[i] if pix[0] == self.svg_commands['move']: # Move command cmd_tensor = np.zeros(14) cmd_tensor[0] = 0 if i + 2 >= len(pixels): break cmd_tensor[12:14] = pixels[i+2] start_pos = pixels[i+1] end_pos = pixels[i+2] if np.all(start_pos == end_pos) and path_tensor: svg_tensors.append(torch.tensor(path_tensor)) path_tensor = [] path_tensor.append(cmd_tensor.tolist()) i += 3 elif pix[0] == self.svg_commands['line']: # Line command cmd_tensor = np.zeros(14) cmd_tensor[0] = 1 if i + 1 >= len(pixels): break cmd_tensor[12:14] = pixels[i+1] path_tensor.append(cmd_tensor.tolist()) i += 2 elif pix[0] == self.svg_commands['curve']: # Curve command cmd_tensor = np.zeros(14) cmd_tensor[0] = 2 if i + 3 >= len(pixels): break cmd_tensor[8:10] = pixels[i+1] cmd_tensor[10:12] = pixels[i+2] cmd_tensor[12:14] = pixels[i+3] path_tensor.append(cmd_tensor.tolist()) i += 4 elif pix[0] == self.svg_commands['arc']: # Arc command cmd_tensor = np.zeros(14) cmd_tensor[0] = 3 if i + 5 >= len(pixels): break radius = pixels[i+1] x_axis_rot = pixels[i+2][0] large_arc_flg = pixels[i+3][0] sweep_flg = pixels[i+4][0] end_pos = pixels[i+5] cmd_tensor[1:3] = radius cmd_tensor[3] = x_axis_rot cmd_tensor[4] = large_arc_flg cmd_tensor[5] = sweep_flg cmd_tensor[12:14] = end_pos path_tensor.append(cmd_tensor.tolist()) i += 6 elif pix[0] == self.svg_commands['close']: # Close command cmd_tensor = np.zeros(14) cmd_tensor[0] = 6 if i + 1 >= len(pixels): break cmd_tensor[12:14] = pixels[i+1] path_tensor.append(cmd_tensor.tolist()) i += 2 else: i += 1 except IndexError: print(f"Index error at position {i}, stopping SVG processing") break if path_tensor: svg_tensors.append(torch.tensor(path_tensor)) return [svg_tensors] except Exception as e: print(f"Error in raster_svg: {e}") return [] def extract_colors_from_tokens(self, tokens: List[int]) -> List[int]: colors = [] base_offset = self.tokens_config['base_offset'] color_start = self.colors_config['color_start_offset'] color_end = self.colors_config['color_end_offset'] for token in tokens: if color_start <= token < color_end: colors.append(token - 1 - base_offset) return colors def process_generated_tokens(self, output_ids: torch.Tensor) -> Tuple[np.ndarray, List[int]]: # Remove and tokens generated_pixels = output_ids[:, 1:-1].tolist() generated_xy = [] generated_colors = [] for pixel_sequence in generated_pixels: xy_sequence = [] colors = [] for pixel in pixel_sequence: try: if self.tokens_config['eom'] < pixel < self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end']: xy = self.pixel_to_xy(pixel) xy_sequence.append(xy) elif self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end'] <= pixel < self.colors_config['cmd_fill'] + self.tokens_config['base_offset'] + self.tokens_config['svg_end']: xy = self.pixel_to_xy(pixel) xy_sequence.append(xy) elif self.colors_config['color_start_offset'] <= pixel < self.colors_config['color_end_offset']: colors.append(pixel - 1 - self.tokens_config['base_offset']) except ValueError as e: print(f"Error processing pixel {pixel}: {e}") continue if xy_sequence: generated_xy = np.vstack(xy_sequence) generated_colors = colors return generated_xy, generated_colors def apply_colors_to_svg(self, svg_tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]], colors: Optional[List[int]]) -> SVG: paths = [] bbox = self.coordinates_config['bbox'] flat_tensors = [] if svg_tensors and isinstance(svg_tensors[0], list): for tensor_list in svg_tensors: flat_tensors.extend(tensor_list) else: flat_tensors = svg_tensors if not flat_tensors: raise ValueError("No valid SVG tensors provided") if colors is None: colors = [] for i, path_tensor in enumerate(flat_tensors): try: path = SVGTensor.from_data(path_tensor) path = SVG.from_tensor(path.data, viewbox=Bbox(bbox)) if i < len(colors): color_token = colors[i] actual_color = self.token_to_color(color_token) else: actual_color = "none" for path_group in path: path_group.color = actual_color path_group.stroke_color = "none" path.fill_(True) paths.append(path) except Exception as e: print(f"Error processing path {i}: {e}") continue if not paths: raise ValueError("No valid paths could be generated") path_groups = paths[0].svg_path_groups for i in range(1, len(paths)): if i < len(paths): path_groups.extend(paths[i].svg_path_groups) svg = SVG(path_groups, viewbox=Bbox(bbox)) return svg