screenshot2html / app.py
broadfield-dev's picture
Update app.py
059fe8c verified
import os
import subprocess
import spaces
import torch
import gradio as gr
from gradio_client.client import DEFAULT_TEMP_DIR
from playwright.sync_api import sync_playwright
from typing import List, Optional
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
import hashlib
# --- Optimization Parameters ---
RESIZE_IMAGE = 224 # Further reduce image size for faster processing
USE_QUANTIZED_MODEL = True # Try loading a quantized version if available
CACHE_RENDERED_HTML = True
FORCE_CPU = True # Force CPU usage
# --- Device Setup ---
DEVICE = torch.device("cpu") if FORCE_CPU else torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
# --- Model Loading ---
MODEL_NAME = "HuggingFaceM4/VLM_WebSight_finetuned"
QUANTIZED_MODEL_NAME = MODEL_NAME + ".quantized" # Check for quantized version
processor = None # Initialize outside the try block
model = None
try:
if USE_QUANTIZED_MODEL and os.path.exists(QUANTIZED_MODEL_NAME):
print(f"Loading quantized model: {QUANTIZED_MODEL_NAME}")
processor = AutoProcessor.from_pretrained(MODEL_NAME) # Use the original processor
model = AutoModelForCausalLM.from_pretrained(
QUANTIZED_MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float32, # or torch.float16 if supported
).to(DEVICE)
else:
print(f"Loading full model: {MODEL_NAME}")
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float32,
).to(DEVICE) # Load on CPU directly
except Exception as e:
print(f"Error loading model: {e}")
if model.config.use_resampler:
image_seq_len = model.config.perceiver_config.resampler_n_latents
else:
image_seq_len = (
model.config.vision_config.image_size // model.config.vision_config.patch_size
) ** 2
BOS_TOKEN = processor.tokenizer.bos_token
BAD_WORDS_IDS = processor.tokenizer(
["<image>", "<fake_token_around_image>"], add_special_tokens=False
).input_ids
# --- Utility Functions ---
def convert_to_rgb(image):
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
def custom_transform(x):
x = convert_to_rgb(x)
x = to_numpy_array(x)
x = resize(x, (RESIZE_IMAGE, RESIZE_IMAGE), resample=PILImageResampling.BILINEAR)
x = processor.image_processor.rescale(x, scale=1 / 255)
x = processor.image_processor.normalize(
x, mean=processor.image_processor.image_mean, std=processor.image_processor.image_std
)
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
x = torch.tensor(x).float() # Convert to float32 here
return x
# --- Playwright Installation ---
def install_playwright():
try:
subprocess.run(["playwright", "install"], check=True)
print("Playwright installation successful.")
except subprocess.CalledProcessError as e:
print(f"Error during Playwright installation: {e}")
install_playwright()
# --- HTML Rendering Cache ---
html_render_cache = {}
def render_webpage(html_css_code: str) -> Image.Image:
"""Renders HTML code to an image using Playwright, with caching."""
if CACHE_RENDERED_HTML:
html_hash = hashlib.md5(html_css_code.encode("utf-8")).hexdigest()
if html_hash in html_render_cache:
print("Using cached rendered HTML.")
return html_render_cache[html_hash]
with sync_playwright() as p:
browser = p.chromium.launch(headless=True) # Reuse browser if possible
context = browser.new_context(
user_agent=(
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
" Safari/537.36"
)
)
page = context.new_page()
page.set_content(html_css_code)
page.wait_for_load_state("networkidle")
output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
page.screenshot(path=output_path_screenshot, full_page=True)
context.close()
browser.close()
image = Image.open(output_path_screenshot)
if CACHE_RENDERED_HTML:
html_render_cache[html_hash] = image
return image
# --- Gallery ---
IMAGE_GALLERY_PATHS = [
f"example_images/{ex_image}" for ex_image in os.listdir("example_images")
]
def add_file_gallery(
selected_state: gr.SelectData, gallery_list: List[str]
) -> Image.Image:
return Image.open(gallery_list.root[selected_state.index].image.path)
# --- Model Inference ---
def model_inference(image: Image.Image) -> tuple[str, Image.Image]:
"""Performs model inference and renders the result."""
if image is None:
raise ValueError("`image` is None. It should be a PIL image.")
inputs = processor.tokenizer(
f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
return_tensors="pt",
add_special_tokens=False,
)
inputs["pixel_values"] = processor.image_processor(
[image],
transform=custom_transform
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad(): # Disable gradient calculation
generation_kwargs = dict(
inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096
)
generated_ids = model.generate(**generation_kwargs)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
rendered_page = render_webpage(generated_text)
return generated_text, rendered_page
# --- Gradio Interface ---
generated_html = gr.Code(
label="Extracted HTML", elem_id="generated_html"
)
rendered_html = gr.Image(
label="Rendered HTML", show_download_button=False, show_share_button=False
)
css = """
.gradio-container{max-width: 1000px!important}
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
*{transition: width 0.5s ease, flex-grow 0.5s ease}
"""
with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
gr.Markdown(
"Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
)
with gr.Row(equal_height=True):
with gr.Column(scale=4, min_width=250) as upload_area:
imagebox = gr.Image(
type="pil",
label="Screenshot to extract",
visible=True,
sources=["upload", "clipboard"],
)
with gr.Group():
with gr.Row():
submit_btn = gr.Button(
value="▶️ Submit", visible=True, min_width=120
)
clear_btn = gr.ClearButton(
[imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
)
regenerate_btn = gr.Button(
value="🔄 Regenerate", visible=True, min_width=120
)
with gr.Column(scale=4):
rendered_html.render()
with gr.Row():
generated_html.render()
with gr.Row():
template_gallery = gr.Gallery(
value=IMAGE_GALLERY_PATHS,
label="Templates Gallery",
allow_preview=False,
columns=5,
elem_id="gallery",
show_share_button=False,
height=400,
loading_lazy="eager",
)
gr.on(
triggers=[
imagebox.upload,
submit_btn.click,
regenerate_btn.click,
],
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
)
regenerate_btn.click(
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
)
template_gallery.select(
fn=add_file_gallery,
inputs=[template_gallery],
outputs=[imagebox],
).success(
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
)
demo.load()
demo.queue(max_size=40, api_open=False)
demo.launch(max_threads=400)