import argparse import os import yaml from main import generate_images, load_unet_controller from unet import utils import torch import queue import threading from tqdm import tqdm # Import tqdm def main_ben(unet_controller, pipe, save_dir, id_prompt, frame_prompt_list, seed, window_length): unet_controller.ipca_index = -1 unet_controller.ipca_time_step = -1 # Ensure each process uses its own assigned device os.makedirs(save_dir, exist_ok=True) images, story_image = generate_images(unet_controller, pipe, id_prompt, frame_prompt_list, save_dir, window_length, seed, verbose=False) return images, story_image def process_instance(unet_controller, pipe, instance): # Unpack instance and execute task save_dir, id_prompt, frame_prompt_list, seed, window_length = instance return main_ben(unet_controller, pipe, save_dir, id_prompt, frame_prompt_list, seed, window_length) def worker(device, unet_controller, pipe, task_queue, pbar): # Process tasks until queue is empty while not task_queue.empty(): instance = task_queue.get() if instance is None: # If None is encountered, stop the worker break # Process the instance result = process_instance(unet_controller, pipe, instance) # Log the completion print(f"Finished processing {instance[1]}") # Log the processed instance (id_prompt) task_queue.task_done() # Mark the task as done pbar.update(1) # Update the progress bar def main(): parser = argparse.ArgumentParser(description="Calculate image similarities using DreamSim or CLIP.") parser.add_argument('--device', type=str, choices=['cuda:0', 'cuda:1', 'cuda'], default='cuda') parser.add_argument('--save_dir', type=str,) parser.add_argument('--benchmark_path', type=str,) parser.add_argument('--model_path', type=str, default='stabilityai/stable-diffusion-xl-base-1.0', help='Path to the model') parser.add_argument('--precision', type=str, choices=["fp16", "fp32"], default="fp16", help='Model precision') parser.add_argument('--window_length', type=int, default=10, help='Window length for story generation') parser.add_argument('--num_gpus', type=int, default=2, help='Number of GPUs to use') parser.add_argument('--fix_seed', type=int, default=42, help='-1 for random seed') args = parser.parse_args() # Create a list of devices devices = [f'cuda:{i}' for i in range(args.num_gpus)] # List of device names if args.num_gpus == 1: devices = [args.device] # Load unet_controllers and pipes for each device unet_controllers = {} pipes = {} for device in devices: pipe, _ = utils.load_pipe_from_path(args.model_path, device, torch.float16 if args.precision == "fp16" else torch.float32, args.precision) unet_controller = load_unet_controller(pipe, device) unet_controller.Save_story_image = False unet_controller.Prompt_embeds_mode = "svr-eot" # unet_controller.Is_freeu_enabled = True unet_controllers[device] = unet_controller pipes[device] = pipe # Load the benchmark data with open(os.path.expanduser(args.benchmark_path), 'r') as file: data = yaml.safe_load(file) instances = [] for subject_domain, subject_domain_instances in data.items(): for index, instance in enumerate(subject_domain_instances): id_prompt = f'{instance["style"]} {instance["subject"]}' frame_prompt_list = instance["settings"] save_dir = os.path.join(args.save_dir, f"{subject_domain}_{index}") if args.fix_seed != -1: seed = args.fix_seed else: import random seed = random.randint(0, 2**32 - 1) instances.append((save_dir, id_prompt, frame_prompt_list, seed, args.window_length)) # Create a task queue and populate it with instances task_queue = queue.Queue() for instance in instances: task_queue.put(instance) # Initialize tqdm progress bar pbar = tqdm(total=len(instances)) # Create threads for each device to process instances threads = [] for device in devices: unet_controller = unet_controllers[device] pipe = pipes[device] thread = threading.Thread(target=worker, args=(device, unet_controller, pipe, task_queue, pbar)) threads.append(thread) thread.start() import time time.sleep(1) # Wait for 1 second before starting the next thread # Wait for all threads to finish for thread in threads: thread.join() # Close the progress bar pbar.close() if __name__ == "__main__": main()