File size: 5,676 Bytes
f56ede2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import argparse
import os
import sys
# Add the project root directory to the Python path
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from inference.config_loader import load_config, find_config_by_model_id
from inference.model_initializer import (
initialize_controlnet,
initialize_pipeline,
initialize_controlnet_detector
)
from inference.device_manager import setup_device
from inference.image_processor import load_input_image, detect_poses
from inference.image_generator import generate_images, save_images
# Global variables to store models
global controlnet_detector, controlnet, pipe
controlnet_detector = None
controlnet = None
pipe = None
def infer(
config_path,
input_image,
image_url,
prompt,
negative_prompt,
num_steps,
seed,
width,
height,
guidance_scale,
controlnet_conditioning_scale,
output_dir=None,
use_prompt_as_output_name=None,
save_output=False
):
global controlnet_detector, controlnet, pipe
# Load configuration
configs = load_config(config_path)
# Initialize models only if they are not already loaded
if controlnet_detector is None or controlnet is None or pipe is None:
controlnet_detector_config = find_config_by_model_id(configs, "lllyasviel/ControlNet")
controlnet_config = find_config_by_model_id(configs,
"danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet")
pipeline_config = find_config_by_model_id(configs,
"stabilityai/stable-diffusion-2-1")
controlnet_detector = initialize_controlnet_detector(controlnet_detector_config)
controlnet = initialize_controlnet(controlnet_config)
pipe = initialize_pipeline(controlnet, pipeline_config)
# Setup device
device = setup_device(pipe)
# Load and process image
demo_image = load_input_image(input_image, image_url)
poses = detect_poses(controlnet_detector, demo_image)
# Generate images
generators = [torch.Generator(device="cpu").manual_seed(seed + i) for i in range(len(poses))]
output_images = generate_images(
pipe,
[prompt] * len(generators),
poses,
generators,
[negative_prompt] * len(generators),
num_steps,
guidance_scale,
controlnet_conditioning_scale,
width,
height
)
# Save images if required
if save_output:
save_images(output_images, output_dir, prompt, use_prompt_as_output_name)
return output_images
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection")
image_group = parser.add_mutually_exclusive_group(required=True)
image_group.add_argument("--input_image", type=str, default=None,
help="Path to local input image (default: tests/test_data/yoga1.jpg)")
image_group.add_argument("--image_url", type=str, default=None,
help="URL of input image (e.g., https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg)")
parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml",
help="Path to configuration YAML file")
parser.add_argument("--prompt", type=str, default="a man is doing yoga",
help="Text prompt for image generation")
parser.add_argument("--negative_prompt", type=str,
default="monochrome, lowres, bad anatomy, worst quality, low quality",
help="Negative prompt for image generation")
parser.add_argument("--num_steps", type=int, default=20,
help="Number of inference steps")
parser.add_argument("--seed", type=int, default=2,
help="Random seed for generation")
parser.add_argument("--width", type=int, default=512,
help="Width of the generated image")
parser.add_argument("--height", type=int, default=512,
help="Height of the generated image")
parser.add_argument("--guidance_scale", type=float, default=7.5,
help="Guidance scale for prompt adherence")
parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0,
help="ControlNet conditioning scale")
parser.add_argument("--output_dir", type=str, default="tests/test_data",
help="Directory to save generated images")
parser.add_argument("--use_prompt_as_output_name", action="store_true",
help="Use prompt as part of output image filename")
parser.add_argument("--save_output", action="store_true",
help="Save generated images to output directory")
args = parser.parse_args()
infer(
config_path=args.config_path,
input_image=args.input_image,
image_url=args.image_url,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
num_steps=args.num_steps,
seed=args.seed,
width=args.width,
height=args.height,
guidance_scale=args.guidance_scale,
controlnet_conditioning_scale=args.controlnet_conditioning_scale,
output_dir=args.output_dir,
use_prompt_as_output_name=args.use_prompt_as_output_name,
save_output=args.save_output
) |