import os import subprocess import spaces import torch import gradio as gr from gradio_client.client import DEFAULT_TEMP_DIR from playwright.sync_api import sync_playwright from typing import List, Optional from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format import hashlib # --- Optimization Parameters --- RESIZE_IMAGE = 224 # Further reduce image size for faster processing USE_QUANTIZED_MODEL = True # Try loading a quantized version if available CACHE_RENDERED_HTML = True FORCE_CPU = True # Force CPU usage # --- Device Setup --- DEVICE = torch.device("cpu") if FORCE_CPU else torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {DEVICE}") # --- Model Loading --- MODEL_NAME = "HuggingFaceM4/VLM_WebSight_finetuned" QUANTIZED_MODEL_NAME = MODEL_NAME + ".quantized" # Check for quantized version processor = None # Initialize outside the try block model = None try: if USE_QUANTIZED_MODEL and os.path.exists(QUANTIZED_MODEL_NAME): print(f"Loading quantized model: {QUANTIZED_MODEL_NAME}") processor = AutoProcessor.from_pretrained(MODEL_NAME) # Use the original processor model = AutoModelForCausalLM.from_pretrained( QUANTIZED_MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, # or torch.float16 if supported ).to(DEVICE) else: print(f"Loading full model: {MODEL_NAME}") processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, ).to(DEVICE) # Load on CPU directly except Exception as e: print(f"Error loading model: {e}") if model.config.use_resampler: image_seq_len = model.config.perceiver_config.resampler_n_latents else: image_seq_len = ( model.config.vision_config.image_size // model.config.vision_config.patch_size ) ** 2 BOS_TOKEN = processor.tokenizer.bos_token BAD_WORDS_IDS = processor.tokenizer( ["", ""], add_special_tokens=False ).input_ids # --- Utility Functions --- def convert_to_rgb(image): if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def custom_transform(x): x = convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (RESIZE_IMAGE, RESIZE_IMAGE), resample=PILImageResampling.BILINEAR) x = processor.image_processor.rescale(x, scale=1 / 255) x = processor.image_processor.normalize( x, mean=processor.image_processor.image_mean, std=processor.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x).float() # Convert to float32 here return x # --- Playwright Installation --- def install_playwright(): try: subprocess.run(["playwright", "install"], check=True) print("Playwright installation successful.") except subprocess.CalledProcessError as e: print(f"Error during Playwright installation: {e}") install_playwright() # --- HTML Rendering Cache --- html_render_cache = {} def render_webpage(html_css_code: str) -> Image.Image: """Renders HTML code to an image using Playwright, with caching.""" if CACHE_RENDERED_HTML: html_hash = hashlib.md5(html_css_code.encode("utf-8")).hexdigest() if html_hash in html_render_cache: print("Using cached rendered HTML.") return html_render_cache[html_hash] with sync_playwright() as p: browser = p.chromium.launch(headless=True) # Reuse browser if possible context = browser.new_context( user_agent=( "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0" " Safari/537.36" ) ) page = context.new_page() page.set_content(html_css_code) page.wait_for_load_state("networkidle") output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png" page.screenshot(path=output_path_screenshot, full_page=True) context.close() browser.close() image = Image.open(output_path_screenshot) if CACHE_RENDERED_HTML: html_render_cache[html_hash] = image return image # --- Gallery --- IMAGE_GALLERY_PATHS = [ f"example_images/{ex_image}" for ex_image in os.listdir("example_images") ] def add_file_gallery( selected_state: gr.SelectData, gallery_list: List[str] ) -> Image.Image: return Image.open(gallery_list.root[selected_state.index].image.path) # --- Model Inference --- def model_inference(image: Image.Image) -> tuple[str, Image.Image]: """Performs model inference and renders the result.""" if image is None: raise ValueError("`image` is None. It should be a PIL image.") inputs = processor.tokenizer( f"{BOS_TOKEN}{'' * image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = processor.image_processor( [image], transform=custom_transform ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): # Disable gradient calculation generation_kwargs = dict( inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096 ) generated_ids = model.generate(**generation_kwargs) generated_text = processor.batch_decode( generated_ids, skip_special_tokens=True )[0] rendered_page = render_webpage(generated_text) return generated_text, rendered_page # --- Gradio Interface --- generated_html = gr.Code( label="Extracted HTML", elem_id="generated_html" ) rendered_html = gr.Image( label="Rendered HTML", show_download_button=False, show_share_button=False ) css = """ .gradio-container{max-width: 1000px!important} h1{display: flex;align-items: center;justify-content: center;gap: .25em} *{transition: width 0.5s ease, flex-grow 0.5s ease} """ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo: gr.Markdown( "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content." ) with gr.Row(equal_height=True): with gr.Column(scale=4, min_width=250) as upload_area: imagebox = gr.Image( type="pil", label="Screenshot to extract", visible=True, sources=["upload", "clipboard"], ) with gr.Group(): with gr.Row(): submit_btn = gr.Button( value="▶️ Submit", visible=True, min_width=120 ) clear_btn = gr.ClearButton( [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120 ) regenerate_btn = gr.Button( value="🔄 Regenerate", visible=True, min_width=120 ) with gr.Column(scale=4): rendered_html.render() with gr.Row(): generated_html.render() with gr.Row(): template_gallery = gr.Gallery( value=IMAGE_GALLERY_PATHS, label="Templates Gallery", allow_preview=False, columns=5, elem_id="gallery", show_share_button=False, height=400, loading_lazy="eager", ) gr.on( triggers=[ imagebox.upload, submit_btn.click, regenerate_btn.click, ], fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) regenerate_btn.click( fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) template_gallery.select( fn=add_file_gallery, inputs=[template_gallery], outputs=[imagebox], ).success( fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) demo.load() demo.queue(max_size=40, api_open=False) demo.launch(max_threads=400)