Training_Celeb / app.py
gaur3009's picture
Update app.py
b292c7c verified
import os
import time
import requests
from PIL import Image
from io import BytesIO
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr
# ---------- Step 1: Scrape Celebrity Images ----------
def scrape_images(celebrity_name, num_images=20):
search_url = f"https://www.google.com/search?q={celebrity_name}+portrait&tbm=isch"
chrome_options = Options()
chrome_options.add_argument("--headless")
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options)
driver.get(search_url)
os.makedirs(f"data/{celebrity_name}", exist_ok=True)
images = driver.find_elements("tag name", "img")
count = 0
for img in images:
if count >= num_images:
break
src = img.get_attribute("src")
if src and "http" in src:
try:
img_data = requests.get(src).content
with open(f"data/{celebrity_name}/{celebrity_name}_{count}.jpg", "wb") as handler:
handler.write(img_data)
count += 1
except Exception as e:
print(f"Error downloading image: {e}")
driver.quit()
# ---------- Step 2: Fine-Tuning Stable Diffusion ----------
def fine_tune_sd3(celebrity_name):
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
celeb_images_path = f"data/{celebrity_name}"
images = [Image.open(os.path.join(celeb_images_path, img)) for img in os.listdir(celeb_images_path) if img.endswith(".jpg")]
# Simple fine-tuning logic (for demonstration; deep fine-tuning requires more work)
print(f"Fine-tuning with {len(images)} images of {celebrity_name}...")
# Saving model
fine_tuned_model_path = f"models/{celebrity_name}_sd3"
os.makedirs(fine_tuned_model_path, exist_ok=True)
pipe.save_pretrained(fine_tuned_model_path)
print(f"Model saved at {fine_tuned_model_path}")
return fine_tuned_model_path
# ---------- Step 3: Generate Phone Cover Designs ----------
def generate_cover(prompt, model_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)
image = pipe(prompt).images[0]
cover_template = Image.open("phone_cover_template.png").convert("RGBA")
image = image.resize(cover_template.size)
blended = Image.alpha_composite(cover_template, image.convert("RGBA"))
output_path = "generated_phone_cover.png"
blended.save(output_path)
return output_path
# ---------- Step 4: Gradio Deployment ----------
def launch_gradio(model_path):
def infer(prompt):
result = generate_cover(prompt, model_path)
return result
gr.Interface(fn=infer,
inputs=gr.Textbox(label="Enter a design prompt"),
outputs=gr.Image(label="Generated Phone Cover"),
title="Celebrity Phone Cover Generator").launch()
# ---------- Main Workflow ----------
if __name__ == "__main__":
celebrity = "Taylor Swift" # Example celebrity
scrape_images(celebrity, num_images=30)
model_path = fine_tune_sd3(celebrity)
# Deploy on Gradio
launch_gradio(model_path)