Change-Hair / app.py
ginipick's picture
Update app.py
14af22f verified
raw
history blame
10.3 kB
import gradio as gr
import replicate
import os
from PIL import Image
import requests
from io import BytesIO
replicate_client = replicate.Client(api_token=os.getenv("RAPI_TOKEN"))
# Comprehensive hairstyle options - All 93 styles
HAIRCUT_OPTIONS = [
"Crew Cut", "Faux Hawk", "Slicked Back", "Side-Parted", "Center-Parted",
"Blunt Bangs", "Side-Swept Bangs", "Shag", "Lob", "Angled Bob",
"A-Line Bob", "Asymmetrical Bob", "Graduated Bob", "Inverted Bob", "Layered Shag",
"Choppy Layers", "Razor Cut", "Perm", "Ombré", "Straightened",
"Soft Waves", "Glamorous Waves", "Hollywood Waves", "Finger Waves", "Tousled",
"Feathered", "Pageboy", "Pigtails", "Pin Curls", "Rollerset",
"Twist Out", "Bantu Knots", "Dreadlocks", "Cornrows", "Box Braids",
"Crochet Braids", "Double Dutch Braids", "French Fishtail Braid", "Waterfall Braid", "Rope Braid",
"Heart Braid", "Halo Braid", "Crown Braid", "Braided Crown", "Bubble Braid",
"Bubble Ponytail", "Ballerina Braids", "Milkmaid Braids", "Bohemian Braids", "Flat Twist",
"Crown Twist", "Twisted Bun", "Twisted Half-Updo", "Twist and Pin Updo", "Chignon",
"Simple Chignon", "Messy Chignon", "French Twist", "French Twist Updo", "French Roll",
"Updo", "Messy Updo", "Knotted Updo", "Ballerina Bun", "Banana Clip Updo",
"Beehive", "Bouffant", "Hair Bow", "Half-Up Top Knot", "Half-Up, Half-Down",
"Messy Bun with a Headband", "Messy Bun with a Scarf", "Messy Fishtail Braid", "Sideswept Pixie", "Mohawk Fade",
"Straight", "Wavy", "Curly", "Bob", "Pixie Cut",
"Layered", "Messy Bun", "High Ponytail", "Low Ponytail", "Braided Ponytail",
"French Braid", "Dutch Braid", "Fishtail Braid", "Space Buns", "Top Knot",
"Undercut", "Mohawk"
]
# Hair color options with color codes
HAIR_COLOR_OPTIONS = {
"Blonde": "#F4E4C1",
"Brunette": "#6F4E37",
"Black": "#1C1C1C",
"Dark Brown": "#3B2F2F",
"Medium Brown": "#8B4513",
"Light Brown": "#A0826D",
"Auburn": "#A52A2A",
"Copper": "#B87333",
"Red": "#DC143C",
"Strawberry Blonde": "#FFB6C1",
"Platinum Blonde": "#FAFAD2",
"Silver": "#C0C0C0",
"White": "#FFFFFF",
"Blue": "#4169E1",
"Purple": "#9370DB",
"Pink": "#FF69B4",
"Green": "#228B22",
"Blue-Black": "#1C1C3D",
"Golden Blonde": "#FFD700",
"Honey Blonde": "#F0E68C",
"Caramel": "#C68E17",
"Chestnut": "#954535",
"Mahogany": "#C04000",
"Burgundy": "#800020",
"Jet Black": "#0A0A0A",
"Ash Brown": "#8B7355",
"Ash Blonde": "#D3D3D3",
"Titanium": "#878787",
"Rose Gold": "#E0BFB8"
}
def create_color_palette_html():
"""Create HTML for visual color palette"""
html = '''
<style>
.color-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(80px, 1fr));
gap: 8px;
padding: 10px;
max-height: 300px;
overflow-y: auto;
border: 1px solid #e0e0e0;
border-radius: 8px;
background-color: #f9f9f9;
}
.color-cell {
aspect-ratio: 1;
border-radius: 8px;
cursor: pointer;
transition: all 0.3s ease;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
display: flex;
align-items: flex-end;
justify-content: center;
overflow: hidden;
border: 3px solid transparent;
}
.color-cell:hover {
transform: translateY(-2px);
box-shadow: 0 4px 10px rgba(0,0,0,0.3);
}
.color-label {
width: 100%;
background-color: rgba(0,0,0,0.7);
color: white;
font-size: 9px;
padding: 2px;
text-align: center;
font-weight: bold;
}
</style>
<div class="color-grid">
'''
for color_name, color_code in HAIR_COLOR_OPTIONS.items():
html += f'''
<div class="color-cell" style="background-color: {color_code};" title="{color_name}">
<span class="color-label">{color_name}</span>
</div>
'''
html += '</div>'
return html
def update_selected_color(color_name):
"""Update the selected color display"""
return f"**Selected Color:** {color_name}", color_name
def change_haircut(input_image, haircut_style, hair_color):
"""
Process image through API to change hairstyle
"""
try:
# Check if image is provided
if input_image is None:
return None, "Please upload an image first."
# Ensure we have a valid color
if not hair_color or hair_color not in HAIR_COLOR_OPTIONS:
hair_color = "Black"
print(f"Processing with hairstyle: {haircut_style}, color: {hair_color}")
# Save PIL Image to temporary file
temp_path = "temp_input.png"
input_image.save(temp_path)
# Call Replicate API
output = replicate_client.run(
"flux-kontext-apps/change-haircut",
input={
"haircut": haircut_style,
"hair_color": hair_color,
"input_image": open(temp_path, "rb")
}
)
# Process results
if output:
# If output is URL
if isinstance(output, str):
response = requests.get(output)
result_image = Image.open(BytesIO(response.content))
# If output is file object
else:
result_image = Image.open(BytesIO(output.read()))
# Clean up temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
return result_image, f"Successfully applied {haircut_style} style with {hair_color} color!"
else:
return None, "Processing error occurred."
except Exception as e:
# Clean up temporary file on error
if os.path.exists(temp_path):
os.remove(temp_path)
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="AI Hairstyle Changer", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎨 AI Hairstyle Changer
"""
)
with gr.Row():
with gr.Column(scale=1):
# Input section
input_image = gr.Image(
label="Upload Your Photo",
type="pil",
height=400
)
with gr.Group():
gr.Markdown("### Select Hair Color")
selected_color_display = gr.Markdown(value="**Selected Color:** Black")
# Dropdown for color selection (functional)
hair_color_dropdown = gr.Dropdown(
choices=list(HAIR_COLOR_OPTIONS.keys()),
value="Black",
label="Choose Hair Color",
info="Select from dropdown for color"
)
# Visual color palette (for reference)
gr.Markdown("**Color Palette Reference:**")
color_palette_display = gr.HTML(create_color_palette_html())
submit_btn = gr.Button(
"🎨 Apply Hairstyle",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
# Hairstyle selection with scrollable radio buttons
with gr.Group():
gr.Markdown(f"### Select Hairstyle ({len(HAIRCUT_OPTIONS)} total styles)")
gr.Markdown("*Scroll down to see all 93 hairstyle options*")
haircut_radio = gr.Radio(
choices=HAIRCUT_OPTIONS,
value="Bob",
label="Hairstyle Options",
info="Scroll to browse all 93 hairstyle options"
)
with gr.Column(scale=1):
# Output section
output_image = gr.Image(
label="Result",
type="pil",
height=400
)
status_text = gr.Textbox(
label="Status",
interactive=False,
value="Upload a photo and select a hairstyle to begin."
)
# Update color display when dropdown changes
hair_color_dropdown.change(
fn=update_selected_color,
inputs=[hair_color_dropdown],
outputs=[selected_color_display, gr.State()]
)
# Button click event
submit_btn.click(
fn=change_haircut,
inputs=[input_image, haircut_radio, hair_color_dropdown],
outputs=[output_image, status_text]
)
# Style customization
demo.css = """
/* Radio button container */
#component-8 > div.wrap > div {
max-height: 600px !important;
overflow-y: auto !important;
padding: 10px;
border: 1px solid #e0e0e0;
border-radius: 8px;
background-color: #f9f9f9;
}
/* Style radio button options */
label.svelte-1gfkn6j {
display: block !important;
padding: 8px 12px !important;
margin: 4px 0 !important;
border-radius: 6px !important;
cursor: pointer !important;
transition: all 0.2s ease !important;
}
label.svelte-1gfkn6j:hover {
background-color: #e8e8e8 !important;
}
/* Selected radio option */
label.svelte-1gfkn6j.selected {
background-color: #2196F3 !important;
color: white !important;
}
"""
# Run the app
if __name__ == "__main__":
# Check API key
if not os.getenv("RAPI_TOKEN"):
print("⚠️ Warning: RAPI_TOKEN environment variable not set!")
print("Set it using: export RAPI_TOKEN='your_token'")
# Verify all styles are loaded
print(f"✅ Loaded {len(HAIRCUT_OPTIONS)} hairstyle options")
print(f"✅ Loaded {len(HAIR_COLOR_OPTIONS)} hair color options")
demo.launch(
share=True, # Create shareable link
debug=True # Debug mode
)