Spaces:
Runtime error
Runtime error
File size: 3,619 Bytes
78f426f 16b76e3 b292c7c 78f426f 1e0d238 78f426f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import time
import requests
from PIL import Image
from io import BytesIO
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr
# ---------- Step 1: Scrape Celebrity Images ----------
def scrape_images(celebrity_name, num_images=20):
search_url = f"https://www.google.com/search?q={celebrity_name}+portrait&tbm=isch"
chrome_options = Options()
chrome_options.add_argument("--headless")
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options)
driver.get(search_url)
os.makedirs(f"data/{celebrity_name}", exist_ok=True)
images = driver.find_elements("tag name", "img")
count = 0
for img in images:
if count >= num_images:
break
src = img.get_attribute("src")
if src and "http" in src:
try:
img_data = requests.get(src).content
with open(f"data/{celebrity_name}/{celebrity_name}_{count}.jpg", "wb") as handler:
handler.write(img_data)
count += 1
except Exception as e:
print(f"Error downloading image: {e}")
driver.quit()
# ---------- Step 2: Fine-Tuning Stable Diffusion ----------
def fine_tune_sd3(celebrity_name):
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
celeb_images_path = f"data/{celebrity_name}"
images = [Image.open(os.path.join(celeb_images_path, img)) for img in os.listdir(celeb_images_path) if img.endswith(".jpg")]
# Simple fine-tuning logic (for demonstration; deep fine-tuning requires more work)
print(f"Fine-tuning with {len(images)} images of {celebrity_name}...")
# Saving model
fine_tuned_model_path = f"models/{celebrity_name}_sd3"
os.makedirs(fine_tuned_model_path, exist_ok=True)
pipe.save_pretrained(fine_tuned_model_path)
print(f"Model saved at {fine_tuned_model_path}")
return fine_tuned_model_path
# ---------- Step 3: Generate Phone Cover Designs ----------
def generate_cover(prompt, model_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)
image = pipe(prompt).images[0]
cover_template = Image.open("phone_cover_template.png").convert("RGBA")
image = image.resize(cover_template.size)
blended = Image.alpha_composite(cover_template, image.convert("RGBA"))
output_path = "generated_phone_cover.png"
blended.save(output_path)
return output_path
# ---------- Step 4: Gradio Deployment ----------
def launch_gradio(model_path):
def infer(prompt):
result = generate_cover(prompt, model_path)
return result
gr.Interface(fn=infer,
inputs=gr.Textbox(label="Enter a design prompt"),
outputs=gr.Image(label="Generated Phone Cover"),
title="Celebrity Phone Cover Generator").launch()
# ---------- Main Workflow ----------
if __name__ == "__main__":
celebrity = "Taylor Swift" # Example celebrity
scrape_images(celebrity, num_images=30)
model_path = fine_tune_sd3(celebrity)
# Deploy on Gradio
launch_gradio(model_path) |