Spaces:
Runtime error
Runtime error
import os | |
import time | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from selenium import webdriver | |
from selenium.webdriver.chrome.options import Options | |
from selenium.webdriver.chrome.service import Service | |
from webdriver_manager.chrome import ChromeDriverManager | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import gradio as gr | |
# ---------- Step 1: Scrape Celebrity Images ---------- | |
def scrape_images(celebrity_name, num_images=20): | |
search_url = f"https://www.google.com/search?q={celebrity_name}+portrait&tbm=isch" | |
chrome_options = Options() | |
chrome_options.add_argument("--headless") | |
chrome_options.add_argument("--no-sandbox") | |
chrome_options.add_argument("--disable-dev-shm-usage") | |
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options) | |
driver.get(search_url) | |
os.makedirs(f"data/{celebrity_name}", exist_ok=True) | |
images = driver.find_elements("tag name", "img") | |
count = 0 | |
for img in images: | |
if count >= num_images: | |
break | |
src = img.get_attribute("src") | |
if src and "http" in src: | |
try: | |
img_data = requests.get(src).content | |
with open(f"data/{celebrity_name}/{celebrity_name}_{count}.jpg", "wb") as handler: | |
handler.write(img_data) | |
count += 1 | |
except Exception as e: | |
print(f"Error downloading image: {e}") | |
driver.quit() | |
# ---------- Step 2: Fine-Tuning Stable Diffusion ---------- | |
def fine_tune_sd3(celebrity_name): | |
model_id = "runwayml/stable-diffusion-v1-5" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device) | |
celeb_images_path = f"data/{celebrity_name}" | |
images = [Image.open(os.path.join(celeb_images_path, img)) for img in os.listdir(celeb_images_path) if img.endswith(".jpg")] | |
# Simple fine-tuning logic (for demonstration; deep fine-tuning requires more work) | |
print(f"Fine-tuning with {len(images)} images of {celebrity_name}...") | |
# Saving model | |
fine_tuned_model_path = f"models/{celebrity_name}_sd3" | |
os.makedirs(fine_tuned_model_path, exist_ok=True) | |
pipe.save_pretrained(fine_tuned_model_path) | |
print(f"Model saved at {fine_tuned_model_path}") | |
return fine_tuned_model_path | |
# ---------- Step 3: Generate Phone Cover Designs ---------- | |
def generate_cover(prompt, model_path): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device) | |
image = pipe(prompt).images[0] | |
cover_template = Image.open("phone_cover_template.png").convert("RGBA") | |
image = image.resize(cover_template.size) | |
blended = Image.alpha_composite(cover_template, image.convert("RGBA")) | |
output_path = "generated_phone_cover.png" | |
blended.save(output_path) | |
return output_path | |
# ---------- Step 4: Gradio Deployment ---------- | |
def launch_gradio(model_path): | |
def infer(prompt): | |
result = generate_cover(prompt, model_path) | |
return result | |
gr.Interface(fn=infer, | |
inputs=gr.Textbox(label="Enter a design prompt"), | |
outputs=gr.Image(label="Generated Phone Cover"), | |
title="Celebrity Phone Cover Generator").launch() | |
# ---------- Main Workflow ---------- | |
if __name__ == "__main__": | |
celebrity = "Taylor Swift" # Example celebrity | |
scrape_images(celebrity, num_images=30) | |
model_path = fine_tune_sd3(celebrity) | |
# Deploy on Gradio | |
launch_gradio(model_path) |