Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
import spaces | |
import torch | |
import gradio as gr | |
from gradio_client.client import DEFAULT_TEMP_DIR | |
from playwright.sync_api import sync_playwright | |
from typing import List, Optional | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension | |
from transformers.image_transforms import resize, to_channel_dimension_format | |
import hashlib | |
# --- Optimization Parameters --- | |
RESIZE_IMAGE = 224 # Further reduce image size for faster processing | |
USE_QUANTIZED_MODEL = True # Try loading a quantized version if available | |
CACHE_RENDERED_HTML = True | |
FORCE_CPU = True # Force CPU usage | |
# --- Device Setup --- | |
DEVICE = torch.device("cpu") if FORCE_CPU else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {DEVICE}") | |
# --- Model Loading --- | |
MODEL_NAME = "HuggingFaceM4/VLM_WebSight_finetuned" | |
QUANTIZED_MODEL_NAME = MODEL_NAME + ".quantized" # Check for quantized version | |
processor = None # Initialize outside the try block | |
model = None | |
try: | |
if USE_QUANTIZED_MODEL and os.path.exists(QUANTIZED_MODEL_NAME): | |
print(f"Loading quantized model: {QUANTIZED_MODEL_NAME}") | |
processor = AutoProcessor.from_pretrained(MODEL_NAME) # Use the original processor | |
model = AutoModelForCausalLM.from_pretrained( | |
QUANTIZED_MODEL_NAME, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, # or torch.float16 if supported | |
).to(DEVICE) | |
else: | |
print(f"Loading full model: {MODEL_NAME}") | |
processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
).to(DEVICE) # Load on CPU directly | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
if model.config.use_resampler: | |
image_seq_len = model.config.perceiver_config.resampler_n_latents | |
else: | |
image_seq_len = ( | |
model.config.vision_config.image_size // model.config.vision_config.patch_size | |
) ** 2 | |
BOS_TOKEN = processor.tokenizer.bos_token | |
BAD_WORDS_IDS = processor.tokenizer( | |
["<image>", "<fake_token_around_image>"], add_special_tokens=False | |
).input_ids | |
# --- Utility Functions --- | |
def convert_to_rgb(image): | |
if image.mode == "RGB": | |
return image | |
image_rgba = image.convert("RGBA") | |
background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
alpha_composite = Image.alpha_composite(background, image_rgba) | |
alpha_composite = alpha_composite.convert("RGB") | |
return alpha_composite | |
def custom_transform(x): | |
x = convert_to_rgb(x) | |
x = to_numpy_array(x) | |
x = resize(x, (RESIZE_IMAGE, RESIZE_IMAGE), resample=PILImageResampling.BILINEAR) | |
x = processor.image_processor.rescale(x, scale=1 / 255) | |
x = processor.image_processor.normalize( | |
x, mean=processor.image_processor.image_mean, std=processor.image_processor.image_std | |
) | |
x = to_channel_dimension_format(x, ChannelDimension.FIRST) | |
x = torch.tensor(x).float() # Convert to float32 here | |
return x | |
# --- Playwright Installation --- | |
def install_playwright(): | |
try: | |
subprocess.run(["playwright", "install"], check=True) | |
print("Playwright installation successful.") | |
except subprocess.CalledProcessError as e: | |
print(f"Error during Playwright installation: {e}") | |
install_playwright() | |
# --- HTML Rendering Cache --- | |
html_render_cache = {} | |
def render_webpage(html_css_code: str) -> Image.Image: | |
"""Renders HTML code to an image using Playwright, with caching.""" | |
if CACHE_RENDERED_HTML: | |
html_hash = hashlib.md5(html_css_code.encode("utf-8")).hexdigest() | |
if html_hash in html_render_cache: | |
print("Using cached rendered HTML.") | |
return html_render_cache[html_hash] | |
with sync_playwright() as p: | |
browser = p.chromium.launch(headless=True) # Reuse browser if possible | |
context = browser.new_context( | |
user_agent=( | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0" | |
" Safari/537.36" | |
) | |
) | |
page = context.new_page() | |
page.set_content(html_css_code) | |
page.wait_for_load_state("networkidle") | |
output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png" | |
page.screenshot(path=output_path_screenshot, full_page=True) | |
context.close() | |
browser.close() | |
image = Image.open(output_path_screenshot) | |
if CACHE_RENDERED_HTML: | |
html_render_cache[html_hash] = image | |
return image | |
# --- Gallery --- | |
IMAGE_GALLERY_PATHS = [ | |
f"example_images/{ex_image}" for ex_image in os.listdir("example_images") | |
] | |
def add_file_gallery( | |
selected_state: gr.SelectData, gallery_list: List[str] | |
) -> Image.Image: | |
return Image.open(gallery_list.root[selected_state.index].image.path) | |
# --- Model Inference --- | |
def model_inference(image: Image.Image) -> tuple[str, Image.Image]: | |
"""Performs model inference and renders the result.""" | |
if image is None: | |
raise ValueError("`image` is None. It should be a PIL image.") | |
inputs = processor.tokenizer( | |
f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>", | |
return_tensors="pt", | |
add_special_tokens=False, | |
) | |
inputs["pixel_values"] = processor.image_processor( | |
[image], | |
transform=custom_transform | |
) | |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
with torch.no_grad(): # Disable gradient calculation | |
generation_kwargs = dict( | |
inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096 | |
) | |
generated_ids = model.generate(**generation_kwargs) | |
generated_text = processor.batch_decode( | |
generated_ids, skip_special_tokens=True | |
)[0] | |
rendered_page = render_webpage(generated_text) | |
return generated_text, rendered_page | |
# --- Gradio Interface --- | |
generated_html = gr.Code( | |
label="Extracted HTML", elem_id="generated_html" | |
) | |
rendered_html = gr.Image( | |
label="Rendered HTML", show_download_button=False, show_share_button=False | |
) | |
css = """ | |
.gradio-container{max-width: 1000px!important} | |
h1{display: flex;align-items: center;justify-content: center;gap: .25em} | |
*{transition: width 0.5s ease, flex-grow 0.5s ease} | |
""" | |
with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo: | |
gr.Markdown( | |
"Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content." | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4, min_width=250) as upload_area: | |
imagebox = gr.Image( | |
type="pil", | |
label="Screenshot to extract", | |
visible=True, | |
sources=["upload", "clipboard"], | |
) | |
with gr.Group(): | |
with gr.Row(): | |
submit_btn = gr.Button( | |
value="▶️ Submit", visible=True, min_width=120 | |
) | |
clear_btn = gr.ClearButton( | |
[imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120 | |
) | |
regenerate_btn = gr.Button( | |
value="🔄 Regenerate", visible=True, min_width=120 | |
) | |
with gr.Column(scale=4): | |
rendered_html.render() | |
with gr.Row(): | |
generated_html.render() | |
with gr.Row(): | |
template_gallery = gr.Gallery( | |
value=IMAGE_GALLERY_PATHS, | |
label="Templates Gallery", | |
allow_preview=False, | |
columns=5, | |
elem_id="gallery", | |
show_share_button=False, | |
height=400, | |
loading_lazy="eager", | |
) | |
gr.on( | |
triggers=[ | |
imagebox.upload, | |
submit_btn.click, | |
regenerate_btn.click, | |
], | |
fn=model_inference, | |
inputs=[imagebox], | |
outputs=[generated_html, rendered_html], | |
) | |
regenerate_btn.click( | |
fn=model_inference, | |
inputs=[imagebox], | |
outputs=[generated_html, rendered_html], | |
) | |
template_gallery.select( | |
fn=add_file_gallery, | |
inputs=[template_gallery], | |
outputs=[imagebox], | |
).success( | |
fn=model_inference, | |
inputs=[imagebox], | |
outputs=[generated_html, rendered_html], | |
) | |
demo.load() | |
demo.queue(max_size=40, api_open=False) | |
demo.launch(max_threads=400) |