File size: 5,676 Bytes
f56ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import argparse
import os
import sys
# Add the project root directory to the Python path
sys.path.append(os.path.abspath(os.path.dirname(__file__)))

from inference.config_loader import load_config, find_config_by_model_id
from inference.model_initializer import (
    initialize_controlnet, 
    initialize_pipeline, 
    initialize_controlnet_detector
)
from inference.device_manager import setup_device
from inference.image_processor import load_input_image, detect_poses
from inference.image_generator import generate_images, save_images

# Global variables to store models
global controlnet_detector, controlnet, pipe
controlnet_detector = None
controlnet = None
pipe = None

def infer(

    config_path,

    input_image,

    image_url,

    prompt,

    negative_prompt,

    num_steps,

    seed,

    width,

    height,

    guidance_scale,

    controlnet_conditioning_scale,

    output_dir=None,

    use_prompt_as_output_name=None,

    save_output=False

):
    global controlnet_detector, controlnet, pipe
    
    # Load configuration
    configs = load_config(config_path)
    
    # Initialize models only if they are not already loaded
    if controlnet_detector is None or controlnet is None or pipe is None:
        controlnet_detector_config = find_config_by_model_id(configs, "lllyasviel/ControlNet")
        controlnet_config = find_config_by_model_id(configs, 
                                                    "danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet")
        pipeline_config = find_config_by_model_id(configs, 
                                                  "stabilityai/stable-diffusion-2-1")
        
        controlnet_detector = initialize_controlnet_detector(controlnet_detector_config)
        controlnet = initialize_controlnet(controlnet_config)
        pipe = initialize_pipeline(controlnet, pipeline_config)
    
    # Setup device
    device = setup_device(pipe)
    
    # Load and process image
    demo_image = load_input_image(input_image, image_url)
    poses = detect_poses(controlnet_detector, demo_image)
    
    # Generate images
    generators = [torch.Generator(device="cpu").manual_seed(seed + i) for i in range(len(poses))]
    output_images = generate_images(
        pipe,
        [prompt] * len(generators),
        poses,
        generators,
        [negative_prompt] * len(generators),
        num_steps,
        guidance_scale,
        controlnet_conditioning_scale,
        width,
        height
    )
    
    # Save images if required
    if save_output:
        save_images(output_images, output_dir, prompt, use_prompt_as_output_name)
    
    return output_images

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection")
    image_group = parser.add_mutually_exclusive_group(required=True)
    image_group.add_argument("--input_image", type=str, default=None,
                             help="Path to local input image (default: tests/test_data/yoga1.jpg)")
    image_group.add_argument("--image_url", type=str, default=None,
                             help="URL of input image (e.g., https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg)")
    
    parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml", 
                        help="Path to configuration YAML file")
    parser.add_argument("--prompt", type=str, default="a man is doing yoga",
                        help="Text prompt for image generation")
    parser.add_argument("--negative_prompt", type=str, 
                        default="monochrome, lowres, bad anatomy, worst quality, low quality",
                        help="Negative prompt for image generation")
    parser.add_argument("--num_steps", type=int, default=20,
                        help="Number of inference steps")
    parser.add_argument("--seed", type=int, default=2,
                        help="Random seed for generation")
    parser.add_argument("--width", type=int, default=512,
                        help="Width of the generated image")
    parser.add_argument("--height", type=int, default=512,
                        help="Height of the generated image")
    parser.add_argument("--guidance_scale", type=float, default=7.5,
                        help="Guidance scale for prompt adherence")
    parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0,
                        help="ControlNet conditioning scale")
    parser.add_argument("--output_dir", type=str, default="tests/test_data",
                        help="Directory to save generated images")
    parser.add_argument("--use_prompt_as_output_name", action="store_true",
                        help="Use prompt as part of output image filename")
    parser.add_argument("--save_output", action="store_true",
                        help="Save generated images to output directory")
    
    args = parser.parse_args()
    infer(
        config_path=args.config_path,
        input_image=args.input_image,
        image_url=args.image_url,
        prompt=args.prompt,
        negative_prompt=args.negative_prompt,
        num_steps=args.num_steps,
        seed=args.seed,
        width=args.width,
        height=args.height,
        guidance_scale=args.guidance_scale,
        controlnet_conditioning_scale=args.controlnet_conditioning_scale,
        output_dir=args.output_dir,
        use_prompt_as_output_name=args.use_prompt_as_output_name,
        save_output=args.save_output
    )