Spaces:
Running
on
L4
Running
on
L4
import gradio as gr | |
import torch | |
import os | |
from PIL import Image | |
import cairosvg | |
import io | |
import tempfile | |
import argparse | |
import gc | |
import yaml | |
import glob | |
from huggingface_hub import hf_hub_download | |
from decoder import SketchDecoder | |
from transformers import AutoTokenizer, AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
from qwen_vl_utils import process_vision_info | |
from tokenizer import SVGTokenizer | |
with open('config.yaml', 'r') as f: | |
config = yaml.safe_load(f) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
tokenizer = None | |
processor = None | |
sketch_decoder = None | |
svg_tokenizer = None | |
# System prompt | |
SYSTEM_PROMPT = "You are a multimodal SVG generation assistant capable of generating SVG code from both text descriptions and images." | |
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'] | |
def parse_args(): | |
"""Parse command line arguments""" | |
parser = argparse.ArgumentParser(description='SVG Generator Service') | |
parser.add_argument('--listen', type=str, default='0.0.0.0', | |
help='Listen address (default: 0.0.0.0)') | |
parser.add_argument('--port', type=int, default=7860, | |
help='Port number (default: 7860)') | |
parser.add_argument('--share', action='store_true', | |
help='Enable gradio share link') | |
parser.add_argument('--debug', action='store_true', | |
help='Enable debug mode') | |
return parser.parse_args() | |
def load_models(): | |
"""Load models""" | |
global tokenizer, processor, sketch_decoder, svg_tokenizer | |
if tokenizer is None: | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left") | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left") | |
sketch_decoder = SketchDecoder() | |
sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin",) | |
sketch_decoder.load_state_dict(torch.load(sketch_weight_path)) | |
sketch_decoder = sketch_decoder.to(device).eval() | |
svg_tokenizer = SVGTokenizer('config.yaml') | |
def process_and_resize_image(image_input, target_size=(200, 200)): | |
"""Process and resize image to target size""" | |
if isinstance(image_input, str): | |
image = Image.open(image_input) | |
elif isinstance(image_input, Image.Image): | |
image = image_input | |
else: | |
image = Image.fromarray(image_input) | |
image = image.resize(target_size, Image.Resampling.LANCZOS) | |
return image | |
def get_example_images(): | |
"""Get example images from the examples directory""" | |
example_dir = "./examples" | |
example_images = [] | |
if os.path.exists(example_dir): | |
for ext in SUPPORTED_FORMATS: | |
pattern = os.path.join(example_dir, f"*{ext}") | |
example_images.extend(glob.glob(pattern)) | |
example_images.sort() | |
return example_images | |
def process_text_to_svg(text_description): | |
"""Process text-to-svg task""" | |
load_models() | |
messages = [{ | |
"role": "system", | |
"content": SYSTEM_PROMPT | |
}, { | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"Task: text-to-svg\nDescription: {text_description}\nGenerate SVG code based on the above description."} | |
] | |
}] | |
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor( | |
text=[text_input], | |
truncation=True, | |
return_tensors="pt" | |
) | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
pixel_values = None | |
image_grid_thw = None | |
return input_ids, attention_mask, pixel_values, image_grid_thw | |
def process_image_to_svg(image_path): | |
"""Process image-to-svg task""" | |
load_models() | |
messages = [{ | |
"role": "system", | |
"content": SYSTEM_PROMPT | |
}, { | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"Task: image-to-svg\nGenerate SVG code that accurately represents the following image."}, | |
{"type": "image", "image": image_path}, | |
] | |
}] | |
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, _ = process_vision_info(messages) | |
inputs = processor( | |
text=[text_input], | |
images=image_inputs, | |
truncation=True, | |
return_tensors="pt" | |
) | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
pixel_values = inputs['pixel_values'].to(device) if 'pixel_values' in inputs else None | |
image_grid_thw = inputs['image_grid_thw'].to(device) if 'image_grid_thw' in inputs else None | |
return input_ids, attention_mask, pixel_values, image_grid_thw | |
def generate_svg(input_ids, attention_mask, pixel_values=None, image_grid_thw=None, task_type="image-to-svg"): | |
"""Generate SVG""" | |
try: | |
# Clean memory before generation | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
print(f"Generating SVG for {task_type}...") | |
# Generation configuration, just adjust for better results. | |
if task_type == "image-to-svg": | |
#Image-to-SVG configuration | |
gen_config = dict( | |
do_sample=True, | |
temperature=0.1, | |
top_p=0.001, | |
top_k=1, | |
num_beams=5, | |
repetition_penalty=1.05, | |
) | |
else: | |
#Text-to-SVG configuration | |
gen_config = dict( | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.05, | |
early_stopping=True, | |
) | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
# Generate SVG | |
model_config = config['model'] | |
max_length = model_config['max_length'] | |
output_ids = torch.ones(1, max_length).long().to(device) * model_config['eos_token_id'] | |
with torch.no_grad(): | |
results = sketch_decoder.transformer.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
pixel_values=pixel_values, | |
image_grid_thw=image_grid_thw, | |
max_new_tokens=max_length-1, | |
num_return_sequences=1, | |
bos_token_id=model_config['bos_token_id'], | |
eos_token_id=model_config['eos_token_id'], | |
pad_token_id=model_config['pad_token_id'], | |
use_cache=True, | |
**gen_config | |
) | |
results = results[:, :max_length-1] | |
output_ids[:, :results.shape[1]] = results | |
generated_xy, generated_colors = svg_tokenizer.process_generated_tokens(output_ids) | |
svg_tensors = svg_tokenizer.raster_svg(generated_xy) | |
if not svg_tensors or not svg_tensors[0]: | |
return "Error: No valid SVG paths generated", None | |
print('Creating SVG...') | |
svg = svg_tokenizer.apply_colors_to_svg(svg_tensors[0], generated_colors) | |
svg_str = svg.to_str() | |
# Convert to PNG for visualization | |
png_data = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
png_image = Image.open(io.BytesIO(png_data)) | |
return svg_str, png_image | |
except Exception as e: | |
print(f"Generation error: {e}") | |
import traceback | |
traceback.print_exc() | |
return f"Error: {e}", None | |
def gradio_image_to_svg(image): | |
"""Gradio interface function - image-to-svg""" | |
if image is None: | |
return "Please upload an image", None | |
processed_image = process_and_resize_image(image) | |
# Save temporary image file | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
processed_image.save(tmp_file.name, format='PNG') | |
tmp_path = tmp_file.name | |
try: | |
input_ids, attention_mask, pixel_values, image_grid_thw = process_image_to_svg(tmp_path) | |
svg_code, png_image = generate_svg(input_ids, attention_mask, pixel_values, image_grid_thw, "image-to-svg") | |
return svg_code, png_image | |
finally: | |
# Clean up temporary file | |
if os.path.exists(tmp_path): | |
os.unlink(tmp_path) | |
def gradio_text_to_svg(text_description): | |
"""Gradio interface function - text-to-svg""" | |
if not text_description or text_description.strip() == "": | |
return "Please enter a description", None | |
input_ids, attention_mask, pixel_values, image_grid_thw = process_text_to_svg(text_description) | |
svg_code, png_image = generate_svg(input_ids, attention_mask, pixel_values, image_grid_thw, "text-to-svg") | |
return svg_code, png_image | |
def create_interface(): | |
# Example texts | |
example_texts = [ | |
"A yellow t-shirt with a heart design represents love and positivity.", | |
"A bright yellow emoji with a surprised expression and rosy cheeks hovers above a shadow.", | |
"A brown coffee cup on a white saucer is seen from a top-down perspective.", | |
"A cartoon firefighter in a red and yellow uniform represents safety and protection.", | |
"A cute bunny face with pink ears rosy cheeks and a playful red tongue conveys charm and cheerfulness.", | |
"A bearded man with orange hair and a mustache represents a hipster style portrait.", | |
"A colorful ice cream popsicle with a hint of chocolate at the bottom on a stick.", | |
"A light blue shopping bag features a white flower with a red center and scattered dots.", | |
"A yellow phone icon and orange arrow on a blue smartphone screen symbolize an incoming call.", | |
"A sad wilted flower with pink petals slumps over an orange cloud with a blue striped background.", | |
"A cartoon character with dark blue hair and a mustache wears a blue suit against a light blue circular background.", | |
"A blue bookmark icon with a white plus sign in the center.", | |
"A computer monitor displays a bar graph with yellow orange and green bars.", | |
"A blue and gray database icon is overlaid with a yellow star in the bottom right corner.", | |
"An orange thermometer with a circular base represents temperature measurement.", | |
"A green delivery truck icon with a checkmark symbolizing a completed delivery.", | |
"A yellow t-shirt with a heart design represents love and positivity.", | |
"A blue and gray microphone icon symbolizes audio recording or voice input.", | |
"Cloud icon with an upward arrow symbolizes uploading or cloud storage.", | |
"A brown chocolate bar is depicted in four square segments with a shiny glossy finish.", | |
"A colorful moving truck icon with a red and orange cargo container.", | |
"A light blue T-shirt icon is outlined with a bold blue border.", | |
"A person in a blue shirt and dark pants stands with one hand in a pocket gesturing outward.", | |
] | |
example_images = get_example_images() | |
with gr.Blocks(title="OmniSVG Demo Page", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# OmniSVG Demo Page") | |
gr.Markdown("Generate SVG code from images or text descriptions") | |
with gr.Tabs(): | |
# Image-to-SVG tab | |
with gr.TabItem("Image-to-SVG"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
label="Input Image", | |
type="pil", | |
image_mode="RGBA" | |
) | |
if example_images: | |
gr.Examples( | |
examples=example_images, | |
inputs=[image_input], | |
label="Example Images (click to use)", | |
examples_per_page=12 | |
) | |
image_generate_btn = gr.Button("Generate SVG", variant="primary") | |
with gr.Column(): | |
image_svg_output = gr.Textbox( | |
label="Generated SVG Code", | |
lines=10, | |
max_lines=20, | |
show_copy_button=True | |
) | |
image_png_preview = gr.Image(label="SVG Preview", type="pil") | |
image_generate_btn.click( | |
fn=gradio_image_to_svg, | |
inputs=[image_input], | |
outputs=[image_svg_output, image_png_preview], | |
queue=True | |
) | |
# Text-to-SVG tab | |
with gr.TabItem("Text-to-SVG"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Description", | |
placeholder="Enter SVG description, e.g.: a red circle with a blue square inside", | |
lines=3 | |
) | |
# Add example texts | |
gr.Examples( | |
examples=[[text] for text in example_texts], | |
inputs=[text_input], | |
label="Example Descriptions (click to use)", | |
examples_per_page=10 | |
) | |
text_generate_btn = gr.Button("Generate SVG", variant="primary") | |
with gr.Column(): | |
text_svg_output = gr.Textbox( | |
label="Generated SVG Code", | |
lines=10, | |
max_lines=20, | |
show_copy_button=True | |
) | |
text_png_preview = gr.Image(label="SVG Preview", type="pil") | |
text_generate_btn.click( | |
fn=gradio_text_to_svg, | |
inputs=[text_input], | |
outputs=[text_svg_output, text_png_preview], | |
queue=True | |
) | |
# Add usage instructions | |
gr.Markdown(""" | |
## Usage Instructions | |
- **Image-to-SVG**: Upload a PNG image and click "Generate SVG" | |
- **Text-to-SVG**: Enter a text description or click an example, then click "Generate SVG" | |
### Performance Tips: | |
- **Image-to-SVG**: Input images with a white background in JPG format, or using 4-channel PNG images. This will help achieve better results. | |
- **Text-to-SVG**: Keep descriptions concise and specific. Focus on 2-3 main elements with clear spatial relationships. Avoid overly complex prompts with too many objects or detailed arrangements. Just retry more times to achieve the satisfying results. | |
### Supported Descriptions: | |
- **Objects**: circle, square, triangle, rectangle, star, polygon, heart, diamond, hexagon, oval | |
- **Colors**: red, blue, green, yellow, purple, orange, pink, black, white, gray, brown, cyan | |
- **Formations**: arranged in a row, arranged in a circle, grid pattern, scattered, clustered, stacked | |
- **Styles**: filled, outlined, gradient, striped, dotted, solid, transparent, shaded | |
""") | |
return demo | |
if __name__ == "__main__": | |
# Set environment variable to avoid tokenizer parallelization warning | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
args = parse_args() | |
# Load models before starting | |
print("Loading models...") | |
load_models() | |
print("Models loaded successfully!") | |
# Create and launch interface | |
demo = create_interface() | |
demo.launch( | |
server_name=args.listen, | |
server_port=args.port, | |
share=args.share, | |
debug=args.debug | |
) | |