File size: 7,904 Bytes
eba2043
3195f8b
 
2b829cd
eba2043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3195f8b
eba2043
 
 
 
3195f8b
 
 
 
 
eba2043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b829cd
eba2043
 
 
 
 
 
 
 
 
 
 
 
 
2b829cd
3195f8b
 
eba2043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b829cd
 
eba2043
 
 
 
2b829cd
eba2043
 
2b829cd
eba2043
 
 
 
 
 
 
2b829cd
eba2043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b829cd
 
eba2043
 
 
3195f8b
 
2b829cd
eba2043
 
3195f8b
 
 
2b829cd
 
eba2043
2b829cd
 
 
 
 
 
eba2043
 
2b829cd
 
 
 
 
eba2043
2b829cd
3195f8b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os, io, math, tempfile
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from spaces import GPU  # ZeroGPU support
from fpdf import FPDF   # make sure requirements.txt includes: fpdf==1.7.2

# Avoid tokenizers parallelism warning after fork
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# ------------------ Globals (CPU-safe) ------------------
_txtgen = None           # text generator stays on CPU
_t2i_cpu = None          # CPU fallback pipeline

STYLE_PRESETS = {
    "Realistic": "realistic photography, finely detailed, natural lighting, 35mm",
    "Anime": "anime, vibrant colors, cel shading, clean lineart",
    "Comic": "comic book style, halftone, bold lines, dramatic shading",
    "Watercolor": "watercolor painting, soft edges, gentle colors, textured paper",
    "Sketch": "pencil sketch, cross-hatching, grayscale, paper texture",
}
NEGATIVE = "nsfw, nudity, gore, deformed, extra limbs, low quality, blurry, worst quality, lowres, text artifacts, watermark, logo"


# ------------------ Loaders ------------------
def get_txtgen_cpu():
    """Load text generator on CPU (ZeroGPU-safe)."""
    global _txtgen
    if _txtgen is None:
        _txtgen = pipeline("text-generation", model="distilgpt2", device=-1)
    return _txtgen


def get_t2i_cpu():
    """CPU Stable Diffusion pipeline (fallback)."""
    global _t2i_cpu
    if _t2i_cpu is None:
        _t2i_cpu = StableDiffusionPipeline.from_pretrained(
            "stabilityai/sd-turbo",
            torch_dtype=torch.float32,
            safety_checker=None,
        )
        _t2i_cpu.enable_attention_slicing()
    return _t2i_cpu


# ------------------ GPU path (ZeroGPU) ------------------
@GPU(duration=120)
def t2i_generate_batch_gpu(prompts, width, height, steps, guidance, negative_prompt, seed=None):
    """Runs inside a GPU-allocated context (ZeroGPU)."""
    pipe = StableDiffusionPipeline.from_pretrained(
        "stabilityai/sd-turbo",
        torch_dtype=torch.float16,
        safety_checker=None,
    ).to("cuda")

    generator = torch.Generator(device="cuda")
    if seed is not None and str(seed).strip().isdigit():
        generator = generator.manual_seed(int(seed))

    images = []
    for p in prompts:
        img = pipe(
            prompt=p,
            negative_prompt=negative_prompt,
            num_inference_steps=steps,
            guidance_scale=guidance,
            width=width,
            height=height,
            generator=generator,
        ).images[0]
        images.append(img)
    return images


# ------------------ Helpers ------------------
def build_prompt(user_prompt: str, style: str, panel_idx: int, num_panels: int) -> str:
    style_desc = STYLE_PRESETS.get(style, "")
    beat = ["opening shot", "rising action", "key moment", "twist", "resolution"]
    beat_text = beat[min(panel_idx, len(beat) - 1)]
    return f"{user_prompt}, {style_desc}, storyboard panel {panel_idx+1} of {num_panels}, {beat_text}, cinematic composition, wide shot"


def generate_captions(user_prompt: str, n: int = 3):
    gen = get_txtgen_cpu()
    # Simple, fast prompts; keep it short
    outputs = []
    for i in range(n):
        text = gen(
            f"Write a very short scene caption (<=10 words) about: {user_prompt}",
            max_new_tokens=30,
            do_sample=True,
            temperature=0.9,
            top_p=0.95,
            num_return_sequences=1,
        )[0]["generated_text"].strip()
        # Fallback if something weird comes out
        if not text or len(text.split()) < 2:
            text = f"Scene {i+1}"
        outputs.append(text[:80])
    return outputs


