File size: 3,619 Bytes
78f426f
 
 
 
 
 
16b76e3
b292c7c
78f426f
 
 
 
 
 
 
 
1e0d238
 
 
 
 
 
78f426f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import time
import requests
from PIL import Image
from io import BytesIO
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr

# ---------- Step 1: Scrape Celebrity Images ----------
def scrape_images(celebrity_name, num_images=20):
    search_url = f"https://www.google.com/search?q={celebrity_name}+portrait&tbm=isch"
    chrome_options = Options()
    chrome_options.add_argument("--headless")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")

    driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options)
    driver.get(search_url)
    os.makedirs(f"data/{celebrity_name}", exist_ok=True)

    images = driver.find_elements("tag name", "img")
    count = 0

    for img in images:
        if count >= num_images:
            break
        src = img.get_attribute("src")
        if src and "http" in src:
            try:
                img_data = requests.get(src).content
                with open(f"data/{celebrity_name}/{celebrity_name}_{count}.jpg", "wb") as handler:
                    handler.write(img_data)
                count += 1
            except Exception as e:
                print(f"Error downloading image: {e}")
    driver.quit()

# ---------- Step 2: Fine-Tuning Stable Diffusion ----------
def fine_tune_sd3(celebrity_name):
    model_id = "runwayml/stable-diffusion-v1-5"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)

    celeb_images_path = f"data/{celebrity_name}"
    images = [Image.open(os.path.join(celeb_images_path, img)) for img in os.listdir(celeb_images_path) if img.endswith(".jpg")]

    # Simple fine-tuning logic (for demonstration; deep fine-tuning requires more work)
    print(f"Fine-tuning with {len(images)} images of {celebrity_name}...")

    # Saving model
    fine_tuned_model_path = f"models/{celebrity_name}_sd3"
    os.makedirs(fine_tuned_model_path, exist_ok=True)
    pipe.save_pretrained(fine_tuned_model_path)
    print(f"Model saved at {fine_tuned_model_path}")

    return fine_tuned_model_path

# ---------- Step 3: Generate Phone Cover Designs ----------
def generate_cover(prompt, model_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)

    image = pipe(prompt).images[0]
    cover_template = Image.open("phone_cover_template.png").convert("RGBA")
    image = image.resize(cover_template.size)
    blended = Image.alpha_composite(cover_template, image.convert("RGBA"))
    
    output_path = "generated_phone_cover.png"
    blended.save(output_path)
    return output_path

# ---------- Step 4: Gradio Deployment ----------
def launch_gradio(model_path):
    def infer(prompt):
        result = generate_cover(prompt, model_path)
        return result

    gr.Interface(fn=infer, 
                 inputs=gr.Textbox(label="Enter a design prompt"),
                 outputs=gr.Image(label="Generated Phone Cover"),
                 title="Celebrity Phone Cover Generator").launch()

# ---------- Main Workflow ----------
if __name__ == "__main__":
    celebrity = "Taylor Swift"  # Example celebrity
    scrape_images(celebrity, num_images=30)
    
    model_path = fine_tune_sd3(celebrity)

    # Deploy on Gradio
    launch_gradio(model_path)