Spaces:
Running
on
Zero
Running
on
Zero
| from PIL import Image | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| import random | |
| import torch | |
| import os | |
| from datasets import load_dataset | |
| import numpy as np | |
| import json | |
| Image.MAX_IMAGE_PIXELS = None | |
| def collate_fn(examples): | |
| if examples[0].get("cond_pixel_values") is not None: | |
| cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) | |
| cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| else: | |
| cond_pixel_values = None | |
| if examples[0].get("source_pixel_values") is not None: | |
| source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples]) | |
| source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| else: | |
| source_pixel_values = None | |
| target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
| target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples]) | |
| token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples]) | |
| mask_values = None | |
| if examples[0].get("mask_values") is not None: | |
| mask_values = torch.stack([example["mask_values"] for example in examples]) | |
| mask_values = mask_values.to(memory_format=torch.contiguous_format).float() | |
| return { | |
| "cond_pixel_values": cond_pixel_values, | |
| "source_pixel_values": source_pixel_values, | |
| "pixel_values": target_pixel_values, | |
| "text_ids_1": token_ids_clip, | |
| "text_ids_2": token_ids_t5, | |
| "mask_values": mask_values, | |
| } | |
| def _resolve_jsonl(path_str: str): | |
| if path_str is None or str(path_str).strip() == "": | |
| raise ValueError("train_data_jsonl is empty. Please set --train_data_jsonl to a JSON/JSONL file or a folder.") | |
| if os.path.isdir(path_str): | |
| files = [ | |
| os.path.join(path_str, f) | |
| for f in os.listdir(path_str) | |
| if f.lower().endswith((".jsonl", ".json")) | |
| ] | |
| if not files: | |
| raise ValueError(f"No .json or .jsonl files found under directory: {path_str}") | |
| return {"train": sorted(files)} | |
| if not os.path.exists(path_str): | |
| raise FileNotFoundError(f"train_data_jsonl not found: {path_str}") | |
| return {"train": [path_str]} | |
| def _tokenize(tokenizers, caption: str): | |
| tokenizer_clip = tokenizers[0] | |
| tokenizer_t5 = tokenizers[1] | |
| text_inputs_clip = tokenizer_clip( | |
| [caption], padding="max_length", max_length=77, truncation=True, return_tensors="pt" | |
| ) | |
| text_inputs_t5 = tokenizer_t5( | |
| [caption], padding="max_length", max_length=128, truncation=True, return_tensors="pt" | |
| ) | |
| return text_inputs_clip.input_ids[0], text_inputs_t5.input_ids[0] | |
| def _prepend_caption(caption: str) -> str: | |
| """Prepend instruction and keep only instruction with 20% prob.""" | |
| instruction = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary." | |
| if random.random() < 0.2: | |
| return instruction | |
| caption = caption or "" | |
| if caption.strip(): | |
| return f"{instruction} {caption.strip()}" | |
| return instruction | |
| def _color_augment(pil_img: Image.Image) -> Image.Image: | |
| brightness = random.uniform(0.8, 1.2) | |
| contrast = random.uniform(0.8, 1.2) | |
| saturation = random.uniform(0.8, 1.2) | |
| hue = random.uniform(-0.05, 0.05) | |
| img = TF.adjust_brightness(pil_img, brightness) | |
| img = TF.adjust_contrast(img, contrast) | |
| img = TF.adjust_saturation(img, saturation) | |
| img = TF.adjust_hue(img, hue) | |
| return img | |
| def _dilate_mask(mask_bin: np.ndarray, min_px: int = 5, max_px: int = 100) -> np.ndarray: | |
| """Grow binary mask by a random radius in [min_px, max_px]. Expects values {0,1}.""" | |
| import cv2 | |
| radius = int(max(min_px, min(max_px, random.randint(min_px, max_px)))) | |
| if radius <= 0: | |
| return mask_bin.astype(np.uint8) | |
| ksize = 2 * radius + 1 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) | |
| grown = cv2.dilate(mask_bin.astype(np.uint8), kernel, iterations=1) | |
| return (grown > 0).astype(np.uint8) | |
| def _random_point_inside_mask(mask_bin: np.ndarray) -> tuple: | |
| ys, xs = np.where(mask_bin > 0) | |
| if len(xs) == 0: | |
| h, w = mask_bin.shape | |
| return w // 2, h // 2 | |
| idx = random.randrange(len(xs)) | |
| return int(xs[idx]), int(ys[idx]) | |
| def _bbox_containing_mask(mask_bin: np.ndarray, img_w: int, img_h: int) -> tuple: | |
| ys, xs = np.where(mask_bin > 0) | |
| if len(xs) == 0: | |
| return 0, 0, img_w - 1, img_h - 1 | |
| x1, x2 = int(xs.min()), int(xs.max()) | |
| y1, y2 = int(ys.min()), int(ys.max()) | |
| # Random padding | |
| max_pad = int(0.25 * min(img_w, img_h)) | |
| pad_x1 = random.randint(0, max_pad) | |
| pad_x2 = random.randint(0, max_pad) | |
| pad_y1 = random.randint(0, max_pad) | |
| pad_y2 = random.randint(0, max_pad) | |
| x1 = max(0, x1 - pad_x1) | |
| y1 = max(0, y1 - pad_y1) | |
| x2 = min(img_w - 1, x2 + pad_x2) | |
| y2 = min(img_h - 1, y2 + pad_y2) | |
| return x1, y1, x2, y2 | |
| def _constrained_random_mask(mask_bin: np.ndarray, image_h: int, image_w: int, aug_prob: float = 0.7) -> np.ndarray: | |
| """Generate random mask whose box contains or starts in m_p, and brush strokes start inside m_p. | |
| Returns binary 0/1 array of shape (H,W). | |
| """ | |
| import cv2 | |
| if random.random() >= aug_prob: | |
| return np.zeros((image_h, image_w), dtype=np.uint8) | |
| # Scale similar to reference | |
| ref_size = 1024 | |
| scale_factor = max(1.0, min(image_h, image_w) / float(ref_size)) | |
| out = np.zeros((image_h, image_w), dtype=np.uint8) | |
| # Choose exactly one augmentation: bbox OR stroke | |
| if random.random() < 0.2: | |
| # BBox augmentation: draw N boxes (randomized), first box often contains mask | |
| num_boxes = random.randint(1, 6) | |
| for b in range(num_boxes): | |
| if b == 0 and random.random() < 0.5: | |
| x1, y1, x2, y2 = _bbox_containing_mask(mask_bin, image_w, image_h) | |
| else: | |
| sx, sy = _random_point_inside_mask(mask_bin) | |
| max_w = int(500 * scale_factor) | |
| min_w = int(100 * scale_factor) | |
| bw = random.randint(max(1, min_w), max(2, max_w)) | |
| bh = random.randint(max(1, min_w), max(2, max_w)) | |
| x1 = max(0, sx - random.randint(0, bw)) | |
| y1 = max(0, sy - random.randint(0, bh)) | |
| x2 = min(image_w - 1, x1 + bw) | |
| y2 = min(image_h - 1, y1 + bh) | |
| out[y1:y2 + 1, x1:x2 + 1] = 1 | |
| else: | |
| # Stroke augmentation: draw N strokes starting inside mask | |
| num_strokes = random.randint(1, 6) | |
| for _ in range(num_strokes): | |
| num_points = random.randint(10, 30) | |
| stroke_width = random.randint(max(1, int(100 * scale_factor)), max(2, int(400 * scale_factor))) | |
| max_offset = max(1, int(100 * scale_factor)) | |
| start_x, start_y = _random_point_inside_mask(mask_bin) | |
| px, py = start_x, start_y | |
| for _ in range(num_points): | |
| dx = random.randint(-max_offset, max_offset) | |
| dy = random.randint(-max_offset, max_offset) | |
| nx = int(np.clip(px + dx, 0, image_w - 1)) | |
| ny = int(np.clip(py + dy, 0, image_h - 1)) | |
| cv2.line(out, (px, py), (nx, ny), 1, stroke_width) | |
| px, py = nx, ny | |
| return (out > 0).astype(np.uint8) | |
| def make_placement_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None): | |
| """ | |
| Dataset for JSONL with fields: | |
| - generated_image_path: relative to base_dir (target image with object) | |
| - mask_path: relative to base_dir (mask of object) | |
| - generated_width, generated_height: image dimensions | |
| - final_prompt: caption | |
| - relight_images: list of {mode, path} for relighted versions | |
| source image construction: | |
| - background is target_image with a hole punched by grown mask | |
| - foreground is randomly selected from relight_images with weights | |
| - includes perspective transformation (moved from interactive dataset) | |
| Args: | |
| base_dir: Base directory for resolving relative paths. If None, uses args.placement_base_dir. | |
| """ | |
| if base_dir is None: | |
| base_dir = getattr(args, "placement_base_dir") | |
| data_files = _resolve_jsonl(getattr(args, "placement_data_jsonl", None)) | |
| file_paths = data_files.get("train", []) | |
| records = [] | |
| for p in file_paths: | |
| with open(p, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| try: | |
| obj = json.loads(line.rstrip(",")) | |
| except Exception: | |
| continue | |
| # Keep only fields we need | |
| pruned = { | |
| "generated_image_path": obj.get("generated_image_path"), | |
| "mask_path": obj.get("mask_path"), | |
| "generated_width": obj.get("generated_width"), | |
| "generated_height": obj.get("generated_height"), | |
| "final_prompt": obj.get("final_prompt"), | |
| "relight_images": obj.get("relight_images"), | |
| } | |
| records.append(pruned) | |
| size = int(getattr(args, "cond_size", 512)) | |
| to_tensor_and_norm = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| class PlacementDataset(torch.utils.data.Dataset): | |
| def __init__(self, hf_ds, base_dir): | |
| self.ds = hf_ds | |
| self.base_dir = base_dir | |
| def __len__(self): | |
| # Triplicate sampling per record | |
| return len(self.ds) | |
| def __getitem__(self, idx): | |
| rec = self.ds[idx % len(self.ds)] | |
| t_rel = rec.get("generated_image_path", "") | |
| m_rel = rec.get("mask_path", "") | |
| # Both are relative paths | |
| t_p = os.path.join(self.base_dir, t_rel) | |
| m_p = os.path.join(self.base_dir, m_rel) | |
| import cv2 | |
| mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE) | |
| if mask_loaded is None: | |
| raise ValueError(f"Failed to read mask: {m_p}") | |
| tgt_img = Image.open(t_p).convert("RGB") | |
| fw = int(rec.get("generated_width", tgt_img.width)) | |
| fh = int(rec.get("generated_height", tgt_img.height)) | |
| tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR) | |
| mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST) | |
| target_tensor = to_tensor_and_norm(tgt_img) | |
| # Binary mask at final_size | |
| mask_np = np.array(mask_img) | |
| mask_bin = (mask_np > 127).astype(np.uint8) | |
| # 1) Grow mask by random 50-100 pixels | |
| grown_mask = _dilate_mask(mask_bin, 50, 200) | |
| # 2) Optional random augmentation mask constrained by mask | |
| rand_mask = _constrained_random_mask(mask_bin, fh, fw, 7) | |
| # 3) Final union mask | |
| union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8) | |
| tgt_np = np.array(tgt_img) | |
| # Helper: choose relighted image from relight_images with weights | |
| def _choose_relight_image(rec, width, height): | |
| relight_list = rec.get("relight_images") or [] | |
| # Build map mode -> path | |
| mode_to_path = {} | |
| for it in relight_list: | |
| try: | |
| mode = str(it.get("mode", "")).strip().lower() | |
| path = it.get("path") | |
| except Exception: | |
| continue | |
| if not mode or not path: | |
| continue | |
| mode_to_path[mode] = path | |
| weighted_order = [ | |
| ("grayscale", 0.5), | |
| ("low", 0.3), | |
| ("high", 0.2), | |
| ] | |
| # Filter to available | |
| available = [(m, w) for (m, w) in weighted_order if m in mode_to_path] | |
| chosen_path = None | |
| if available: | |
| rnd = random.random() | |
| cum = 0.0 | |
| total_w = sum(w for _, w in available) | |
| for m, w in available: | |
| cum += w / total_w | |
| if rnd <= cum: | |
| chosen_path = mode_to_path.get(m) | |
| break | |
| if chosen_path is None: | |
| chosen_path = mode_to_path.get(available[-1][0]) | |
| else: | |
| # Fallback to any provided path | |
| if mode_to_path: | |
| chosen_path = next(iter(mode_to_path.values())) | |
| # Open chosen image | |
| if chosen_path is not None: | |
| try: | |
| # Paths are relative to base_dir | |
| open_path = os.path.join(self.base_dir, chosen_path) | |
| img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR) | |
| return img | |
| except Exception: | |
| pass | |
| # Fallback: return target image | |
| return Image.open(t_p).convert("RGB").resize((width, height), resample=Image.BILINEAR) | |
| # Choose base image with probabilities: | |
| # 20%: original target, 20%: color augment(target), 60%: relight augment | |
| rsel = random.random() | |
| if rsel < 0.2: | |
| base_img = tgt_img | |
| elif rsel < 0.4: | |
| base_img = _color_augment(tgt_img) | |
| else: | |
| base_img = _choose_relight_image(rec, fw, fh) | |
| base_np = np.array(base_img) | |
| fore_np = base_np.copy() | |
| # Random perspective augmentation (50%): apply to foreground ROI (mask bbox) and its mask only | |
| perspective_applied = False | |
| roi_update = None | |
| paste_mask_bool = mask_bin.astype(bool) | |
| if random.random() < 0.5: | |
| try: | |
| import cv2 | |
| ys, xs = np.where(mask_bin > 0) | |
| if len(xs) > 0 and len(ys) > 0: | |
| x1, x2 = int(xs.min()), int(xs.max()) | |
| y1, y2 = int(ys.min()), int(ys.max()) | |
| if x2 > x1 and y2 > y1: | |
| roi = base_np[y1:y2 + 1, x1:x2 + 1] | |
| roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1] | |
| bh, bw = roi.shape[:2] | |
| # Random perturbation relative to ROI size | |
| max_ratio = random.uniform(0.1, 0.3) | |
| dx = bw * max_ratio | |
| dy = bh * max_ratio | |
| src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]]) | |
| dst = np.float32([ | |
| [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)], | |
| [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)], | |
| [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)], | |
| [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)], | |
| ]) | |
| M = cv2.getPerspectiveTransform(src, dst) | |
| warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101) | |
| warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127 | |
| # Build a fresh foreground canvas | |
| fore_np = np.zeros_like(base_np) | |
| h_warp, w_warp = warped_mask_roi.shape | |
| y2c = y1 + h_warp | |
| x2c = x1 + w_warp | |
| fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi] | |
| paste_mask_bool = np.zeros_like(mask_bin, dtype=bool) | |
| paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi | |
| roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi) | |
| perspective_applied = True | |
| except Exception: | |
| perspective_applied = False | |
| paste_mask_bool = mask_bin.astype(bool) | |
| fore_np = base_np | |
| # Optional: simulate resolution artifacts | |
| if random.random() < 0.7: | |
| ys, xs = np.where(paste_mask_bool) | |
| if len(xs) > 0 and len(ys) > 0: | |
| x1, x2 = int(xs.min()), int(xs.max()) | |
| y1, y2 = int(ys.min()), int(ys.max()) | |
| if x2 > x1 and y2 > y1: | |
| crop = fore_np[y1:y2 + 1, x1:x2 + 1] | |
| ch, cw = crop.shape[:2] | |
| scale = random.uniform(0.15, 0.9) | |
| dw = max(1, int(cw * scale)) | |
| dh = max(1, int(ch * scale)) | |
| try: | |
| small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC) | |
| back = small.resize((cw, ch), Image.BICUBIC) | |
| crop_blurred = np.array(back).astype(np.uint8) | |
| fore_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred | |
| except Exception: | |
| pass | |
| # Build masked target and compose | |
| union_mask_for_target = union_mask.copy() | |
| if roi_update is not None: | |
| rx, ry, rh, rw, warped_mask_roi = roi_update | |
| um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw] | |
| union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1) | |
| masked_t_np = tgt_np.copy() | |
| masked_t_np[union_mask_for_target.astype(bool)] = 255 | |
| composed_np = masked_t_np.copy() | |
| m_fore = paste_mask_bool | |
| composed_np[m_fore] = fore_np[m_fore] | |
| # Build tensors | |
| source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8))) | |
| mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0) | |
| # Caption: prepend instruction | |
| cap_orig = rec.get("final_prompt", "") or "" | |
| # Handle list format in final_prompt | |
| if isinstance(cap_orig, list) and len(cap_orig) > 0: | |
| cap_orig = cap_orig[0] if isinstance(cap_orig[0], str) else str(cap_orig[0]) | |
| cap = _prepend_caption(cap_orig) | |
| if perspective_applied: | |
| cap = f"{cap} Fix the perspective if necessary." | |
| ids1, ids2 = _tokenize(tokenizers, cap) | |
| return { | |
| "source_pixel_values": source_tensor, | |
| "pixel_values": target_tensor, | |
| "token_ids_clip": ids1, | |
| "token_ids_t5": ids2, | |
| "mask_values": mask_tensor, | |
| } | |
| return PlacementDataset(records, base_dir) | |
| def make_interactive_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None): | |
| """ | |
| Dataset for JSONL with fields: | |
| - input_path: relative to base_dir (target image) | |
| - output_path: absolute path to image with foreground | |
| - mask_after_completion: absolute path to mask | |
| - img_width, img_height: resize dimensions | |
| - prompt: caption | |
| source image construction: | |
| - background is target_image with a hole punched by grown `mask_after_completion` | |
| - foreground is from `output_path` image, pasted using original `mask_after_completion` | |
| - 50% chance to color augment the foreground source | |
| - NO perspective transform (moved to placement dataset) | |
| Args: | |
| base_dir: Base directory for resolving relative paths. If None, uses args.interactive_base_dir. | |
| """ | |
| if base_dir is None: | |
| base_dir = getattr(args, "interactive_base_dir") | |
| data_files = _resolve_jsonl(getattr(args, "train_data_jsonl", None)) | |
| file_paths = data_files.get("train", []) | |
| records = [] | |
| for p in file_paths: | |
| with open(p, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| # Best-effort: strip any trailing commas and retry | |
| try: | |
| obj = json.loads(line.rstrip(",")) | |
| except Exception: | |
| continue | |
| # Keep only fields we actually need to avoid schema issues | |
| pruned = { | |
| "input_path": obj.get("input_path"), | |
| "output_path": obj.get("output_path"), | |
| "mask_after_completion": obj.get("mask_after_completion"), | |
| "img_width": obj.get("img_width"), | |
| "img_height": obj.get("img_height"), | |
| "prompt": obj.get("prompt"), | |
| # New optional fields | |
| "generated_images": obj.get("generated_images"), | |
| "positive_prompt_used": obj.get("positive_prompt_used"), | |
| "negative_caption_used": obj.get("negative_caption_used"), | |
| } | |
| records.append(pruned) | |
| size = int(getattr(args, "cond_size", 512)) | |
| to_tensor_and_norm = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| class SubjectsDataset(torch.utils.data.Dataset): | |
| def __init__(self, hf_ds, base_dir): | |
| self.ds = hf_ds | |
| self.base_dir = base_dir | |
| def __len__(self): | |
| # Triplicate sampling per record | |
| return len(self.ds) | |
| def __getitem__(self, idx): | |
| rec = self.ds[idx % len(self.ds)] | |
| t_rel = rec.get("input_path", "") | |
| foreground_p = rec.get("output_path", "") | |
| m_abs = rec.get("mask_after_completion", "") | |
| if not os.path.isabs(m_abs): | |
| raise ValueError("mask_after_completion must be absolute") | |
| if not os.path.isabs(foreground_p): | |
| raise ValueError("output_path must be absolute") | |
| t_p = os.path.join(self.base_dir, t_rel) | |
| m_p = m_abs | |
| import cv2 | |
| mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE) | |
| if mask_loaded is None: | |
| raise ValueError(f"Failed to read mask: {m_p}") | |
| tgt_img = Image.open(t_p).convert("RGB") | |
| foreground_source_img = Image.open(foreground_p).convert("RGB") | |
| fw = int(rec.get("img_width", tgt_img.width)) | |
| fh = int(rec.get("img_height", tgt_img.height)) | |
| tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR) | |
| foreground_source_img = foreground_source_img.resize((fw, fh), resample=Image.BILINEAR) | |
| mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST) | |
| # Ensure PIL images to tensors for outputs based on new logic later | |
| target_tensor = to_tensor_and_norm(tgt_img) | |
| # Binary mask at final_size | |
| mask_np = np.array(mask_img) | |
| mask_bin = (mask_np > 127).astype(np.uint8) | |
| # 1) Grow m_p by random 50-100 pixels | |
| grown_mask = _dilate_mask(mask_bin, 50, 200) | |
| # 2) Optional random augmentation mask constrained by m_p | |
| rand_mask = _constrained_random_mask(mask_bin, fh, fw, aug_prob=0.7) | |
| # 3) Final union mask | |
| union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8) | |
| tgt_np = np.array(tgt_img) | |
| # Helper: choose relighted image from generated_images with weights | |
| def _choose_relight_image(rec, default_img, width, height): | |
| gen_list = rec.get("generated_images") or [] | |
| # Build map mode -> path | |
| mode_to_path = {} | |
| for it in gen_list: | |
| try: | |
| mode = str(it.get("mode", "")).strip().lower() | |
| path = it.get("path") | |
| except Exception: | |
| continue | |
| if not mode or not path: | |
| continue | |
| mode_to_path[mode] = path | |
| # Weighted selection among available modes | |
| weighted_order = [ | |
| ("grayscale", 0.5), | |
| ("low", 0.3), | |
| ("high", 0.2), | |
| ] | |
| # Filter to available | |
| available = [(m, w) for (m, w) in weighted_order if m in mode_to_path] | |
| chosen_path = None | |
| if available: | |
| rnd = random.random() | |
| cum = 0.0 | |
| total_w = sum(w for _, w in available) | |
| for m, w in available: | |
| cum += w / total_w | |
| if rnd <= cum: | |
| chosen_path = mode_to_path.get(m) | |
| break | |
| if chosen_path is None: | |
| chosen_path = mode_to_path.get(available[-1][0]) | |
| else: | |
| # Fallback to any provided path | |
| if mode_to_path: | |
| chosen_path = next(iter(mode_to_path.values())) | |
| # Open chosen image | |
| if chosen_path is not None: | |
| try: | |
| open_path = chosen_path | |
| # generated paths are typically absolute; if not, use as-is | |
| img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR) | |
| return img | |
| except Exception: | |
| pass | |
| return default_img | |
| # 5) Choose base image with probabilities: | |
| # 20%: original, 20%: color augment(original), 60%: relight augment | |
| rsel = random.random() | |
| if rsel < 0.2: | |
| base_img = foreground_source_img | |
| elif rsel < 0.4: | |
| base_img = _color_augment(foreground_source_img) | |
| else: | |
| base_img = _choose_relight_image(rec, foreground_source_img, fw, fh) | |
| base_np = np.array(base_img) | |
| # 5.1) Random perspective augmentation (20%): apply to foreground ROI (mask bbox) and its mask only | |
| perspective_applied = False | |
| roi_update = None | |
| paste_mask_bool = mask_bin.astype(bool) | |
| if random.random() < 0.5: | |
| try: | |
| import cv2 | |
| ys, xs = np.where(mask_bin > 0) | |
| if len(xs) > 0 and len(ys) > 0: | |
| x1, x2 = int(xs.min()), int(xs.max()) | |
| y1, y2 = int(ys.min()), int(ys.max()) | |
| if x2 > x1 and y2 > y1: | |
| roi = base_np[y1:y2 + 1, x1:x2 + 1] | |
| roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1] | |
| bh, bw = roi.shape[:2] | |
| # Random perturbation relative to ROI size | |
| max_ratio = random.uniform(0.1, 0.3) | |
| dx = bw * max_ratio | |
| dy = bh * max_ratio | |
| src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]]) | |
| dst = np.float32([ | |
| [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)], | |
| [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)], | |
| [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)], | |
| [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)], | |
| ]) | |
| M = cv2.getPerspectiveTransform(src, dst) | |
| warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101) | |
| warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127 | |
| # Build a fresh foreground canvas | |
| fore_np = np.zeros_like(base_np) | |
| h_warp, w_warp = warped_mask_roi.shape | |
| y2c = y1 + h_warp | |
| x2c = x1 + w_warp | |
| fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi] | |
| paste_mask_bool = np.zeros_like(mask_bin, dtype=bool) | |
| paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi | |
| roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi) | |
| perspective_applied = True | |
| base_np = fore_np | |
| except Exception: | |
| perspective_applied = False | |
| paste_mask_bool = mask_bin.astype(bool) | |
| # Optional: simulate cut-out foregrounds coming from different resolutions by | |
| # downscaling the masked foreground region and upscaling back to original size. | |
| # This introduces realistic blur/aliasing seen in real inpaint workflows. | |
| if random.random() < 0.7: | |
| ys, xs = np.where(mask_bin > 0) | |
| if len(xs) > 0 and len(ys) > 0: | |
| x1, x2 = int(xs.min()), int(xs.max()) | |
| y1, y2 = int(ys.min()), int(ys.max()) | |
| # Ensure valid box | |
| if x2 > x1 and y2 > y1: | |
| crop = base_np[y1:y2 + 1, x1:x2 + 1] | |
| ch, cw = crop.shape[:2] | |
| scale = random.uniform(0.2, 0.9) | |
| dw = max(1, int(cw * scale)) | |
| dh = max(1, int(ch * scale)) | |
| try: | |
| small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC) | |
| back = small.resize((cw, ch), Image.BICUBIC) | |
| crop_blurred = np.array(back).astype(np.uint8) | |
| base_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred | |
| except Exception: | |
| # Fallback: skip if resize fails | |
| pass | |
| # 6) Build masked target using (possibly) updated union mask; then paste | |
| union_mask_for_target = union_mask.copy() | |
| if roi_update is not None: | |
| rx, ry, rh, rw, warped_mask_roi = roi_update | |
| # Ensure union covers the warped foreground area inside ROI using warped shape | |
| um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw] | |
| union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1) | |
| masked_t_np = tgt_np.copy() | |
| masked_t_np[union_mask_for_target.astype(bool)] = 255 | |
| composed_np = masked_t_np.copy() | |
| m_fore = paste_mask_bool | |
| composed_np[m_fore] = base_np[m_fore] | |
| # 7) Build tensors | |
| source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8))) | |
| mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0) | |
| # 8) Caption: prepend instruction, 20% keep only instruction | |
| cap_orig = rec.get("prompt", "") or "" | |
| cap = _prepend_caption(cap_orig) | |
| if perspective_applied: | |
| cap = f"{cap} Fix the perspective if necessary." | |
| ids1, ids2 = _tokenize(tokenizers, cap) | |
| return { | |
| "source_pixel_values": source_tensor, | |
| "pixel_values": target_tensor, | |
| "token_ids_clip": ids1, | |
| "token_ids_t5": ids2, | |
| "mask_values": mask_tensor, | |
| } | |
| return SubjectsDataset(records, base_dir) | |
| def make_pexels_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None): | |
| """ | |
| Dataset for JSONL with fields: | |
| - input_path: relative to base_dir (target image) | |
| - output_path: relative to relight_base_dir (relighted image) | |
| - final_size: {width, height} resize applied | |
| - caption: text caption | |
| Modified to use segmentation maps instead of raw_mask_path. | |
| Randomly selects 2-5 segments from segmentation map, applies augmentation to each, and takes union. | |
| This simulates multiple foreground objects being placed like a puzzle. | |
| Each segment independently uses: 20% original, 20% color_augment, 60% relighted image. | |
| Args: | |
| base_dir: Base directory for resolving relative paths. If None, uses args.pexels_base_dir. | |
| """ | |
| if base_dir is None: | |
| base_dir = getattr(args, "pexels_base_dir", "/mnt/robby-b1/common/datasets") | |
| relight_base_dir = getattr(args, "pexels_relight_base_dir", "/robby/share/Editing/lzc/data/relight_outputs") | |
| seg_base_dir = getattr(args, "seg_base_dir", "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182") | |
| data_files = _resolve_jsonl(getattr(args, "pexels_data_jsonl", None)) | |
| file_paths = data_files.get("train", []) | |
| records = [] | |
| for p in file_paths: | |
| with open(p, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| try: | |
| obj = json.loads(line.rstrip(",")) | |
| except Exception: | |
| continue | |
| pruned = { | |
| "input_path": obj.get("input_path"), | |
| "output_path": obj.get("output_path"), | |
| "final_size": obj.get("final_size"), | |
| "caption": obj.get("caption"), | |
| } | |
| records.append(pruned) | |
| to_tensor_and_norm = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| class PexelsDataset(torch.utils.data.Dataset): | |
| def __init__(self, hf_ds, base_dir, relight_base_dir, seg_base_dir): | |
| self.ds = hf_ds | |
| self.base_dir = base_dir | |
| self.relight_base_dir = relight_base_dir | |
| self.seg_base_dir = seg_base_dir | |
| def __len__(self): | |
| return len(self.ds) | |
| def _extract_hash_from_filename(self, filename: str) -> str: | |
| """Extract hash from input filename for segmentation map lookup.""" | |
| stem = os.path.splitext(os.path.basename(filename))[0] | |
| if '_' in stem: | |
| parts = stem.split('_') | |
| return parts[-1] | |
| return stem | |
| def _build_segmap_path(self, input_filename: str) -> str: | |
| """Build path to segmentation map from input filename.""" | |
| hash_id = self._extract_hash_from_filename(input_filename) | |
| return os.path.join(self.seg_base_dir, f"{hash_id}_map.png") | |
| def _load_segmap_uint32(self, seg_path: str): | |
| """Load segmentation map as uint32 array.""" | |
| import cv2 | |
| try: | |
| with Image.open(seg_path) as im: | |
| if im.mode == 'P': | |
| seg = np.array(im) | |
| elif im.mode in ('I;16', 'I', 'L'): | |
| seg = np.array(im) | |
| else: | |
| seg = np.array(im.convert('L')) | |
| except Exception: | |
| return None | |
| if seg.ndim == 3: | |
| seg = cv2.cvtColor(seg, cv2.COLOR_BGR2GRAY) | |
| return seg.astype(np.uint32) | |
| def _extract_multiple_segments( | |
| self, | |
| image_h: int, | |
| image_w: int, | |
| seg_path: str, | |
| min_area_ratio: float = 0.02, | |
| max_area_ratio: float = 0.4, | |
| ): | |
| """Extract 2-5 individual segment masks from segmentation map.""" | |
| import cv2 | |
| seg = self._load_segmap_uint32(seg_path) | |
| if seg is None: | |
| return [] | |
| if seg.shape != (image_h, image_w): | |
| seg = cv2.resize(seg.astype(np.uint16), (image_w, image_h), interpolation=cv2.INTER_NEAREST).astype(np.uint32) | |
| labels, counts = np.unique(seg, return_counts=True) | |
| if labels.size == 0: | |
| return [] | |
| # Exclude background label 0 | |
| bg_mask = labels == 0 | |
| labels = labels[~bg_mask] | |
| counts = counts[~bg_mask] | |
| if labels.size == 0: | |
| return [] | |
| area = image_h * image_w | |
| min_px = int(round(min_area_ratio * area)) | |
| max_px = int(round(max_area_ratio * area)) | |
| keep = (counts >= min_px) & (counts <= max_px) | |
| cand_labels = labels[keep] | |
| if cand_labels.size == 0: | |
| return [] | |
| # Select 2-5 labels randomly | |
| max_sel = min(5, cand_labels.size) | |
| min_sel = min(2, cand_labels.size) | |
| num_to_select = random.randint(min_sel, max_sel) | |
| chosen = np.random.choice(cand_labels, size=num_to_select, replace=False) | |
| # Create individual masks for each chosen label | |
| individual_masks = [] | |
| for lab in chosen: | |
| binm = (seg == int(lab)).astype(np.uint8) | |
| # Apply opening operation to clean up mask | |
| k = max(3, int(round(max(image_h, image_w) * 0.01))) | |
| if k % 2 == 0: | |
| k += 1 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) | |
| eroded = cv2.erode(binm, kernel, iterations=1) | |
| opened = cv2.dilate(eroded, kernel, iterations=1) | |
| individual_masks.append(opened) | |
| return individual_masks | |
| def __getitem__(self, idx): | |
| rec = self.ds[idx % len(self.ds)] | |
| t_rel = rec.get("input_path", "") | |
| r_rel = rec.get("output_path", "") | |
| t_p = os.path.join(self.base_dir, t_rel) | |
| relight_p = os.path.join(self.relight_base_dir, r_rel) | |
| import cv2 | |
| tgt_img = Image.open(t_p).convert("RGB") | |
| # Load relighted image, fallback to target if not available | |
| try: | |
| relighted_img = Image.open(relight_p).convert("RGB") | |
| except Exception: | |
| relighted_img = tgt_img.copy() | |
| final_size = rec.get("final_size", {}) or {} | |
| fw = int(final_size.get("width", tgt_img.width)) | |
| fh = int(final_size.get("height", tgt_img.height)) | |
| tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR) | |
| relighted_img = relighted_img.resize((fw, fh), resample=Image.BILINEAR) | |
| target_tensor = to_tensor_and_norm(tgt_img) | |
| # Build segmentation map path and extract multiple segments | |
| input_filename = os.path.basename(t_rel) | |
| seg_path = self._build_segmap_path(input_filename) | |
| individual_masks = self._extract_multiple_segments(fh, fw, seg_path) | |
| if not individual_masks: | |
| # Fallback: create empty mask (will be handled gracefully) | |
| union_mask = np.zeros((fh, fw), dtype=np.uint8) | |
| individual_masks = [] | |
| else: | |
| # Apply augmentation to each segment mask and take union | |
| augmented_masks = [] | |
| for seg_mask in individual_masks: | |
| # 1) Grow mask by random 50-200 pixels | |
| grown = _dilate_mask(seg_mask, 50, 200) | |
| # 2) Optional random augmentation mask constrained by this segment | |
| rand_mask = _constrained_random_mask(seg_mask, fh, fw, aug_prob=0.7) | |
| # 3) Union for this segment | |
| seg_union = np.clip(grown | rand_mask, 0, 1).astype(np.uint8) | |
| augmented_masks.append(seg_union) | |
| # Take union of all augmented masks | |
| union_mask = np.zeros((fh, fw), dtype=np.uint8) | |
| for m in augmented_masks: | |
| union_mask = np.clip(union_mask | m, 0, 1).astype(np.uint8) | |
| tgt_np = np.array(tgt_img) | |
| # Build masked target first | |
| masked_t_np = tgt_np.copy() | |
| masked_t_np[union_mask.astype(bool)] = 255 | |
| composed_np = masked_t_np.copy() | |
| # Process each segment independently with different augmentations | |
| # This simulates multiple foreground objects from different sources | |
| for seg_mask in individual_masks: | |
| # 1) Choose source for this segment: 20% original, 20% color_augment, 60% relighted | |
| r = random.random() | |
| if r < 0.2: | |
| # Original image | |
| seg_source_img = tgt_img | |
| else: | |
| seg_source_img = _color_augment(tgt_img) | |
| # elif r < 0.4: | |
| # # Color augmentation | |
| # seg_source_img = _color_augment(tgt_img) | |
| # else: | |
| # # Relighted image | |
| # seg_source_img = relighted_img | |
| seg_source_np = np.array(seg_source_img) | |
| # 2) Apply resolution augmentation to this segment's region | |
| if random.random() < 0.7: | |
| ys, xs = np.where(seg_mask > 0) | |
| if len(xs) > 0 and len(ys) > 0: | |
| x1, x2 = int(xs.min()), int(xs.max()) | |
| y1, y2 = int(ys.min()), int(ys.max()) | |
| if x2 > x1 and y2 > y1: | |
| crop = seg_source_np[y1:y2 + 1, x1:x2 + 1] | |
| ch, cw = crop.shape[:2] | |
| scale = random.uniform(0.2, 0.9) | |
| dw = max(1, int(cw * scale)) | |
| dh = max(1, int(ch * scale)) | |
| try: | |
| small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC) | |
| back = small.resize((cw, ch), Image.BICUBIC) | |
| crop_blurred = np.array(back).astype(np.uint8) | |
| seg_source_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred | |
| except Exception: | |
| pass | |
| # 3) Paste this segment onto composed image | |
| m_fore = seg_mask.astype(bool) | |
| composed_np[m_fore] = seg_source_np[m_fore] | |
| # Build tensors | |
| source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8))) | |
| mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0) | |
| # Caption: prepend instruction | |
| cap_orig = rec.get("caption", "") or "" | |
| cap = _prepend_caption(cap_orig) | |
| ids1, ids2 = _tokenize(tokenizers, cap) | |
| return { | |
| "source_pixel_values": source_tensor, | |
| "pixel_values": target_tensor, | |
| "token_ids_clip": ids1, | |
| "token_ids_t5": ids2, | |
| "mask_values": mask_tensor, | |
| } | |
| return PexelsDataset(records, base_dir, relight_base_dir, seg_base_dir) | |
| def make_mixed_dataset(args, tokenizers, interactive_jsonl_path=None, placement_jsonl_path=None, | |
| pexels_jsonl_path=None, interactive_base_dir=None, placement_base_dir=None, | |
| pexels_base_dir=None, interactive_weight=1.0, placement_weight=1.0, | |
| pexels_weight=1.0, accelerator=None): | |
| """ | |
| Create a mixed dataset combining interactive, placement, and pexels datasets. | |
| Args: | |
| args: Arguments object with dataset configuration | |
| tokenizers: Tuple of tokenizers for text encoding | |
| interactive_jsonl_path: Path to interactive dataset JSONL (optional) | |
| placement_jsonl_path: Path to placement dataset JSONL (optional) | |
| pexels_jsonl_path: Path to pexels dataset JSONL (optional) | |
| interactive_base_dir: Base directory for interactive dataset paths (optional) | |
| placement_base_dir: Base directory for placement dataset paths (optional) | |
| pexels_base_dir: Base directory for pexels dataset paths (optional) | |
| interactive_weight: Sampling weight for interactive dataset (default: 1.0) | |
| placement_weight: Sampling weight for placement dataset (default: 1.0) | |
| pexels_weight: Sampling weight for pexels dataset (default: 1.0) | |
| accelerator: Optional accelerator object | |
| Returns: | |
| Mixed dataset that samples from all provided datasets with specified weights | |
| """ | |
| datasets = [] | |
| dataset_names = [] | |
| dataset_weights = [] | |
| # Create interactive dataset if path provided | |
| if interactive_jsonl_path: | |
| interactive_args = type('Args', (), {})() | |
| for k, v in vars(args).items(): | |
| setattr(interactive_args, k, v) | |
| interactive_args.train_data_jsonl = interactive_jsonl_path | |
| if interactive_base_dir: | |
| interactive_args.interactive_base_dir = interactive_base_dir | |
| interactive_ds = make_interactive_dataset_subjects(interactive_args, tokenizers, accelerator) | |
| datasets.append(interactive_ds) | |
| dataset_names.append("interactive") | |
| dataset_weights.append(interactive_weight) | |
| # Create placement dataset if path provided | |
| if placement_jsonl_path: | |
| placement_args = type('Args', (), {})() | |
| for k, v in vars(args).items(): | |
| setattr(placement_args, k, v) | |
| placement_args.placement_data_jsonl = placement_jsonl_path | |
| if placement_base_dir: | |
| placement_args.placement_base_dir = placement_base_dir | |
| placement_ds = make_placement_dataset_subjects(placement_args, tokenizers, accelerator) | |
| datasets.append(placement_ds) | |
| dataset_names.append("placement") | |
| dataset_weights.append(placement_weight) | |
| # Create pexels dataset if path provided | |
| if pexels_jsonl_path: | |
| pexels_args = type('Args', (), {})() | |
| for k, v in vars(args).items(): | |
| setattr(pexels_args, k, v) | |
| pexels_args.pexels_data_jsonl = pexels_jsonl_path | |
| if pexels_base_dir: | |
| pexels_args.pexels_base_dir = pexels_base_dir | |
| pexels_ds = make_pexels_dataset_subjects(pexels_args, tokenizers, accelerator) | |
| datasets.append(pexels_ds) | |
| dataset_names.append("pexels") | |
| dataset_weights.append(pexels_weight) | |
| if not datasets: | |
| raise ValueError("At least one dataset path must be provided") | |
| if len(datasets) == 1: | |
| return datasets[0] | |
| # Mixed dataset class with balanced sampling (based on smallest dataset) | |
| class MixedDataset(torch.utils.data.Dataset): | |
| def __init__(self, datasets, dataset_names, dataset_weights): | |
| self.datasets = datasets | |
| self.dataset_names = dataset_names | |
| self.lengths = [len(ds) for ds in datasets] | |
| # Normalize weights | |
| total_weight = sum(dataset_weights) | |
| self.weights = [w / total_weight for w in dataset_weights] | |
| # Calculate samples per dataset based on smallest dataset and weights | |
| # Find the minimum weighted size | |
| min_weighted_size = min(length / weight for length, weight in zip(self.lengths, dataset_weights)) | |
| # Each dataset contributes samples proportional to its weight, scaled by min_weighted_size | |
| self.samples_per_dataset = [int(min_weighted_size * w) for w in dataset_weights] | |
| self.total_length = sum(self.samples_per_dataset) | |
| # Build cumulative sample counts for indexing | |
| self.cumsum_samples = [0] | |
| for count in self.samples_per_dataset: | |
| self.cumsum_samples.append(self.cumsum_samples[-1] + count) | |
| print(f"Balanced mixed dataset created:") | |
| for i, name in enumerate(dataset_names): | |
| print(f" {name}: {self.lengths[i]} total, {self.samples_per_dataset[i]} per epoch") | |
| print(f" Total samples per epoch: {self.total_length}") | |
| def __len__(self): | |
| return self.total_length | |
| def __getitem__(self, idx): | |
| # Determine which dataset this idx belongs to | |
| dataset_idx = 0 | |
| for i in range(len(self.cumsum_samples) - 1): | |
| if self.cumsum_samples[i] <= idx < self.cumsum_samples[i + 1]: | |
| dataset_idx = i | |
| break | |
| # Randomly sample from the selected dataset (enables different samples each epoch) | |
| local_idx = random.randint(0, self.lengths[dataset_idx] - 1) | |
| sample = self.datasets[dataset_idx][local_idx] | |
| # Add dataset source information | |
| sample["dataset_source"] = self.dataset_names[dataset_idx] | |
| return sample | |
| return MixedDataset(datasets, dataset_names, dataset_weights) | |
| def _run_test_mode( | |
| interactive_jsonl: str = None, | |
| placement_jsonl: str = None, | |
| pexels_jsonl: str = None, | |
| interactive_base_dir: str = None, | |
| placement_base_dir: str = None, | |
| pexels_base_dir: str = None, | |
| pexels_relight_base_dir: str = None, | |
| seg_base_dir: str = None, | |
| interactive_weight: float = 1.0, | |
| placement_weight: float = 1.0, | |
| pexels_weight: float = 1.0, | |
| output_dir: str = "test_output", | |
| num_samples: int = 100 | |
| ): | |
| """Test dataset by saving samples with source labels. | |
| Args: | |
| interactive_jsonl: Path to interactive dataset JSONL (optional) | |
| placement_jsonl: Path to placement dataset JSONL (optional) | |
| pexels_jsonl: Path to pexels dataset JSONL (optional) | |
| interactive_base_dir: Base directory for interactive dataset | |
| placement_base_dir: Base directory for placement dataset | |
| pexels_base_dir: Base directory for pexels dataset | |
| pexels_relight_base_dir: Base directory for pexels relighted images | |
| seg_base_dir: Directory containing segmentation maps for pexels dataset | |
| interactive_weight: Sampling weight for interactive dataset (default: 1.0) | |
| placement_weight: Sampling weight for placement dataset (default: 1.0) | |
| pexels_weight: Sampling weight for pexels dataset (default: 1.0) | |
| output_dir: Output directory for test images | |
| num_samples: Number of samples to save | |
| """ | |
| if not interactive_jsonl and not placement_jsonl and not pexels_jsonl: | |
| raise ValueError("At least one dataset path must be provided") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create dummy tokenizers for testing | |
| class DummyTokenizer: | |
| def __call__(self, text, **kwargs): | |
| class Result: | |
| input_ids = torch.zeros(1, 77, dtype=torch.long) | |
| return Result() | |
| tokenizers = (DummyTokenizer(), DummyTokenizer()) | |
| # Create args object | |
| class Args: | |
| cond_size = 512 | |
| args = Args() | |
| args.train_data_jsonl = interactive_jsonl | |
| args.placement_data_jsonl = placement_jsonl | |
| args.pexels_data_jsonl = pexels_jsonl | |
| args.interactive_base_dir = interactive_base_dir | |
| args.placement_base_dir = placement_base_dir | |
| args.pexels_base_dir = pexels_base_dir | |
| args.pexels_relight_base_dir = pexels_relight_base_dir if pexels_relight_base_dir else "/robby/share/Editing/lzc/data/relight_outputs" | |
| args.seg_base_dir = seg_base_dir if seg_base_dir else "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182" | |
| # Create dataset (single or mixed) | |
| try: | |
| # Count how many datasets are provided | |
| num_datasets = sum([bool(interactive_jsonl), bool(placement_jsonl), bool(pexels_jsonl)]) | |
| if num_datasets > 1: | |
| dataset = make_mixed_dataset( | |
| args, tokenizers, | |
| interactive_jsonl_path=interactive_jsonl, | |
| placement_jsonl_path=placement_jsonl, | |
| pexels_jsonl_path=pexels_jsonl, | |
| interactive_base_dir=args.interactive_base_dir, | |
| placement_base_dir=args.placement_base_dir, | |
| pexels_base_dir=args.pexels_base_dir, | |
| interactive_weight=interactive_weight, | |
| placement_weight=placement_weight, | |
| pexels_weight=pexels_weight | |
| ) | |
| print(f"Created mixed dataset with {len(dataset)} samples") | |
| weights_str = [] | |
| if interactive_jsonl: | |
| weights_str.append(f"Interactive: {interactive_weight:.2f}") | |
| if placement_jsonl: | |
| weights_str.append(f"Placement: {placement_weight:.2f}") | |
| if pexels_jsonl: | |
| weights_str.append(f"Pexels: {pexels_weight:.2f}") | |
| print(f"Sampling weights - {', '.join(weights_str)}") | |
| elif pexels_jsonl: | |
| dataset = make_pexels_dataset_subjects(args, tokenizers, base_dir=pexels_base_dir) | |
| print(f"Created pexels dataset with {len(dataset)} samples") | |
| elif placement_jsonl: | |
| dataset = make_placement_dataset_subjects(args, tokenizers, base_dir=args.placement_base_dir) | |
| print(f"Created placement dataset with {len(dataset)} samples") | |
| else: | |
| dataset = make_interactive_dataset_subjects(args, tokenizers, base_dir=args.interactive_base_dir) | |
| print(f"Created interactive dataset with {len(dataset)} samples") | |
| except Exception as e: | |
| print(f"Failed to create dataset: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Sample and save | |
| saved = 0 | |
| counts = {} | |
| for attempt in range(min(num_samples * 3, len(dataset))): | |
| try: | |
| idx = random.randint(0, len(dataset) - 1) | |
| sample = dataset[idx] | |
| source_name = sample.get("dataset_source", "single") | |
| counts[source_name] = counts.get(source_name, 0) + 1 | |
| # Denormalize tensors from [-1, 1] to [0, 255] | |
| source_np = ((sample["source_pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) | |
| target_np = ((sample["pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) | |
| # Save images | |
| idx_str = f"{saved:05d}" | |
| Image.fromarray(source_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_source.jpg")) | |
| Image.fromarray(target_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_target.jpg")) | |
| saved += 1 | |
| if saved % 10 == 0: | |
| print(f"Saved {saved}/{num_samples} samples - {counts}") | |
| if saved >= num_samples: | |
| break | |
| except Exception as e: | |
| print(f"Failed to process sample: {e}") | |
| continue | |
| print(f"\nTest complete. Saved {saved} samples to {output_dir}") | |
| print(f"Distribution: {counts}") | |
| def _parse_test_args(): | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Test visualization for Kontext datasets") | |
| parser.add_argument("--interactive_jsonl", type=str, default="/robby/share/Editing/lzc/HOI_v1/final_metadata.jsonl", | |
| help="Path to interactive dataset JSONL") | |
| parser.add_argument("--placement_jsonl", type=str, default="/robby/share/Editing/lzc/subject_placement/metadata_relight.jsonl", | |
| help="Path to placement dataset JSONL") | |
| parser.add_argument("--pexels_jsonl", type=str, default=None, | |
| help="Path to pexels dataset JSONL") | |
| parser.add_argument("--interactive_base_dir", type=str, default="/robby/share/Editing/lzc/HOI_v1", | |
| help="Base directory for interactive dataset") | |
| parser.add_argument("--placement_base_dir", type=str, default=None, | |
| help="Base directory for placement dataset") | |
| parser.add_argument("--pexels_base_dir", type=str, default=None, | |
| help="Base directory for pexels dataset") | |
| parser.add_argument("--pexels_relight_base_dir", type=str, default="/robby/share/Editing/lzc/data/relight_outputs", | |
| help="Base directory for pexels relighted images") | |
| parser.add_argument("--seg_base_dir", type=str, default=None, | |
| help="Directory containing segmentation maps for pexels dataset") | |
| parser.add_argument("--interactive_weight", type=float, default=1.0, | |
| help="Sampling weight for interactive dataset (default: 1.0)") | |
| parser.add_argument("--placement_weight", type=float, default=1.0, | |
| help="Sampling weight for placement dataset (default: 1.0)") | |
| parser.add_argument("--pexels_weight", type=float, default=0, | |
| help="Sampling weight for pexels dataset (default: 1.0)") | |
| parser.add_argument("--output_dir", type=str, default="visualize_output", | |
| help="Output directory to save pairs") | |
| parser.add_argument("--num_samples", type=int, default=100, | |
| help="Number of pairs to save") | |
| # Legacy arguments | |
| parser.add_argument("--test_jsonl", type=str, default=None, | |
| help="Legacy: Path to JSONL (uses as interactive_jsonl)") | |
| parser.add_argument("--base_dir", type=str, default=None, | |
| help="Legacy: Base directory (uses as interactive_base_dir)") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| try: | |
| args = _parse_test_args() | |
| # Handle legacy args | |
| interactive_jsonl = args.interactive_jsonl or args.test_jsonl | |
| interactive_base_dir = args.interactive_base_dir or args.base_dir | |
| _run_test_mode( | |
| interactive_jsonl=interactive_jsonl, | |
| placement_jsonl=args.placement_jsonl, | |
| pexels_jsonl=args.pexels_jsonl, | |
| interactive_base_dir=interactive_base_dir, | |
| placement_base_dir=args.placement_base_dir, | |
| pexels_base_dir=args.pexels_base_dir, | |
| pexels_relight_base_dir=args.pexels_relight_base_dir, | |
| seg_base_dir=args.seg_base_dir, | |
| interactive_weight=args.interactive_weight, | |
| placement_weight=args.placement_weight, | |
| pexels_weight=args.pexels_weight, | |
| output_dir=args.output_dir, | |
| num_samples=args.num_samples | |
| ) | |
| except SystemExit: | |
| # Allow import usage without triggering test mode | |
| pass | |