Spaces:
Sleeping
Sleeping
File size: 2,628 Bytes
67dff27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import warnings
from smolagents import tool
warnings.filterwarnings("ignore")
# Global pipeline variable for reuse
_pipeline = None
def get_pipeline():
"""Initialize and return the Stable Diffusion pipeline."""
global _pipeline
if _pipeline is None:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
_pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype,
safety_checker=None,
requires_safety_checker=False
).to(device)
if hasattr(_pipeline, 'enable_attention_slicing'):
_pipeline.enable_attention_slicing()
except Exception as e:
print(f"Failed to load pipeline: {e}")
_pipeline = "mock"
return _pipeline
@tool
def generate_image(scene_prompt: str) -> Image.Image:
"""
Generates a cartoon-style image from a scene prompt using Stable Diffusion v1.5.
Falls back to a placeholder if loading fails.
Args:
scene_prompt (str): Description of the scene to generate
Returns:
PIL.Image.Image: Generated cartoon-style image
"""
pipe = get_pipeline()
# Fallback to placeholder if pipeline loading failed
if pipe == "mock":
return Image.new('RGB', (512, 512), color='lightblue')
# Enhance prompt for cartoon style
prompt = f"cartoon style, {scene_prompt}, colorful, animated"
# Generate image with fixed seed for reproducibility
gen = torch.Generator(device=pipe.device).manual_seed(42)
result = pipe(
prompt,
guidance_scale=7.5,
num_inference_steps=20,
height=512,
width=512,
generator=gen
)
return result.images[0]
if __name__ == "__main__":
# Test the function
test_prompts = [
"Cartoon cat wearing a wizard hat in a magical forest",
"Cartoon robot dancing in a disco with neon lights",
"Cartoon dragon flying over a rainbow castle"
]
for i, prompt in enumerate(test_prompts, 1):
print(f"Generating image {i}: '{prompt}'")
img = generate_image(prompt)
print(f"Result: Image size={img.size}, mode={img.mode}")
# Optionally save the image
# img.save(f"test_image_{i}.png") |