def add_caption_strip(img: Image.Image, text: str, width_hint: int) -> Image.Image:
    """Add a black strip with white text at the bottom. Uses textbbox (Pillow>=10)."""
    out = img.copy()
    draw = ImageDraw.Draw(out)
    try:
        font = ImageFont.truetype("DejaVuSans.ttf", size=max(16, width_hint // 28))
    except Exception:
        font = ImageFont.load_default()

    bbox = draw.textbbox((0, 0), text, font=font)
    text_w = bbox[2] - bbox[0]
    text_h = bbox[3] - bbox[1]
    strip_h = text_h + 14

    strip = Image.new("RGB", (out.width, strip_h), (0, 0, 0))
    d2 = ImageDraw.Draw(strip)
    d2.text(((out.width - text_w) // 2, 7), text, font=font, fill=(255, 255, 255))

    combined = Image.new("RGB", (out.width, out.height + strip_h), (0, 0, 0))
    combined.paste(out, (0, 0))
    combined.paste(strip, (0, out.height))
    return combined


def images_to_pdf_with_fpdf(images):
    """Write a simple multipage PDF using FPDF."""
    if not images:
        return None
    pdf_path = tempfile.mktemp(suffix=".pdf")
    pdf = FPDF()
    for img in images:
        # Save temp PNG to insert in PDF
        tmp = tempfile.mktemp(suffix=".png")
        img.save(tmp)
        pdf.add_page()
        # Fit the image nicely within margins
        pdf.image(tmp, x=10, y=10, w=190)
    pdf.output(pdf_path)
    return pdf_path


# ------------------ Core logic ------------------
def create_storyboard(user_prompt, style, num_panels, width, height, seed):
    if not user_prompt or not user_prompt.strip():
        return [], None

    # Build prompts + captions
    captions = generate_captions(user_prompt, n=num_panels)
    prompts = [build_prompt(user_prompt, style, i, num_panels) for i in range(num_panels)]

    # Try GPU (ZeroGPU). If it fails (no GPU), fallback to CPU.
    images = None
    try:
        images = t2i_generate_batch_gpu(prompts, width, height, steps=2, guidance=0.0,
                                        negative_prompt=NEGATIVE, seed=seed)
    except Exception:
        # GPU not available → CPU fallback (slower)
        pipe = get_t2i_cpu()
        images = []
        # No seed control on CPU path by default; can be added with torch.Generator("cpu")
        for p in prompts:
            img = pipe(
                prompt=p,
                negative_prompt=NEGATIVE,
                num_inference_steps=4,
                guidance_scale=0.0,
                width=width,
                height=height,
            ).images[0]
            images.append(img)

    # Add caption strips
    final_images = [add_caption_strip(img, cap, width_hint=width) for img, cap in zip(images, captions)]
    # Build PDF
    pdf_path = images_to_pdf_with_fpdf(final_images)
    return final_images, pdf_path


# ------------------ UI ------------------
with gr.Blocks(title="AI Storyboard Creator") as demo:
    gr.Markdown(
        """
        # 🎬 AI Storyboard Creator
        Turn a single prompt into a mini storyboard: 3–6 panels, captions, and a downloadable PDF.  
        Works on **CPU basic** and supports **ZeroGPU** (GPU on-demand).
        """
    )
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Story prompt", placeholder="A cyberpunk detective in the rain", lines=2)
            style = gr.Dropdown(choices=list(STYLE_PRESETS.keys()), value="Comic", label="Style")
            num_panels = gr.Slider(3, 6, value=3, step=1, label="Number of panels")
            width = gr.Slider(384, 768, value=448, step=64, label="Panel width (px)")
            height = gr.Slider(384, 768, value=448, step=64, label="Panel height (px)")
            seed = gr.Textbox(label="Seed (optional)", placeholder="e.g., 42")
            run_btn = gr.Button("Create Storyboard")
        with gr.Column():
            # NOTE: no .style(); use columns=2 instead
            gallery = gr.Gallery(label="Preview (grid)", columns=2, height="auto")
            pdf_file = gr.File(label="Download PDF")

    run_btn.click(
        create_storyboard,
        inputs=[prompt, style, num_panels, width, height, seed],
        outputs=[gallery, pdf_file],
    )

if __name__ == "__main__":
    demo.launch()