from PIL import Image from datasets import load_dataset from torchvision import transforms import random import torch import os from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS import numpy as np from src.condition.edge_extraction import ( CannyDetector, PidiNetDetector, TEDDetector, LineartStandardDetector, HEDdetector, AnyLinePreprocessor, LineartDetector, InformativeDetector ) Image.MAX_IMAGE_PIXELS = None def multiple_16(num: float): return int(round(num / 16) * 16) def load_image_safely(image_path, size, root="/mnt/robby-b1/common/datasets/"): image_path = os.path.join(root, image_path) try: image = Image.open(image_path).convert("RGB") return image except Exception as e: print("file error: "+image_path) with open("failed_images.txt", "a") as f: f.write(f"{image_path}\n") return Image.new("RGB", (size, size), (255, 255, 255)) def choose_kontext_resolution_from_wh(width: int, height: int): aspect_ratio = width / max(1, height) _, best_w, best_h = min( (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS ) return best_w, best_h class EdgeExtractorManager: _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super(EdgeExtractorManager, cls).__new__(cls) return cls._instance def __init__(self): if not self._initialized: self.edge_extractors = None self.device = None self._initialized = True def set_device(self, device): self.device = device def get_edge_extractors(self, device=None): # 强制在CPU上初始化,避免DataLoader子进程中触发CUDA初始化 current_device = "cpu" if device is not None: self.set_device(current_device) if self.edge_extractors is None or len(self.edge_extractors) == 0: self.edge_extractors = [ ("canny", CannyDetector()), ("pidinet", PidiNetDetector.from_pretrained().to(current_device)), ("ted", TEDDetector.from_pretrained().to(current_device)), # ("lineart_standard", LineartStandardDetector()), ("hed", HEDdetector.from_pretrained().to(current_device)), ("anyline", AnyLinePreprocessor.from_pretrained().to(current_device)), # ("lineart", LineartDetector.from_pretrained().to(current_device)), ("informative", InformativeDetector.from_pretrained().to(current_device)), ] return self.edge_extractors edge_extractor_manager = EdgeExtractorManager() def collate_fn(examples): if examples[0].get("cond_pixel_values") is not None: cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() else: cond_pixel_values = None source_pixel_values = None target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples]) token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples]) return { "cond_pixel_values": cond_pixel_values, "source_pixel_values": source_pixel_values, "pixel_values": target_pixel_values, "text_ids_1": token_ids_clip, "text_ids_2": token_ids_t5, } def make_train_dataset_inpaint_mask(args, tokenizers, accelerator=None): # 加载CSV数据集:三列,第一列为图片相对路径,第三列为caption if args.train_data_dir is not None: dataset = load_dataset('csv', data_files=args.train_data_dir) # 列名兼容处理:使用第 0 列作为图片路径,第 2 列作为caption column_names = dataset["train"].column_names image_col = column_names[0] caption_col = column_names[2] if len(column_names) >= 3 else column_names[-1] size = args.cond_size # 设备设置(用于分布式时将部分检测器放到对应GPU) if accelerator is not None: device = accelerator.device edge_extractor_manager.set_device(device) else: device = "cpu" # Transforms to_tensor_and_norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) # 与 jsonl_datasets_edge.py 保持一致:Resize -> CenterCrop -> ToTensor -> Normalize cond_train_transforms = transforms.Compose([ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop((size, size)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) tokenizer_clip = tokenizers[0] tokenizer_t5 = tokenizers[1] def tokenize_prompt_clip_t5(examples): captions_raw = examples[caption_col] captions = [] for c in captions_raw: if isinstance(c, str): if random.random() < 0.25: captions.append("") else: captions.append(c) else: captions.append("") text_inputs_clip = tokenizer_clip( captions, padding="max_length", max_length=77, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids_1 = text_inputs_clip.input_ids text_inputs_t5 = tokenizer_t5( captions, padding="max_length", max_length=128, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids_2 = text_inputs_t5.input_ids return text_input_ids_1, text_input_ids_2 def preprocess_train(examples): batch = {} img_paths = examples[image_col] target_tensors = [] cond_tensors = [] for p in img_paths: # Load image by joining with root in load_image_safely img = load_image_safely(p, size) img = img.convert("RGB") # Resize to Kontext preferred resolution for target w, h = img.size best_w, best_h = choose_kontext_resolution_from_wh(w, h) img_rs = img.resize((best_w, best_h), resample=Image.BILINEAR) target_tensor = to_tensor_and_norm(img_rs) # Build edge condition extractor_name, extractor = random.choice(edge_extractor_manager.get_edge_extractors()) img_np = np.array(img) if extractor_name == "informative": edge = extractor(img_np, style="contour") else: edge = extractor(img_np) if extractor_name == "ted": th = 128 else: th = 32 edge_np = np.array(edge) if isinstance(edge, Image.Image) else edge if edge_np.ndim == 3: edge_np = edge_np[..., 0] edge_bin = (edge_np > th).astype(np.float32) edge_pil = Image.fromarray((edge_bin * 255).astype(np.uint8)) edge_tensor = cond_train_transforms(edge_pil) edge_tensor = edge_tensor.repeat(3, 1, 1) target_tensors.append(target_tensor) cond_tensors.append(edge_tensor) batch["pixel_values"] = target_tensors batch["cond_pixel_values"] = cond_tensors batch["token_ids_clip"], batch["token_ids_t5"] = tokenize_prompt_clip_t5(examples) return batch if accelerator is not None: with accelerator.main_process_first(): train_dataset = dataset["train"].with_transform(preprocess_train) else: train_dataset = dataset["train"].with_transform(preprocess_train) return train_dataset