KevinNg99's picture
default shift
23c68ff
import os
IS_SPACE = True
if IS_SPACE:
import spaces
import sys
import warnings
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Dict
import torch
def space_context(duration: int):
if IS_SPACE:
return spaces.GPU(duration=duration)
return lambda x: x
@space_context(duration=120)
def test_env():
assert torch.cuda.is_available()
try:
import flash_attn
except ImportError:
print("Flash-attn not found, installing...")
os.system("pip install flash-attn==2.7.3 --no-build-isolation")
else:
print("Flash-attn found, skipping installation...")
test_env()
warnings.filterwarnings("ignore")
# Add the current directory to Python path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
try:
import gradio as gr
from PIL import Image
from hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline
from huggingface_hub import snapshot_download
import modelscope
except ImportError as e:
print(f"Missing required dependencies: {e}")
print("Please install with: pip install -r requirements_gradio.txt")
print("For checkpoint downloads, also install: pip install -U 'huggingface_hub[cli]' modelscope")
sys.exit(1)
BASE_DIR = os.environ.get('HUNYUANIMAGE_V2_1_MODEL_ROOT', './ckpts')
class CheckpointDownloader:
"""Handles downloading of all required checkpoints for HunyuanImage."""
def __init__(self, base_dir: str = BASE_DIR):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(exist_ok=True)
print(f'Downloading checkpoints to: {self.base_dir}')
# Define all required checkpoints
self.checkpoints = {
"main_model": {
"repo_id": "tencent/HunyuanImage-2.1",
"local_dir": self.base_dir,
},
"mllm_encoder": {
"repo_id": "Qwen/Qwen2.5-VL-7B-Instruct",
"local_dir": self.base_dir / "text_encoder" / "llm",
},
"byt5_encoder": {
"repo_id": "google/byt5-small",
"local_dir": self.base_dir / "text_encoder" / "byt5-small",
},
"glyph_encoder": {
"repo_id": "AI-ModelScope/Glyph-SDXL-v2",
"local_dir": self.base_dir / "text_encoder" / "Glyph-SDXL-v2",
"use_modelscope": True
}
}
def download_checkpoint(self, checkpoint_name: str, progress_callback=None) -> Tuple[bool, str]:
"""Download a specific checkpoint."""
if checkpoint_name not in self.checkpoints:
return False, f"Unknown checkpoint: {checkpoint_name}"
config = self.checkpoints[checkpoint_name]
local_dir = config["local_dir"]
local_dir.mkdir(parents=True, exist_ok=True)
try:
if config.get("use_modelscope", False):
# Use modelscope for Chinese models
return self._download_with_modelscope(config, progress_callback)
else:
# Use huggingface_hub for other models
return self._download_with_hf(config, progress_callback)
except Exception as e:
return False, f"Download failed: {str(e)}"
def _download_with_hf(self, config: Dict, progress_callback=None) -> Tuple[bool, str]:
"""Download using huggingface_hub."""
repo_id = config["repo_id"]
local_dir = config["local_dir"]
if progress_callback:
progress_callback(f"Downloading {repo_id}...")
try:
snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
resume_download=True
)
return True, f"Successfully downloaded {repo_id}"
except Exception as e:
return False, f"HF download failed: {str(e)}"
def _download_with_modelscope(self, config: Dict, progress_callback=None) -> Tuple[bool, str]:
"""Download using modelscope."""
repo_id = config["repo_id"]
local_dir = config["local_dir"]
if progress_callback:
progress_callback(f"Downloading {repo_id} via ModelScope...")
print(f"Downloading {repo_id} via ModelScope...")
try:
# Use subprocess to call modelscope CLI
cmd = [
"modelscope", "download",
"--model", repo_id,
"--local_dir", str(local_dir)
]
subprocess.run(cmd, capture_output=True, text=True, check=True)
return True, f"Successfully downloaded {repo_id} via ModelScope"
except subprocess.CalledProcessError as e:
return False, f"ModelScope download failed: {e.stderr}"
except FileNotFoundError:
return False, "ModelScope CLI not found. Install with: pip install modelscope"
def download_all_checkpoints(self, progress_callback=None) -> Tuple[bool, str, Dict[str, any]]:
"""Download all checkpoints."""
results = {}
for name, _ in self.checkpoints.items():
if progress_callback:
progress_callback(f"Starting download of {name}...")
success, message = self.download_checkpoint(name, progress_callback)
results[name] = {"success": success, "message": message}
if not success:
return False, f"Failed to download {name}: {message}", results
return True, "All checkpoints downloaded successfully", results
@space_context(duration=2000)
def load_pipeline(use_distilled: bool = False, device: str = "cuda"):
"""Load the HunyuanImage pipeline (only load once, refiner and reprompt are accessed from it)."""
try:
assert not use_distilled # use_distilled is a placeholder for the future
print(f"Loading HunyuanImage pipeline (distilled={use_distilled})...")
model_name = "hunyuanimage-v2.1-distilled" if use_distilled else "hunyuanimage-v2.1"
pipeline = HunyuanImagePipeline.from_pretrained(
model_name=model_name,
device=device,
enable_dit_offloading=True,
enable_reprompt_model_offloading=True,
enable_refiner_offloading=True
)
pipeline.to('cpu')
refiner_pipeline = pipeline.refiner_pipeline
refiner_pipeline.text_encoder.model = pipeline.text_encoder.model
refiner_pipeline.to('cpu')
reprompt_model = pipeline.reprompt_model
print("βœ“ Pipeline loaded successfully")
return pipeline
except Exception as e:
error_msg = f"Error loading pipeline: {str(e)}"
print(f"βœ— {error_msg}")
raise
# if IS_SPACE:
# downloader = CheckpointDownloader()
# downloader.download_all_checkpoints()
pipeline = load_pipeline(use_distilled=False, device="cuda")
class HunyuanImageApp:
@space_context(duration=290)
def __init__(self, auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"):
"""Initialize the HunyuanImage Gradio app."""
global pipeline
self.pipeline = pipeline
self.current_use_distilled = None
# Define aspect ratio mappings
self.aspect_ratio_mappings = {
"16:9": (2560, 1536),
"4:3": (2304, 1792),
"1:1": (2048, 2048),
"3:4": (1792, 2304),
"9:16": (1536, 2560)
}
def print_peak_memory(self):
import torch
stats = torch.cuda.memory_stats()
peak_bytes_requirement = stats["allocated_bytes.all.peak"]
print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB")
def update_resolution(self, aspect_ratio_choice: str) -> Tuple[int, int]:
"""Update width and height based on selected aspect ratio."""
# Extract the aspect ratio key from the choice (e.g., "16:9" from "16:9 (2560Γ—1536)")
aspect_key = aspect_ratio_choice.split(" (")[0]
if aspect_key in self.aspect_ratio_mappings:
return self.aspect_ratio_mappings[aspect_key]
else:
# Default to 1:1 if not found
return self.aspect_ratio_mappings["1:1"]
@space_context(duration=300)
def generate_image(self,
prompt: str,
negative_prompt: str,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
seed: int,
use_reprompt: bool,
use_refiner: bool,
# use_distilled: bool
) -> Tuple[Optional[Image.Image], str]:
"""Generate an image using the HunyuanImage pipeline."""
try:
torch.cuda.empty_cache()
if self.pipeline is None:
return None, "Pipeline not loaded. Please try again."
if hasattr(self.pipeline, '_refiner_pipeline'):
self.pipeline.refiner_pipeline.to('cpu')
self.pipeline.to('cuda')
if seed == -1:
import random
seed = random.randint(100000, 999999)
# Generate image
image = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
shift=5,
use_reprompt=use_reprompt,
use_refiner=use_refiner
)
self.print_peak_memory()
return image, "Image generated successfully!"
except Exception as e:
error_msg = f"Error generating image: {str(e)}"
print(f"βœ— {error_msg}")
return None, error_msg
@space_context(duration=300)
def enhance_prompt(self, prompt: str, # use_distilled: bool
) -> Tuple[str, str]:
"""Enhance a prompt using the reprompt model."""
try:
torch.cuda.empty_cache()
# Load pipeline if needed
if self.pipeline is None:
return prompt, "Pipeline not loaded. Please try again."
self.pipeline.to('cpu')
if hasattr(self.pipeline, '_refiner_pipeline'):
self.pipeline.refiner_pipeline.to('cpu')
# Use reprompt model from the main pipeline
enhanced_prompt = self.pipeline.reprompt_model.predict(prompt)
self.print_peak_memory()
return enhanced_prompt, "Prompt enhanced successfully!"
except Exception as e:
error_msg = f"Error enhancing prompt: {str(e)}"
print(f"βœ— {error_msg}")
return prompt, error_msg
@space_context(duration=300)
def refine_image(self,
image: Image.Image,
prompt: str,
negative_prompt: str,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
seed: int) -> Tuple[Optional[Image.Image], str]:
"""Refine an image using the refiner pipeline."""
try:
if image is None:
return None, "Please upload an image to refine."
torch.cuda.empty_cache()
# Resize image to target dimensions if needed
if image.size != (width, height):
image = image.resize((width, height), Image.Resampling.LANCZOS)
self.pipeline.to('cpu')
self.pipeline.refiner_pipeline.to('cuda')
if seed == -1:
import random
seed = random.randint(100000, 999999)
# Use refiner from the main pipeline
refined_image = self.pipeline.refiner_pipeline(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed
)
self.print_peak_memory()
return refined_image, "Image refined successfully!"
except Exception as e:
error_msg = f"Error refining image: {str(e)}"
print(f"βœ— {error_msg}")
return None, error_msg
def download_single_checkpoint(self, checkpoint_name: str) -> Tuple[bool, str]:
"""Download a single checkpoint."""
try:
success, message = self.downloader.download_checkpoint(checkpoint_name)
return success, message
except Exception as e:
return False, f"Download error: {str(e)}"
def download_all_checkpoints(self) -> Tuple[bool, str, Dict[str, any]]:
"""Download all missing checkpoints."""
try:
success, message, results = self.downloader.download_all_checkpoints()
return success, message, results
except Exception as e:
return False, f"Download error: {str(e)}", {}
def create_interface(auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"):
"""Create the Gradio interface."""
app = HunyuanImageApp(auto_load=auto_load, use_distilled=use_distilled, device=device)
# Custom CSS for better styling with dark mode support
css = """
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
.tab-nav {
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
border-radius: 10px;
padding: 10px;
margin-bottom: 20px;
}
.model-info {
background: var(--background-fill-secondary);
border: 1px solid var(--border-color-primary);
border-radius: 8px;
padding: 15px;
margin-bottom: 20px;
color: var(--body-text-color);
}
.model-info h1, .model-info h2, .model-info h3 {
color: var(--body-text-color) !important;
}
.model-info p, .model-info li {
color: var(--body-text-color) !important;
}
.model-info strong {
color: var(--body-text-color) !important;
}
"""
with gr.Blocks(css=css, title="HunyuanImage Pipeline", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎨 HunyuanImage 2.1 Pipeline
**HunyuanImage-2.1: An Efficient Diffusion Model for High-Resolution (2K) Text-to-Image Generation​**
This app provides three main functionalities:
1. **Text-to-Image Generation**: Generate high-quality images from text prompts
2. **Prompt Enhancement**: Improve your prompts using MLLM reprompting
3. **Image Refinement**: Enhance existing images with the refiner model
""",
elem_classes="model-info"
)
with gr.Tabs():
# Tab 1: Text-to-Image Generation
with gr.Tab("πŸ–ΌοΈ Text-to-Image Generation"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Generation Settings")
gr.Markdown("**Model**: HunyuanImage v2.1 (Non-distilled)")
# use_distilled = gr.Checkbox(
# label="Use Distilled Model",
# value=False,
# info="Faster generation with slightly lower quality"
# )
use_distilled = False
prompt = gr.Textbox(
label="Prompt",
placeholder="",
lines=3,
value="A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word β€œTencent” on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="",
lines=2,
value=""
)
# Predefined aspect ratios
aspect_ratios = [
("16:9 (2560Γ—1536)", "16:9"),
("4:3 (2304Γ—1792)", "4:3"),
("1:1 (2048Γ—2048)", "1:1"),
("3:4 (1792Γ—2304)", "3:4"),
("9:16 (1536Γ—2560)", "9:16")
]
aspect_ratio = gr.Radio(
choices=aspect_ratios,
value="1:1",
label="Aspect Ratio",
info="Select the aspect ratio for image generation"
)
# Hidden width and height inputs that get updated based on aspect ratio
width = gr.Number(value=2048, visible=False)
height = gr.Number(value=2048, visible=False)
with gr.Row():
num_inference_steps = gr.Slider(
minimum=10, maximum=100, step=5, value=50,
label="Inference Steps", info="More steps = better quality, slower generation"
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=10.0, step=0.1, value=3.5,
label="Guidance Scale", info="How closely to follow the prompt"
)
with gr.Row():
seed = gr.Number(
label="Seed", value=-1, precision=0,
info="Random seed for reproducibility. (-1 for random seed)"
)
use_reprompt = gr.Checkbox(
label="Use Reprompt", value=True,
info="Enhance prompt automatically"
)
use_refiner = gr.Checkbox(
label="Use Refiner", value=True,
info="Apply refiner after generation ",
interactive=True
)
generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Generated Image")
generated_image = gr.Image(
label="Generated Image",
format="png",
show_download_button=True,
type="pil",
height=600
)
generation_status = gr.Textbox(
label="Status",
interactive=False,
value="Ready to generate"
)
# Tab 2: Prompt Enhancement
with gr.Tab("✨ Prompt Enhancement"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Prompt Enhancement Settings")
gr.Markdown("**Model**: HunyuanImage v2.1 Reprompt Model")
# enhance_use_distilled = gr.Checkbox(
# label="Use Distilled Model",
# value=False,
# info="For loading the reprompt model"
# )
enhance_use_distilled = False
original_prompt = gr.Textbox(
label="Original Prompt",
placeholder="A cat sitting on a table",
lines=4,
value="A cat sitting on a table"
)
enhance_btn = gr.Button("✨ Enhance Prompt", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Enhanced Prompt")
enhanced_prompt = gr.Textbox(
label="Enhanced Prompt",
lines=6,
interactive=False
)
enhancement_status = gr.Textbox(
label="Status",
interactive=False,
value="Ready to enhance"
)
# # Tab 3: Image Refinement
with gr.Tab("πŸ”§ Image Refinement"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Refinement Settings")
gr.Markdown("**Model**: HunyuanImage v2.1 Refiner")
input_image = gr.Image(
label="Input Image",
type="pil",
height=300
)
refine_prompt = gr.Textbox(
label="Refinement Prompt",
placeholder="Make the image more detailed and high quality",
lines=2,
value="Make the image more detailed and high quality"
)
refine_negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="",
lines=2,
value=""
)
with gr.Row():
refine_width = gr.Slider(
minimum=512, maximum=2048, step=64, value=2048,
label="Width", info="Output width"
)
refine_height = gr.Slider(
minimum=512, maximum=2048, step=64, value=2048,
label="Height", info="Output height"
)
with gr.Row():
refine_steps = gr.Slider(
minimum=1, maximum=20, step=1, value=4,
label="Refinement Steps", info="More steps = more refinement"
)
refine_guidance = gr.Slider(
minimum=1.0, maximum=10.0, step=0.1, value=3.5,
label="Guidance Scale", info="How strongly to follow the prompt"
)
refine_seed = gr.Number(
label="Seed", value=-1, precision=0,
info="Random seed for reproducibility"
)
refine_btn = gr.Button("πŸ”§ Refine Image", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Refined Image")
refined_image = gr.Image(
label="Refined Image",
type="pil",
height=600
)
refinement_status = gr.Textbox(
label="Status",
interactive=False,
value="Ready to refine"
)
# Event handlers
# Update width and height when aspect ratio changes
aspect_ratio.change(
fn=app.update_resolution,
inputs=[aspect_ratio],
outputs=[width, height]
)
generate_btn.click(
fn=app.generate_image,
inputs=[
prompt, negative_prompt, width, height, num_inference_steps,
guidance_scale, seed, use_reprompt, use_refiner # , use_distilled
],
outputs=[generated_image, generation_status]
)
enhance_btn.click(
fn=app.enhance_prompt,
inputs=[original_prompt],
outputs=[enhanced_prompt, enhancement_status]
)
refine_btn.click(
fn=app.refine_image,
inputs=[
input_image, refine_prompt, refine_negative_prompt,
refine_width, refine_height, refine_steps, refine_guidance, refine_seed
],
outputs=[refined_image, refinement_status]
)
# Additional info
gr.Markdown(
"""
### πŸ“ Usage Tips
**Text-to-Image Generation:**
- Use descriptive prompts with specific details
- Adjust guidance scale: higher values follow prompts more closely
- More inference steps generally produce better quality
- Enable reprompt for automatic prompt enhancement
- Enable refiner for additional quality improvement
**Prompt Enhancement:**
- Enter your basic prompt idea
- The AI will enhance it with better structure and details
- Enhanced prompts often produce better results
**Image Refinement:**
- Upload any image you want to improve
- Describe what improvements you want in the refinement prompt
- The refiner will enhance details and quality
- Works best with images generated by HunyuanImage
""",
elem_classes="model-info"
)
return demo
if __name__ == "__main__":
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(description="Launch HunyuanImage Gradio App")
parser.add_argument("--no-auto-load", action="store_true", help="Disable auto-loading pipeline on startup")
parser.add_argument("--use-distilled", action="store_true", help="Use distilled model")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)")
parser.add_argument("--port", type=int, default=8081, help="Port to run the app on")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
args = parser.parse_args()
# Create and launch the interface
auto_load = not args.no_auto_load
demo = create_interface(auto_load=auto_load, use_distilled=args.use_distilled, device=args.device)
print("πŸš€ Starting HunyuanImage Gradio App...")
print(f"πŸ“± The app will be available at: http://{args.host}:{args.port}")
print(f"πŸ”§ Auto-load pipeline: {'Yes' if auto_load else 'No'}")
print(f"🎯 Model type: {'Distilled' if args.use_distilled else 'Non-distilled'}")
print(f"πŸ’» Device: {args.device}")
print("⚠️ Make sure you have the required model checkpoints downloaded!")
demo.launch(
server_name=args.host,
# server_port=args.port,
share=False,
show_error=True,
quiet=False,
max_threads=1, # Default: sequential processing (recommended for GPU apps)
# max_threads=4, # Enable parallel processing (requires more GPU memory)
)