Spaces:
Paused
Paused
import gradio as gr | |
import replicate | |
import os | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
replicate_client = replicate.Client(api_token=os.getenv("RAPI_TOKEN")) | |
# Comprehensive hairstyle options - All 93 styles | |
HAIRCUT_OPTIONS = [ | |
"Crew Cut", "Faux Hawk", "Slicked Back", "Side-Parted", "Center-Parted", | |
"Blunt Bangs", "Side-Swept Bangs", "Shag", "Lob", "Angled Bob", | |
"A-Line Bob", "Asymmetrical Bob", "Graduated Bob", "Inverted Bob", "Layered Shag", | |
"Choppy Layers", "Razor Cut", "Perm", "Ombré", "Straightened", | |
"Soft Waves", "Glamorous Waves", "Hollywood Waves", "Finger Waves", "Tousled", | |
"Feathered", "Pageboy", "Pigtails", "Pin Curls", "Rollerset", | |
"Twist Out", "Bantu Knots", "Dreadlocks", "Cornrows", "Box Braids", | |
"Crochet Braids", "Double Dutch Braids", "French Fishtail Braid", "Waterfall Braid", "Rope Braid", | |
"Heart Braid", "Halo Braid", "Crown Braid", "Braided Crown", "Bubble Braid", | |
"Bubble Ponytail", "Ballerina Braids", "Milkmaid Braids", "Bohemian Braids", "Flat Twist", | |
"Crown Twist", "Twisted Bun", "Twisted Half-Updo", "Twist and Pin Updo", "Chignon", | |
"Simple Chignon", "Messy Chignon", "French Twist", "French Twist Updo", "French Roll", | |
"Updo", "Messy Updo", "Knotted Updo", "Ballerina Bun", "Banana Clip Updo", | |
"Beehive", "Bouffant", "Hair Bow", "Half-Up Top Knot", "Half-Up, Half-Down", | |
"Messy Bun with a Headband", "Messy Bun with a Scarf", "Messy Fishtail Braid", "Sideswept Pixie", "Mohawk Fade", | |
"Straight", "Wavy", "Curly", "Bob", "Pixie Cut", | |
"Layered", "Messy Bun", "High Ponytail", "Low Ponytail", "Braided Ponytail", | |
"French Braid", "Dutch Braid", "Fishtail Braid", "Space Buns", "Top Knot", | |
"Undercut", "Mohawk" | |
] | |
# Hair color options with color codes | |
HAIR_COLOR_OPTIONS = { | |
"Blonde": "#F4E4C1", | |
"Brunette": "#6F4E37", | |
"Black": "#1C1C1C", | |
"Dark Brown": "#3B2F2F", | |
"Medium Brown": "#8B4513", | |
"Light Brown": "#A0826D", | |
"Auburn": "#A52A2A", | |
"Copper": "#B87333", | |
"Red": "#DC143C", | |
"Strawberry Blonde": "#FFB6C1", | |
"Platinum Blonde": "#FAFAD2", | |
"Silver": "#C0C0C0", | |
"White": "#FFFFFF", | |
"Blue": "#4169E1", | |
"Purple": "#9370DB", | |
"Pink": "#FF69B4", | |
"Green": "#228B22", | |
"Blue-Black": "#1C1C3D", | |
"Golden Blonde": "#FFD700", | |
"Honey Blonde": "#F0E68C", | |
"Caramel": "#C68E17", | |
"Chestnut": "#954535", | |
"Mahogany": "#C04000", | |
"Burgundy": "#800020", | |
"Jet Black": "#0A0A0A", | |
"Ash Brown": "#8B7355", | |
"Ash Blonde": "#D3D3D3", | |
"Titanium": "#878787", | |
"Rose Gold": "#E0BFB8" | |
} | |
def create_color_palette_html(): | |
"""Create HTML for visual color palette""" | |
html = ''' | |
<style> | |
.color-grid { | |
display: grid; | |
grid-template-columns: repeat(auto-fill, minmax(80px, 1fr)); | |
gap: 8px; | |
padding: 10px; | |
max-height: 300px; | |
overflow-y: auto; | |
border: 1px solid #e0e0e0; | |
border-radius: 8px; | |
background-color: #f9f9f9; | |
} | |
.color-cell { | |
aspect-ratio: 1; | |
border-radius: 8px; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.2); | |
display: flex; | |
align-items: flex-end; | |
justify-content: center; | |
overflow: hidden; | |
border: 3px solid transparent; | |
} | |
.color-cell:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 10px rgba(0,0,0,0.3); | |
} | |
.color-label { | |
width: 100%; | |
background-color: rgba(0,0,0,0.7); | |
color: white; | |
font-size: 9px; | |
padding: 2px; | |
text-align: center; | |
font-weight: bold; | |
} | |
</style> | |
<div class="color-grid"> | |
''' | |
for color_name, color_code in HAIR_COLOR_OPTIONS.items(): | |
html += f''' | |
<div class="color-cell" style="background-color: {color_code};" title="{color_name}"> | |
<span class="color-label">{color_name}</span> | |
</div> | |
''' | |
html += '</div>' | |
return html | |
def update_selected_color(color_name): | |
"""Update the selected color display""" | |
return f"**Selected Color:** {color_name}", color_name | |
def change_haircut(input_image, haircut_style, hair_color): | |
""" | |
Process image through API to change hairstyle | |
""" | |
try: | |
# Check if image is provided | |
if input_image is None: | |
return None, "Please upload an image first." | |
# Ensure we have a valid color | |
if not hair_color or hair_color not in HAIR_COLOR_OPTIONS: | |
hair_color = "Black" | |
print(f"Processing with hairstyle: {haircut_style}, color: {hair_color}") | |
# Save PIL Image to temporary file | |
temp_path = "temp_input.png" | |
input_image.save(temp_path) | |
# Call Replicate API | |
output = replicate_client.run( | |
"flux-kontext-apps/change-haircut", | |
input={ | |
"haircut": haircut_style, | |
"hair_color": hair_color, | |
"input_image": open(temp_path, "rb") | |
} | |
) | |
# Process results | |
if output: | |
# If output is URL | |
if isinstance(output, str): | |
response = requests.get(output) | |
result_image = Image.open(BytesIO(response.content)) | |
# If output is file object | |
else: | |
result_image = Image.open(BytesIO(output.read())) | |
# Clean up temporary file | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
return result_image, f"Successfully applied {haircut_style} style with {hair_color} color!" | |
else: | |
return None, "Processing error occurred." | |
except Exception as e: | |
# Clean up temporary file on error | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
return None, f"Error: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="AI Hairstyle Changer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🎨 AI Hairstyle Changer | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input section | |
input_image = gr.Image( | |
label="Upload Your Photo", | |
type="pil", | |
height=400 | |
) | |
with gr.Group(): | |
gr.Markdown("### Select Hair Color") | |
selected_color_display = gr.Markdown(value="**Selected Color:** Black") | |
# Dropdown for color selection (functional) | |
hair_color_dropdown = gr.Dropdown( | |
choices=list(HAIR_COLOR_OPTIONS.keys()), | |
value="Black", | |
label="Choose Hair Color", | |
info="Select from dropdown for color" | |
) | |
# Visual color palette (for reference) | |
gr.Markdown("**Color Palette Reference:**") | |
color_palette_display = gr.HTML(create_color_palette_html()) | |
submit_btn = gr.Button( | |
"🎨 Apply Hairstyle", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
# Hairstyle selection with scrollable radio buttons | |
with gr.Group(): | |
gr.Markdown(f"### Select Hairstyle ({len(HAIRCUT_OPTIONS)} total styles)") | |
gr.Markdown("*Scroll down to see all 93 hairstyle options*") | |
haircut_radio = gr.Radio( | |
choices=HAIRCUT_OPTIONS, | |
value="Bob", | |
label="Hairstyle Options", | |
info="Scroll to browse all 93 hairstyle options" | |
) | |
with gr.Column(scale=1): | |
# Output section | |
output_image = gr.Image( | |
label="Result", | |
type="pil", | |
height=400 | |
) | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False, | |
value="Upload a photo and select a hairstyle to begin." | |
) | |
# Update color display when dropdown changes | |
hair_color_dropdown.change( | |
fn=update_selected_color, | |
inputs=[hair_color_dropdown], | |
outputs=[selected_color_display, gr.State()] | |
) | |
# Button click event | |
submit_btn.click( | |
fn=change_haircut, | |
inputs=[input_image, haircut_radio, hair_color_dropdown], | |
outputs=[output_image, status_text] | |
) | |
# Style customization | |
demo.css = """ | |
/* Radio button container */ | |
#component-8 > div.wrap > div { | |
max-height: 600px !important; | |
overflow-y: auto !important; | |
padding: 10px; | |
border: 1px solid #e0e0e0; | |
border-radius: 8px; | |
background-color: #f9f9f9; | |
} | |
/* Style radio button options */ | |
label.svelte-1gfkn6j { | |
display: block !important; | |
padding: 8px 12px !important; | |
margin: 4px 0 !important; | |
border-radius: 6px !important; | |
cursor: pointer !important; | |
transition: all 0.2s ease !important; | |
} | |
label.svelte-1gfkn6j:hover { | |
background-color: #e8e8e8 !important; | |
} | |
/* Selected radio option */ | |
label.svelte-1gfkn6j.selected { | |
background-color: #2196F3 !important; | |
color: white !important; | |
} | |
""" | |
# Run the app | |
if __name__ == "__main__": | |
# Check API key | |
if not os.getenv("RAPI_TOKEN"): | |
print("⚠️ Warning: RAPI_TOKEN environment variable not set!") | |
print("Set it using: export RAPI_TOKEN='your_token'") | |
# Verify all styles are loaded | |
print(f"✅ Loaded {len(HAIRCUT_OPTIONS)} hairstyle options") | |
print(f"✅ Loaded {len(HAIR_COLOR_OPTIONS)} hair color options") | |
demo.launch( | |
share=True, # Create shareable link | |
debug=True # Debug mode | |
) |