File size: 4,125 Bytes
60cec25
 
4772902
60cec25
 
 
4772902
60cec25
 
 
 
4772902
60cec25
 
4772902
60cec25
 
4772902
60cec25
 
4772902
60cec25
 
 
 
 
 
 
4772902
60cec25
4772902
60cec25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4772902
60cec25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4772902
 
60cec25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# app.py - Space Gradio para SDXL (listo para pegar)
import os
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
import gradio as gr
from PIL import Image

# --- CONFIG ---
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"  # base SDXL
REFINER_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"  # refiner (opcional)
USE_REFINER = True

# Cargar token si está en variable de entorno HF_ACCESS_TOKEN (recomendado en Settings -> Secrets)
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN", None)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Carga del pipeline (ajusta torch_dtype según device)
torch_dtype = torch.float16 if device == "cuda" else torch.float32

pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch_dtype,
    use_safetensors=True,
    revision="fp16" if device == "cuda" else None,
    use_auth_token=HF_TOKEN
)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

if device == "cuda":
    pipe.enable_xformers_memory_efficient_attention()
    pipe.to("cuda")
else:
    pipe.to("cpu")

# Cargar refiner si se desea (mejora detalles)
refiner = None
if USE_REFINER:
    try:
        refiner = StableDiffusionXLPipeline.from_pretrained(
            REFINER_ID,
            torch_dtype=torch_dtype,
            use_safetensors=True,
            use_auth_token=HF_TOKEN
        )
        if device == "cuda":
            refiner.enable_xformers_memory_efficient_attention()
            refiner.to("cuda")
        else:
            refiner.to("cpu")
    except Exception as e:
        print("No se pudo cargar refiner:", e)
        refiner = None

# Función principal de generación
def generate(prompt, negative_prompt, steps, guidance_scale, width, height, seed):
    generator = torch.Generator(device=device)
    if seed is not None and seed != "":
        try:
            seed = int(seed)
            generator = torch.Generator(device=device).manual_seed(seed)
        except:
            seed = None

    # Ajustes
    height = int(height)
    width = int(width)

    with torch.autocast(device_type="cuda" if device=="cuda" else "cpu"):
        output = pipe(
            prompt,
            negative_prompt=negative_prompt if negative_prompt else None,
            num_inference_steps=int(steps),
            guidance_scale=float(guidance_scale),
            generator=generator,
            height=height,
            width=width
        )
    image = output.images[0]

    # Refinamiento opcional
    if refiner is not None:
        with torch.autocast(device_type="cuda" if device=="cuda" else "cpu"):
            refined = refiner(
                prompt,
                image=image,
                num_inference_steps=10,
                guidance_scale=float(guidance_scale),
                generator=generator
            )
        image = refined.images[0]

    return image

# Interfaz Gradio
with gr.Blocks() as demo:
    gr.Markdown("# SDXL — Generador fotorrealista")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt (describe la escena fotorrealista)", lines=4, placeholder="Portrait of a young woman... cinematic lighting, 85mm lens, photorealistic")
            negative = gr.Textbox(label="Negative prompt (evitar)", lines=2, placeholder="lowres, deformed, cartoon, watermark")
            steps = gr.Slider(minimum=10, maximum=60, step=1, value=28, label="Steps")
            scale = gr.Slider(minimum=1.0, maximum=12.0, step=0.5, value=7.5, label="Guidance scale (fidelity)")
            width = gr.Dropdown([512, 640, 768, 1024], value=1024, label="Width")
            height = gr.Dropdown([512, 640, 768, 1024], value=1024, label="Height")
            seed = gr.Textbox(label="Seed (opcional)", placeholder="123456")
            btn = gr.Button("Generar")
        with gr.Column():
            gallery = gr.Image(label="Resultado", type="pil")

    btn.click(fn=generate, inputs=[prompt, negative, steps, scale, width, height, seed], outputs=[gallery])

if __name__ == "__main__":
    demo.launch()