File size: 3,784 Bytes
e1e0641
 
 
1089a55
e1e0641
1089a55
e1e0641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1089a55
 
 
e1e0641
 
 
 
1089a55
e1e0641
 
 
 
 
 
 
 
 
 
 
 
1089a55
c49a67c
e1e0641
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# gradio_blip3o_next_min.py
import time
from dataclasses import dataclass

import torch
from PIL import Image
from transformers import AutoTokenizer
from blip3o.model import *
import gradio as gr
from huggingface_hub import snapshot_download


# -----------------------------
# Minimal config and runner
# -----------------------------
@dataclass
class T2IConfig:
    device: str = "cuda:0"
    dtype: torch.dtype = torch.bfloat16
    # fixed generation config (no UI controls)
    scale: int = 0
    seq_len: int = 729
    top_p: float = 0.95
    top_k: int = 1200


class TextToImageInference:
    def __init__(self, config: T2IConfig):
        self.config = config
        self.device = torch.device(config.device)
        self._load_models()

    def _load_models(self):
        model_path = snapshot_download(repo_id='BLIP3o/BLIP3o-NEXT-GRPO-Geneval-3B')
        self.model = blip3oQwenForInferenceLM.from_pretrained(
            model_path, torch_dtype=self.config.dtype
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        if hasattr(self.tokenizer, "padding_side"):
            self.tokenizer.padding_side = "left"

    @torch.inference_mode()
    def generate_image(self, prompt: str) -> Image.Image:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {
                "role": "user",
                "content": f"Please generate image based on the following caption: {prompt}",
            },
        ]
        input_text = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        input_text += f"<im_start><S{self.config.scale}>"

        inputs = self.tokenizer(
            [input_text], return_tensors="pt", padding=True, truncation=True
        )

        _, images = self.model.generate_images(
            inputs.input_ids.to(self.device),
            inputs.attention_mask.to(self.device),
            max_new_tokens=self.config.seq_len,
            do_sample=True,
            top_p=self.config.top_p,
            top_k=self.config.top_k,
        )
        return images[0]


# Try loading once at startup for simplicity
LOAD_ERROR = None
inference = None
try:
    inference = TextToImageInference(T2IConfig())
except Exception as e:
    LOAD_ERROR = f"❌ Failed to load model: {e}"


def run_generate(prompt, progress=gr.Progress(track_tqdm=True)):
    t0 = time.time()
    if LOAD_ERROR:
        return None, LOAD_ERROR
    if not prompt or not prompt.strip():
        return None, "⚠️ Please enter a prompt."

    try:
        img = inference.generate_image(prompt.strip())
        return img, f"βœ… Done in {time.time() - t0:.2f}s."
    except torch.cuda.OutOfMemoryError:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return None, "❌ CUDA OOM. Try reducing other GPU workloads."
    except Exception as e:
        return None, f"❌ Error: {e}"


with gr.Blocks(title="BLIP3o-NEXT-GRPO-Geneval β€” Text ➜ Image") as demo:
    gr.Markdown("# BLIP3o-NEXT-GRPO-Geneval β€” Text ➜ Image")

    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="Describe the image you want to generate...",
                lines=4,
            )
            run_btn = gr.Button("Generate", variant="primary")

        with gr.Column(scale=4):
            out_img = gr.Image(label="Generated Image", format="png")
            status = gr.Markdown("")

    run_btn.click(
        fn=run_generate,
        inputs=[prompt],
        outputs=[out_img, status],
        queue=True,
        api_name="generate",
    )

if __name__ == "__main__":
    demo.queue().launch(share=True)