|
import gradio as gr |
|
import json, os, re, traceback, contextlib, math, random |
|
from typing import Any, List, Dict, Optional, Tuple |
|
|
|
import spaces |
|
import torch |
|
from PIL import Image, ImageDraw |
|
import requests |
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
|
|
|
|
|
MODEL_ID = "Hcompany/Holo1-3B" |
|
|
|
|
|
|
|
def pick_device() -> str: |
|
""" |
|
On HF Spaces (ZeroGPU), CUDA is only available inside @spaces.GPU calls. |
|
We still honor FORCE_DEVICE for local testing. |
|
""" |
|
forced = os.getenv("FORCE_DEVICE", "").lower().strip() |
|
if forced in {"cpu", "cuda", "mps"}: |
|
return forced |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
|
return "mps" |
|
return "cpu" |
|
|
|
def pick_dtype(device: str) -> torch.dtype: |
|
if device == "cuda": |
|
major, _ = torch.cuda.get_device_capability() |
|
return torch.bfloat16 if major >= 8 else torch.float16 |
|
if device == "mps": |
|
return torch.float16 |
|
return torch.float32 |
|
|
|
def move_to_device(batch, device: str): |
|
if isinstance(batch, dict): |
|
return {k: (v.to(device, non_blocking=True) if hasattr(v, "to") else v) for k, v in batch.items()} |
|
if hasattr(batch, "to"): |
|
return batch.to(device, non_blocking=True) |
|
return batch |
|
|
|
|
|
def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str: |
|
tok = getattr(processor, "tokenizer", None) |
|
if hasattr(processor, "apply_chat_template"): |
|
return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
if tok is not None and hasattr(tok, "apply_chat_template"): |
|
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
texts = [] |
|
for m in messages: |
|
for c in m.get("content", []): |
|
if isinstance(c, dict) and c.get("type") == "text": |
|
texts.append(c.get("text", "")) |
|
return "\n".join(texts) |
|
|
|
def batch_decode_compat(processor, token_id_batches, **kw): |
|
tok = getattr(processor, "tokenizer", None) |
|
if tok is not None and hasattr(tok, "batch_decode"): |
|
return tok.batch_decode(token_id_batches, **kw) |
|
if hasattr(processor, "batch_decode"): |
|
return processor.batch_decode(token_id_batches, **kw) |
|
raise AttributeError("No batch_decode available on processor or tokenizer.") |
|
|
|
def get_image_proc_params(processor) -> Dict[str, int]: |
|
ip = getattr(processor, "image_processor", None) |
|
return { |
|
"patch_size": getattr(ip, "patch_size", 14), |
|
"merge_size": getattr(ip, "merge_size", 1), |
|
"min_pixels": getattr(ip, "min_pixels", 256 * 256), |
|
"max_pixels": getattr(ip, "max_pixels", 1280 * 1280), |
|
} |
|
|
|
def trim_generated(generated_ids, inputs): |
|
in_ids = getattr(inputs, "input_ids", None) |
|
if in_ids is None and isinstance(inputs, dict): |
|
in_ids = inputs.get("input_ids", None) |
|
if in_ids is None: |
|
return [out_ids for out_ids in generated_ids] |
|
return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)] |
|
|
|
|
|
print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...") |
|
model = None |
|
processor = None |
|
model_loaded = False |
|
load_error_message = "" |
|
|
|
try: |
|
model = AutoModelForImageTextToText.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
) |
|
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
model.eval() |
|
model_loaded = True |
|
print("Model and processor loaded on CPU.") |
|
except Exception as e: |
|
load_error_message = ( |
|
f"Error loading model/processor: {e}\n" |
|
"This might be due to network/model ID/library versions.\n" |
|
"Check the full traceback in the logs." |
|
) |
|
print(load_error_message) |
|
traceback.print_exc() |
|
|
|
|
|
def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[dict]: |
|
guidelines: str = ( |
|
"Localize an element on the GUI image according to my instructions and " |
|
"output a click position as Click(x, y) with x num pixels from the left edge " |
|
"and y num pixels from the top edge." |
|
) |
|
return [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": pil_image}, |
|
{"type": "text", "text": f"{guidelines}\n{instruction}"} |
|
], |
|
} |
|
] |
|
|
|
|
|
@torch.inference_mode() |
|
def run_inference_localization( |
|
messages_for_template: List[dict[str, Any]], |
|
pil_image_for_processing: Image.Image, |
|
device: str, |
|
dtype: torch.dtype, |
|
do_sample: bool = False, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
max_new_tokens: int = 128, |
|
) -> str: |
|
text_prompt = apply_chat_template_compat(processor, messages_for_template) |
|
|
|
inputs = processor( |
|
text=[text_prompt], |
|
images=[pil_image_for_processing], |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
inputs = move_to_device(inputs, device) |
|
|
|
|
|
if device == "cuda": |
|
amp_ctx = torch.autocast(device_type="cuda", dtype=dtype) |
|
elif device == "mps": |
|
amp_ctx = torch.autocast(device_type="mps", dtype=torch.float16) |
|
else: |
|
amp_ctx = contextlib.nullcontext() |
|
|
|
gen_kwargs = dict( |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
|
|
with amp_ctx: |
|
generated_ids = model.generate(**inputs, **gen_kwargs) |
|
|
|
generated_ids_trimmed = trim_generated(generated_ids, inputs) |
|
decoded_output = batch_decode_compat( |
|
processor, |
|
generated_ids_trimmed, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False |
|
) |
|
return decoded_output[0] if decoded_output else "" |
|
|
|
|
|
CLICK_RE = re.compile(r"Click\((\d+),\s*(\d+)\)") |
|
|
|
def parse_click(s: str) -> Optional[Tuple[int, int]]: |
|
m = CLICK_RE.search(s) |
|
if not m: |
|
return None |
|
try: |
|
return int(m.group(1)), int(m.group(2)) |
|
except Exception: |
|
return None |
|
|
|
@torch.inference_mode() |
|
def sample_clicks( |
|
messages: List[dict], |
|
img: Image.Image, |
|
device: str, |
|
dtype: torch.dtype, |
|
n_samples: int = 7, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
seed: Optional[int] = None, |
|
) -> List[Optional[Tuple[int, int]]]: |
|
""" |
|
Run multiple stochastic decodes to estimate self-consistency. |
|
Returns a list of (x,y) or None (if parsing failed) for each sample. |
|
""" |
|
clicks: List[Optional[Tuple[int, int]]] = [] |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
for i in range(n_samples): |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed + i + 1) |
|
random.seed((seed + i + 1) & 0xFFFFFFFF) |
|
out = run_inference_localization( |
|
messages, img, device, dtype, |
|
do_sample=True, temperature=temperature, top_p=top_p |
|
) |
|
clicks.append(parse_click(out)) |
|
return clicks |
|
|
|
def cluster_and_confidence( |
|
clicks: List[Optional[Tuple[int,int]]], |
|
img_w: int, |
|
img_h: int, |
|
) -> Dict[str, Any]: |
|
""" |
|
Simple robust consensus: |
|
- Keep only valid points |
|
- Compute median point (x_med, y_med) |
|
- Compute distances to median |
|
- Inlier threshold = max(8 px, 2% of min(img_w, img_h)) |
|
- Confidence = (#inliers / #total_samples) * clamp(1 - (rms_inlier_dist / thr), 0, 1) |
|
Returns dict with consensus point, confidence, dispersion, and counts. |
|
""" |
|
valid = [xy for xy in clicks if xy is not None] |
|
total = len(clicks) |
|
if total == 0: |
|
return dict(ok=False, reason="no_samples") |
|
|
|
if not valid: |
|
return dict(ok=False, reason="no_valid_points", total=total) |
|
|
|
xs = sorted([x for x, _ in valid]) |
|
ys = sorted([y for _, y in valid]) |
|
mid = len(valid) // 2 |
|
if len(valid) % 2 == 1: |
|
x_med = xs[mid] |
|
y_med = ys[mid] |
|
else: |
|
x_med = (xs[mid - 1] + xs[mid]) // 2 |
|
y_med = (ys[mid - 1] + ys[mid]) // 2 |
|
|
|
thr = max(8.0, 0.02 * min(img_w, img_h)) |
|
dists = [math.hypot(x - x_med, y - y_med) for (x, y) in valid] |
|
inliers = [(xy, d) for xy, d in zip(valid, dists) if d <= thr] |
|
outliers = [(xy, d) for xy, d in zip(valid, dists) if d > thr] |
|
inlier_count = len(inliers) |
|
|
|
|
|
if inliers: |
|
rms = math.sqrt(sum(d*d for _, d in inliers) / len(inliers)) |
|
else: |
|
rms = float("inf") |
|
|
|
|
|
if inliers: |
|
sharp = max(0.0, min(1.0, 1.0 - (rms / thr))) |
|
else: |
|
sharp = 0.0 |
|
confidence = (inlier_count / total) * sharp |
|
|
|
return dict( |
|
ok=True, |
|
x=x_med, y=y_med, |
|
confidence=confidence, |
|
total_samples=total, |
|
valid_samples=len(valid), |
|
inliers=inlier_count, |
|
outliers=len(outliers), |
|
sigma_px=rms if math.isfinite(rms) else None, |
|
inlier_threshold_px=thr, |
|
all_points=valid, |
|
inlier_points=[xy for xy,_ in inliers], |
|
outlier_points=[xy for xy,_ in outliers], |
|
) |
|
|
|
def draw_samples( |
|
base_img: Image.Image, |
|
consensus_xy: Optional[Tuple[int,int]], |
|
inliers: List[Tuple[int,int]], |
|
outliers: List[Tuple[int,int]], |
|
ring_color: str = "red", |
|
) -> Image.Image: |
|
""" |
|
Overlay all sampled points: green=inliers, red=outliers, plus a ring for consensus. |
|
""" |
|
img = base_img.copy().convert("RGB") |
|
draw = ImageDraw.Draw(img) |
|
w, h = img.size |
|
|
|
r = max(3, min(w, h) // 200) |
|
|
|
|
|
for (x, y) in inliers: |
|
draw.ellipse((x - r, y - r, x + r, y + r), fill="green", outline=None) |
|
|
|
|
|
for (x, y) in outliers: |
|
draw.ellipse((x - r, y - r, x + r, y + r), fill="red", outline=None) |
|
|
|
|
|
if consensus_xy is not None: |
|
cx, cy = consensus_xy |
|
ring_r = max(5, min(w, h) // 100, r * 3) |
|
draw.ellipse((cx - ring_r, cy - ring_r, cx + ring_r, cy + ring_r), outline=ring_color, width=max(2, ring_r // 4)) |
|
return img |
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def predict_click_location( |
|
input_pil_image: Image.Image, |
|
instruction: str, |
|
estimate_confidence: bool = True, |
|
num_samples: int = 7, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
seed: Optional[int] = None, |
|
): |
|
if not model_loaded or not processor or not model: |
|
return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a" |
|
if not input_pil_image: |
|
return "No image provided. Please upload an image.", None, "device: n/a | dtype: n/a" |
|
if not instruction or instruction.strip() == "": |
|
return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB"), "device: n/a | dtype: n/a" |
|
|
|
|
|
device = pick_device() |
|
dtype = pick_dtype(device) |
|
|
|
|
|
if device == "cuda": |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
try: |
|
p = next(model.parameters()) |
|
cur_dev = p.device.type |
|
cur_dtype = p.dtype |
|
except StopIteration: |
|
cur_dev, cur_dtype = "cpu", torch.float32 |
|
|
|
if cur_dev != device or cur_dtype != dtype: |
|
model.to(device=device, dtype=dtype) |
|
model.eval() |
|
|
|
|
|
try: |
|
ip = get_image_proc_params(processor) |
|
resized_height, resized_width = smart_resize( |
|
input_pil_image.height, |
|
input_pil_image.width, |
|
factor=ip["patch_size"] * ip["merge_size"], |
|
min_pixels=ip["min_pixels"], |
|
max_pixels=ip["max_pixels"], |
|
) |
|
resized_image = input_pil_image.resize( |
|
size=(resized_width, resized_height), |
|
resample=Image.Resampling.LANCZOS |
|
) |
|
except Exception as e: |
|
traceback.print_exc() |
|
return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
|
|
|
|
|
messages = get_localization_prompt(resized_image, instruction) |
|
|
|
|
|
try: |
|
if estimate_confidence and num_samples >= 3: |
|
|
|
clicks = sample_clicks( |
|
messages, resized_image, device, dtype, |
|
n_samples=int(num_samples), |
|
temperature=float(temperature), |
|
top_p=float(top_p), |
|
seed=seed |
|
) |
|
summary = cluster_and_confidence(clicks, resized_image.width, resized_image.height) |
|
|
|
if not summary.get("ok", False): |
|
|
|
coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False) |
|
out_img = resized_image.copy().convert("RGB") |
|
match = CLICK_RE.search(coord_str or "") |
|
if match: |
|
x, y = int(match.group(1)), int(match.group(2)) |
|
out_img = draw_samples(out_img, (x, y), [], []) |
|
coords_text = f"{coord_str} | confidence=0.00 (fallback)" |
|
return coords_text, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
|
|
|
|
|
x, y = int(summary["x"]), int(summary["y"]) |
|
conf = summary["confidence"] |
|
inliers = summary["inlier_points"] |
|
outliers = summary["outlier_points"] |
|
sigma = summary["sigma_px"] |
|
thr = summary["inlier_threshold_px"] |
|
total = summary["total_samples"] |
|
valid = summary["valid_samples"] |
|
|
|
|
|
coord_str = f"Click({x}, {y})" |
|
diag = ( |
|
f"confidence={conf:.2f} | samples(valid/total)={valid}/{total} | " |
|
f"inliers={len(inliers)} | σ={sigma:.1f}px | thr={thr:.1f}px | " |
|
f"T={temperature:.2f}, p={top_p:.2f}" |
|
) |
|
|
|
|
|
out_img = draw_samples(resized_image, (x, y), inliers, outliers) |
|
return f"{coord_str} | {diag}", out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
|
|
|
else: |
|
|
|
coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False) |
|
out_img = resized_image.copy().convert("RGB") |
|
match = CLICK_RE.search(coord_str or "") |
|
if match: |
|
x = int(match.group(1)) |
|
y = int(match.group(2)) |
|
|
|
out_img = draw_samples(out_img, (x, y), [], []) |
|
else: |
|
print(f"Could not parse 'Click(x, y)' from model output: {coord_str}") |
|
return coord_str, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
|
|
|
|
|
example_image = None |
|
example_instruction = "Enter the server address readyforquantum.com to check its security" |
|
try: |
|
example_image_url = "https://readyforquantum.com/img/screentest.jpg" |
|
example_image = Image.open(requests.get(example_image_url, stream=True).raw) |
|
except Exception as e: |
|
print(f"Could not load example image from URL: {e}") |
|
traceback.print_exc() |
|
try: |
|
example_image = Image.new("RGB", (200, 150), color="lightgray") |
|
draw = ImageDraw.Draw(example_image) |
|
draw.text((10, 10), "Example image\nfailed to load", fill="black") |
|
except Exception: |
|
pass |
|
|
|
|
|
title = "Holo1-3B: Holo1 Localization Demo (ZeroGPU-ready)" |
|
article = f""" |
|
<p style='text-align: center'> |
|
Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany | |
|
Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> | |
|
Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a><br/> |
|
<small>GPU (if available) is requested only during inference via @spaces.GPU.</small> |
|
</p> |
|
""" |
|
|
|
if not model_loaded: |
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"# <center>⚠️ Error: Model Failed to Load ⚠️</center>") |
|
gr.Markdown(f"<center>{load_error_message}</center>") |
|
gr.Markdown("<center>See logs for the full traceback.</center>") |
|
else: |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
|
gr.Markdown(article) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image_component = gr.Image(type="pil", label="Input UI Image", height=400) |
|
instruction_component = gr.Textbox( |
|
label="Instruction", |
|
placeholder="e.g., Click the 'Login' button", |
|
info="Type the action you want the model to localize on the image." |
|
) |
|
estimate_conf = gr.Checkbox(value=True, label="Estimate confidence (slower)") |
|
num_samples_slider = gr.Slider(3, 15, value=7, step=1, label="Samples (for confidence)") |
|
temperature_slider = gr.Slider(0.2, 1.2, value=0.6, step=0.05, label="Temperature") |
|
top_p_slider = gr.Slider(0.5, 0.99, value=0.9, step=0.01, label="Top-p") |
|
seed_box = gr.Number(value=None, precision=0, label="Seed (optional, for reproducibility)") |
|
submit_button = gr.Button("Localize Click", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
output_coords_component = gr.Textbox( |
|
label="Predicted Coordinates + Confidence", |
|
interactive=False |
|
) |
|
output_image_component = gr.Image( |
|
type="pil", |
|
label="Image with Samples (green=inliers, red=outliers) and Final Ring", |
|
height=400, |
|
interactive=False |
|
) |
|
runtime_info = gr.Textbox( |
|
label="Runtime Info", |
|
value="device: n/a | dtype: n/a", |
|
interactive=False |
|
) |
|
|
|
if example_image: |
|
gr.Examples( |
|
examples=[[example_image, example_instruction, True, 7, 0.6, 0.9, None]], |
|
inputs=[ |
|
input_image_component, |
|
instruction_component, |
|
estimate_conf, |
|
num_samples_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
seed_box, |
|
], |
|
outputs=[output_coords_component, output_image_component, runtime_info], |
|
fn=predict_click_location, |
|
cache_examples="lazy", |
|
) |
|
|
|
submit_button.click( |
|
fn=predict_click_location, |
|
inputs=[ |
|
input_image_component, |
|
instruction_component, |
|
estimate_conf, |
|
num_samples_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
seed_box, |
|
], |
|
outputs=[output_coords_component, output_image_component, runtime_info] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|