MagicQuillV2 / train /src /jsonl_datasets_kontext_interactive_lora.py
LiuZichen's picture
update
f460ce6
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF
import random
import torch
import os
from datasets import load_dataset
import numpy as np
import json
Image.MAX_IMAGE_PIXELS = None
def collate_fn(examples):
if examples[0].get("cond_pixel_values") is not None:
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
else:
cond_pixel_values = None
if examples[0].get("source_pixel_values") is not None:
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
else:
source_pixel_values = None
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
mask_values = None
if examples[0].get("mask_values") is not None:
mask_values = torch.stack([example["mask_values"] for example in examples])
mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
return {
"cond_pixel_values": cond_pixel_values,
"source_pixel_values": source_pixel_values,
"pixel_values": target_pixel_values,
"text_ids_1": token_ids_clip,
"text_ids_2": token_ids_t5,
"mask_values": mask_values,
}
def _resolve_jsonl(path_str: str):
if path_str is None or str(path_str).strip() == "":
raise ValueError("train_data_jsonl is empty. Please set --train_data_jsonl to a JSON/JSONL file or a folder.")
if os.path.isdir(path_str):
files = [
os.path.join(path_str, f)
for f in os.listdir(path_str)
if f.lower().endswith((".jsonl", ".json"))
]
if not files:
raise ValueError(f"No .json or .jsonl files found under directory: {path_str}")
return {"train": sorted(files)}
if not os.path.exists(path_str):
raise FileNotFoundError(f"train_data_jsonl not found: {path_str}")
return {"train": [path_str]}
def _tokenize(tokenizers, caption: str):
tokenizer_clip = tokenizers[0]
tokenizer_t5 = tokenizers[1]
text_inputs_clip = tokenizer_clip(
[caption], padding="max_length", max_length=77, truncation=True, return_tensors="pt"
)
text_inputs_t5 = tokenizer_t5(
[caption], padding="max_length", max_length=128, truncation=True, return_tensors="pt"
)
return text_inputs_clip.input_ids[0], text_inputs_t5.input_ids[0]
def _prepend_caption(caption: str) -> str:
"""Prepend instruction and keep only instruction with 20% prob."""
instruction = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary."
if random.random() < 0.2:
return instruction
caption = caption or ""
if caption.strip():
return f"{instruction} {caption.strip()}"
return instruction
def _color_augment(pil_img: Image.Image) -> Image.Image:
brightness = random.uniform(0.8, 1.2)
contrast = random.uniform(0.8, 1.2)
saturation = random.uniform(0.8, 1.2)
hue = random.uniform(-0.05, 0.05)
img = TF.adjust_brightness(pil_img, brightness)
img = TF.adjust_contrast(img, contrast)
img = TF.adjust_saturation(img, saturation)
img = TF.adjust_hue(img, hue)
return img
def _dilate_mask(mask_bin: np.ndarray, min_px: int = 5, max_px: int = 100) -> np.ndarray:
"""Grow binary mask by a random radius in [min_px, max_px]. Expects values {0,1}."""
import cv2
radius = int(max(min_px, min(max_px, random.randint(min_px, max_px))))
if radius <= 0:
return mask_bin.astype(np.uint8)
ksize = 2 * radius + 1
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
grown = cv2.dilate(mask_bin.astype(np.uint8), kernel, iterations=1)
return (grown > 0).astype(np.uint8)
def _random_point_inside_mask(mask_bin: np.ndarray) -> tuple:
ys, xs = np.where(mask_bin > 0)
if len(xs) == 0:
h, w = mask_bin.shape
return w // 2, h // 2
idx = random.randrange(len(xs))
return int(xs[idx]), int(ys[idx])
def _bbox_containing_mask(mask_bin: np.ndarray, img_w: int, img_h: int) -> tuple:
ys, xs = np.where(mask_bin > 0)
if len(xs) == 0:
return 0, 0, img_w - 1, img_h - 1
x1, x2 = int(xs.min()), int(xs.max())
y1, y2 = int(ys.min()), int(ys.max())
# Random padding
max_pad = int(0.25 * min(img_w, img_h))
pad_x1 = random.randint(0, max_pad)
pad_x2 = random.randint(0, max_pad)
pad_y1 = random.randint(0, max_pad)
pad_y2 = random.randint(0, max_pad)
x1 = max(0, x1 - pad_x1)
y1 = max(0, y1 - pad_y1)
x2 = min(img_w - 1, x2 + pad_x2)
y2 = min(img_h - 1, y2 + pad_y2)
return x1, y1, x2, y2
def _constrained_random_mask(mask_bin: np.ndarray, image_h: int, image_w: int, aug_prob: float = 0.7) -> np.ndarray:
"""Generate random mask whose box contains or starts in m_p, and brush strokes start inside m_p.
Returns binary 0/1 array of shape (H,W).
"""
import cv2
if random.random() >= aug_prob:
return np.zeros((image_h, image_w), dtype=np.uint8)
# Scale similar to reference
ref_size = 1024
scale_factor = max(1.0, min(image_h, image_w) / float(ref_size))
out = np.zeros((image_h, image_w), dtype=np.uint8)
# Choose exactly one augmentation: bbox OR stroke
if random.random() < 0.2:
# BBox augmentation: draw N boxes (randomized), first box often contains mask
num_boxes = random.randint(1, 6)
for b in range(num_boxes):
if b == 0 and random.random() < 0.5:
x1, y1, x2, y2 = _bbox_containing_mask(mask_bin, image_w, image_h)
else:
sx, sy = _random_point_inside_mask(mask_bin)
max_w = int(500 * scale_factor)
min_w = int(100 * scale_factor)
bw = random.randint(max(1, min_w), max(2, max_w))
bh = random.randint(max(1, min_w), max(2, max_w))
x1 = max(0, sx - random.randint(0, bw))
y1 = max(0, sy - random.randint(0, bh))
x2 = min(image_w - 1, x1 + bw)
y2 = min(image_h - 1, y1 + bh)
out[y1:y2 + 1, x1:x2 + 1] = 1
else:
# Stroke augmentation: draw N strokes starting inside mask
num_strokes = random.randint(1, 6)
for _ in range(num_strokes):
num_points = random.randint(10, 30)
stroke_width = random.randint(max(1, int(100 * scale_factor)), max(2, int(400 * scale_factor)))
max_offset = max(1, int(100 * scale_factor))
start_x, start_y = _random_point_inside_mask(mask_bin)
px, py = start_x, start_y
for _ in range(num_points):
dx = random.randint(-max_offset, max_offset)
dy = random.randint(-max_offset, max_offset)
nx = int(np.clip(px + dx, 0, image_w - 1))
ny = int(np.clip(py + dy, 0, image_h - 1))
cv2.line(out, (px, py), (nx, ny), 1, stroke_width)
px, py = nx, ny
return (out > 0).astype(np.uint8)
def make_placement_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
"""
Dataset for JSONL with fields:
- generated_image_path: relative to base_dir (target image with object)
- mask_path: relative to base_dir (mask of object)
- generated_width, generated_height: image dimensions
- final_prompt: caption
- relight_images: list of {mode, path} for relighted versions
source image construction:
- background is target_image with a hole punched by grown mask
- foreground is randomly selected from relight_images with weights
- includes perspective transformation (moved from interactive dataset)
Args:
base_dir: Base directory for resolving relative paths. If None, uses args.placement_base_dir.
"""
if base_dir is None:
base_dir = getattr(args, "placement_base_dir")
data_files = _resolve_jsonl(getattr(args, "placement_data_jsonl", None))
file_paths = data_files.get("train", [])
records = []
for p in file_paths:
with open(p, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except Exception:
try:
obj = json.loads(line.rstrip(","))
except Exception:
continue
# Keep only fields we need
pruned = {
"generated_image_path": obj.get("generated_image_path"),
"mask_path": obj.get("mask_path"),
"generated_width": obj.get("generated_width"),
"generated_height": obj.get("generated_height"),
"final_prompt": obj.get("final_prompt"),
"relight_images": obj.get("relight_images"),
}
records.append(pruned)
size = int(getattr(args, "cond_size", 512))
to_tensor_and_norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
class PlacementDataset(torch.utils.data.Dataset):
def __init__(self, hf_ds, base_dir):
self.ds = hf_ds
self.base_dir = base_dir
def __len__(self):
# Triplicate sampling per record
return len(self.ds)
def __getitem__(self, idx):
rec = self.ds[idx % len(self.ds)]
t_rel = rec.get("generated_image_path", "")
m_rel = rec.get("mask_path", "")
# Both are relative paths
t_p = os.path.join(self.base_dir, t_rel)
m_p = os.path.join(self.base_dir, m_rel)
import cv2
mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
if mask_loaded is None:
raise ValueError(f"Failed to read mask: {m_p}")
tgt_img = Image.open(t_p).convert("RGB")
fw = int(rec.get("generated_width", tgt_img.width))
fh = int(rec.get("generated_height", tgt_img.height))
tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
target_tensor = to_tensor_and_norm(tgt_img)
# Binary mask at final_size
mask_np = np.array(mask_img)
mask_bin = (mask_np > 127).astype(np.uint8)
# 1) Grow mask by random 50-100 pixels
grown_mask = _dilate_mask(mask_bin, 50, 200)
# 2) Optional random augmentation mask constrained by mask
rand_mask = _constrained_random_mask(mask_bin, fh, fw, 7)
# 3) Final union mask
union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8)
tgt_np = np.array(tgt_img)
# Helper: choose relighted image from relight_images with weights
def _choose_relight_image(rec, width, height):
relight_list = rec.get("relight_images") or []
# Build map mode -> path
mode_to_path = {}
for it in relight_list:
try:
mode = str(it.get("mode", "")).strip().lower()
path = it.get("path")
except Exception:
continue
if not mode or not path:
continue
mode_to_path[mode] = path
weighted_order = [
("grayscale", 0.5),
("low", 0.3),
("high", 0.2),
]
# Filter to available
available = [(m, w) for (m, w) in weighted_order if m in mode_to_path]
chosen_path = None
if available:
rnd = random.random()
cum = 0.0
total_w = sum(w for _, w in available)
for m, w in available:
cum += w / total_w
if rnd <= cum:
chosen_path = mode_to_path.get(m)
break
if chosen_path is None:
chosen_path = mode_to_path.get(available[-1][0])
else:
# Fallback to any provided path
if mode_to_path:
chosen_path = next(iter(mode_to_path.values()))
# Open chosen image
if chosen_path is not None:
try:
# Paths are relative to base_dir
open_path = os.path.join(self.base_dir, chosen_path)
img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR)
return img
except Exception:
pass
# Fallback: return target image
return Image.open(t_p).convert("RGB").resize((width, height), resample=Image.BILINEAR)
# Choose base image with probabilities:
# 20%: original target, 20%: color augment(target), 60%: relight augment
rsel = random.random()
if rsel < 0.2:
base_img = tgt_img
elif rsel < 0.4:
base_img = _color_augment(tgt_img)
else:
base_img = _choose_relight_image(rec, fw, fh)
base_np = np.array(base_img)
fore_np = base_np.copy()
# Random perspective augmentation (50%): apply to foreground ROI (mask bbox) and its mask only
perspective_applied = False
roi_update = None
paste_mask_bool = mask_bin.astype(bool)
if random.random() < 0.5:
try:
import cv2
ys, xs = np.where(mask_bin > 0)
if len(xs) > 0 and len(ys) > 0:
x1, x2 = int(xs.min()), int(xs.max())
y1, y2 = int(ys.min()), int(ys.max())
if x2 > x1 and y2 > y1:
roi = base_np[y1:y2 + 1, x1:x2 + 1]
roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1]
bh, bw = roi.shape[:2]
# Random perturbation relative to ROI size
max_ratio = random.uniform(0.1, 0.3)
dx = bw * max_ratio
dy = bh * max_ratio
src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]])
dst = np.float32([
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
])
M = cv2.getPerspectiveTransform(src, dst)
warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101)
warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127
# Build a fresh foreground canvas
fore_np = np.zeros_like(base_np)
h_warp, w_warp = warped_mask_roi.shape
y2c = y1 + h_warp
x2c = x1 + w_warp
fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi]
paste_mask_bool = np.zeros_like(mask_bin, dtype=bool)
paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi
roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi)
perspective_applied = True
except Exception:
perspective_applied = False
paste_mask_bool = mask_bin.astype(bool)
fore_np = base_np
# Optional: simulate resolution artifacts
if random.random() < 0.7:
ys, xs = np.where(paste_mask_bool)
if len(xs) > 0 and len(ys) > 0:
x1, x2 = int(xs.min()), int(xs.max())
y1, y2 = int(ys.min()), int(ys.max())
if x2 > x1 and y2 > y1:
crop = fore_np[y1:y2 + 1, x1:x2 + 1]
ch, cw = crop.shape[:2]
scale = random.uniform(0.15, 0.9)
dw = max(1, int(cw * scale))
dh = max(1, int(ch * scale))
try:
small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
back = small.resize((cw, ch), Image.BICUBIC)
crop_blurred = np.array(back).astype(np.uint8)
fore_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
except Exception:
pass
# Build masked target and compose
union_mask_for_target = union_mask.copy()
if roi_update is not None:
rx, ry, rh, rw, warped_mask_roi = roi_update
um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw]
union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1)
masked_t_np = tgt_np.copy()
masked_t_np[union_mask_for_target.astype(bool)] = 255
composed_np = masked_t_np.copy()
m_fore = paste_mask_bool
composed_np[m_fore] = fore_np[m_fore]
# Build tensors
source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
# Caption: prepend instruction
cap_orig = rec.get("final_prompt", "") or ""
# Handle list format in final_prompt
if isinstance(cap_orig, list) and len(cap_orig) > 0:
cap_orig = cap_orig[0] if isinstance(cap_orig[0], str) else str(cap_orig[0])
cap = _prepend_caption(cap_orig)
if perspective_applied:
cap = f"{cap} Fix the perspective if necessary."
ids1, ids2 = _tokenize(tokenizers, cap)
return {
"source_pixel_values": source_tensor,
"pixel_values": target_tensor,
"token_ids_clip": ids1,
"token_ids_t5": ids2,
"mask_values": mask_tensor,
}
return PlacementDataset(records, base_dir)
def make_interactive_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
"""
Dataset for JSONL with fields:
- input_path: relative to base_dir (target image)
- output_path: absolute path to image with foreground
- mask_after_completion: absolute path to mask
- img_width, img_height: resize dimensions
- prompt: caption
source image construction:
- background is target_image with a hole punched by grown `mask_after_completion`
- foreground is from `output_path` image, pasted using original `mask_after_completion`
- 50% chance to color augment the foreground source
- NO perspective transform (moved to placement dataset)
Args:
base_dir: Base directory for resolving relative paths. If None, uses args.interactive_base_dir.
"""
if base_dir is None:
base_dir = getattr(args, "interactive_base_dir")
data_files = _resolve_jsonl(getattr(args, "train_data_jsonl", None))
file_paths = data_files.get("train", [])
records = []
for p in file_paths:
with open(p, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except Exception:
# Best-effort: strip any trailing commas and retry
try:
obj = json.loads(line.rstrip(","))
except Exception:
continue
# Keep only fields we actually need to avoid schema issues
pruned = {
"input_path": obj.get("input_path"),
"output_path": obj.get("output_path"),
"mask_after_completion": obj.get("mask_after_completion"),
"img_width": obj.get("img_width"),
"img_height": obj.get("img_height"),
"prompt": obj.get("prompt"),
# New optional fields
"generated_images": obj.get("generated_images"),
"positive_prompt_used": obj.get("positive_prompt_used"),
"negative_caption_used": obj.get("negative_caption_used"),
}
records.append(pruned)
size = int(getattr(args, "cond_size", 512))
to_tensor_and_norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
class SubjectsDataset(torch.utils.data.Dataset):
def __init__(self, hf_ds, base_dir):
self.ds = hf_ds
self.base_dir = base_dir
def __len__(self):
# Triplicate sampling per record
return len(self.ds)
def __getitem__(self, idx):
rec = self.ds[idx % len(self.ds)]
t_rel = rec.get("input_path", "")
foreground_p = rec.get("output_path", "")
m_abs = rec.get("mask_after_completion", "")
if not os.path.isabs(m_abs):
raise ValueError("mask_after_completion must be absolute")
if not os.path.isabs(foreground_p):
raise ValueError("output_path must be absolute")
t_p = os.path.join(self.base_dir, t_rel)
m_p = m_abs
import cv2
mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
if mask_loaded is None:
raise ValueError(f"Failed to read mask: {m_p}")
tgt_img = Image.open(t_p).convert("RGB")
foreground_source_img = Image.open(foreground_p).convert("RGB")
fw = int(rec.get("img_width", tgt_img.width))
fh = int(rec.get("img_height", tgt_img.height))
tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
foreground_source_img = foreground_source_img.resize((fw, fh), resample=Image.BILINEAR)
mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
# Ensure PIL images to tensors for outputs based on new logic later
target_tensor = to_tensor_and_norm(tgt_img)
# Binary mask at final_size
mask_np = np.array(mask_img)
mask_bin = (mask_np > 127).astype(np.uint8)
# 1) Grow m_p by random 50-100 pixels
grown_mask = _dilate_mask(mask_bin, 50, 200)
# 2) Optional random augmentation mask constrained by m_p
rand_mask = _constrained_random_mask(mask_bin, fh, fw, aug_prob=0.7)
# 3) Final union mask
union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8)
tgt_np = np.array(tgt_img)
# Helper: choose relighted image from generated_images with weights
def _choose_relight_image(rec, default_img, width, height):
gen_list = rec.get("generated_images") or []
# Build map mode -> path
mode_to_path = {}
for it in gen_list:
try:
mode = str(it.get("mode", "")).strip().lower()
path = it.get("path")
except Exception:
continue
if not mode or not path:
continue
mode_to_path[mode] = path
# Weighted selection among available modes
weighted_order = [
("grayscale", 0.5),
("low", 0.3),
("high", 0.2),
]
# Filter to available
available = [(m, w) for (m, w) in weighted_order if m in mode_to_path]
chosen_path = None
if available:
rnd = random.random()
cum = 0.0
total_w = sum(w for _, w in available)
for m, w in available:
cum += w / total_w
if rnd <= cum:
chosen_path = mode_to_path.get(m)
break
if chosen_path is None:
chosen_path = mode_to_path.get(available[-1][0])
else:
# Fallback to any provided path
if mode_to_path:
chosen_path = next(iter(mode_to_path.values()))
# Open chosen image
if chosen_path is not None:
try:
open_path = chosen_path
# generated paths are typically absolute; if not, use as-is
img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR)
return img
except Exception:
pass
return default_img
# 5) Choose base image with probabilities:
# 20%: original, 20%: color augment(original), 60%: relight augment
rsel = random.random()
if rsel < 0.2:
base_img = foreground_source_img
elif rsel < 0.4:
base_img = _color_augment(foreground_source_img)
else:
base_img = _choose_relight_image(rec, foreground_source_img, fw, fh)
base_np = np.array(base_img)
# 5.1) Random perspective augmentation (20%): apply to foreground ROI (mask bbox) and its mask only
perspective_applied = False
roi_update = None
paste_mask_bool = mask_bin.astype(bool)
if random.random() < 0.5:
try:
import cv2
ys, xs = np.where(mask_bin > 0)
if len(xs) > 0 and len(ys) > 0:
x1, x2 = int(xs.min()), int(xs.max())
y1, y2 = int(ys.min()), int(ys.max())
if x2 > x1 and y2 > y1:
roi = base_np[y1:y2 + 1, x1:x2 + 1]
roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1]
bh, bw = roi.shape[:2]
# Random perturbation relative to ROI size
max_ratio = random.uniform(0.1, 0.3)
dx = bw * max_ratio
dy = bh * max_ratio
src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]])
dst = np.float32([
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
])
M = cv2.getPerspectiveTransform(src, dst)
warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101)
warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127
# Build a fresh foreground canvas
fore_np = np.zeros_like(base_np)
h_warp, w_warp = warped_mask_roi.shape
y2c = y1 + h_warp
x2c = x1 + w_warp
fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi]
paste_mask_bool = np.zeros_like(mask_bin, dtype=bool)
paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi
roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi)
perspective_applied = True
base_np = fore_np
except Exception:
perspective_applied = False
paste_mask_bool = mask_bin.astype(bool)
# Optional: simulate cut-out foregrounds coming from different resolutions by
# downscaling the masked foreground region and upscaling back to original size.
# This introduces realistic blur/aliasing seen in real inpaint workflows.
if random.random() < 0.7:
ys, xs = np.where(mask_bin > 0)
if len(xs) > 0 and len(ys) > 0:
x1, x2 = int(xs.min()), int(xs.max())
y1, y2 = int(ys.min()), int(ys.max())
# Ensure valid box
if x2 > x1 and y2 > y1:
crop = base_np[y1:y2 + 1, x1:x2 + 1]
ch, cw = crop.shape[:2]
scale = random.uniform(0.2, 0.9)
dw = max(1, int(cw * scale))
dh = max(1, int(ch * scale))
try:
small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
back = small.resize((cw, ch), Image.BICUBIC)
crop_blurred = np.array(back).astype(np.uint8)
base_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
except Exception:
# Fallback: skip if resize fails
pass
# 6) Build masked target using (possibly) updated union mask; then paste
union_mask_for_target = union_mask.copy()
if roi_update is not None:
rx, ry, rh, rw, warped_mask_roi = roi_update
# Ensure union covers the warped foreground area inside ROI using warped shape
um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw]
union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1)
masked_t_np = tgt_np.copy()
masked_t_np[union_mask_for_target.astype(bool)] = 255
composed_np = masked_t_np.copy()
m_fore = paste_mask_bool
composed_np[m_fore] = base_np[m_fore]
# 7) Build tensors
source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
# 8) Caption: prepend instruction, 20% keep only instruction
cap_orig = rec.get("prompt", "") or ""
cap = _prepend_caption(cap_orig)
if perspective_applied:
cap = f"{cap} Fix the perspective if necessary."
ids1, ids2 = _tokenize(tokenizers, cap)
return {
"source_pixel_values": source_tensor,
"pixel_values": target_tensor,
"token_ids_clip": ids1,
"token_ids_t5": ids2,
"mask_values": mask_tensor,
}
return SubjectsDataset(records, base_dir)
def make_pexels_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
"""
Dataset for JSONL with fields:
- input_path: relative to base_dir (target image)
- output_path: relative to relight_base_dir (relighted image)
- final_size: {width, height} resize applied
- caption: text caption
Modified to use segmentation maps instead of raw_mask_path.
Randomly selects 2-5 segments from segmentation map, applies augmentation to each, and takes union.
This simulates multiple foreground objects being placed like a puzzle.
Each segment independently uses: 20% original, 20% color_augment, 60% relighted image.
Args:
base_dir: Base directory for resolving relative paths. If None, uses args.pexels_base_dir.
"""
if base_dir is None:
base_dir = getattr(args, "pexels_base_dir", "/mnt/robby-b1/common/datasets")
relight_base_dir = getattr(args, "pexels_relight_base_dir", "/robby/share/Editing/lzc/data/relight_outputs")
seg_base_dir = getattr(args, "seg_base_dir", "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182")
data_files = _resolve_jsonl(getattr(args, "pexels_data_jsonl", None))
file_paths = data_files.get("train", [])
records = []
for p in file_paths:
with open(p, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except Exception:
try:
obj = json.loads(line.rstrip(","))
except Exception:
continue
pruned = {
"input_path": obj.get("input_path"),
"output_path": obj.get("output_path"),
"final_size": obj.get("final_size"),
"caption": obj.get("caption"),
}
records.append(pruned)
to_tensor_and_norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
class PexelsDataset(torch.utils.data.Dataset):
def __init__(self, hf_ds, base_dir, relight_base_dir, seg_base_dir):
self.ds = hf_ds
self.base_dir = base_dir
self.relight_base_dir = relight_base_dir
self.seg_base_dir = seg_base_dir
def __len__(self):
return len(self.ds)
def _extract_hash_from_filename(self, filename: str) -> str:
"""Extract hash from input filename for segmentation map lookup."""
stem = os.path.splitext(os.path.basename(filename))[0]
if '_' in stem:
parts = stem.split('_')
return parts[-1]
return stem
def _build_segmap_path(self, input_filename: str) -> str:
"""Build path to segmentation map from input filename."""
hash_id = self._extract_hash_from_filename(input_filename)
return os.path.join(self.seg_base_dir, f"{hash_id}_map.png")
def _load_segmap_uint32(self, seg_path: str):
"""Load segmentation map as uint32 array."""
import cv2
try:
with Image.open(seg_path) as im:
if im.mode == 'P':
seg = np.array(im)
elif im.mode in ('I;16', 'I', 'L'):
seg = np.array(im)
else:
seg = np.array(im.convert('L'))
except Exception:
return None
if seg.ndim == 3:
seg = cv2.cvtColor(seg, cv2.COLOR_BGR2GRAY)
return seg.astype(np.uint32)
def _extract_multiple_segments(
self,
image_h: int,
image_w: int,
seg_path: str,
min_area_ratio: float = 0.02,
max_area_ratio: float = 0.4,
):
"""Extract 2-5 individual segment masks from segmentation map."""
import cv2
seg = self._load_segmap_uint32(seg_path)
if seg is None:
return []
if seg.shape != (image_h, image_w):
seg = cv2.resize(seg.astype(np.uint16), (image_w, image_h), interpolation=cv2.INTER_NEAREST).astype(np.uint32)
labels, counts = np.unique(seg, return_counts=True)
if labels.size == 0:
return []
# Exclude background label 0
bg_mask = labels == 0
labels = labels[~bg_mask]
counts = counts[~bg_mask]
if labels.size == 0:
return []
area = image_h * image_w
min_px = int(round(min_area_ratio * area))
max_px = int(round(max_area_ratio * area))
keep = (counts >= min_px) & (counts <= max_px)
cand_labels = labels[keep]
if cand_labels.size == 0:
return []
# Select 2-5 labels randomly
max_sel = min(5, cand_labels.size)
min_sel = min(2, cand_labels.size)
num_to_select = random.randint(min_sel, max_sel)
chosen = np.random.choice(cand_labels, size=num_to_select, replace=False)
# Create individual masks for each chosen label
individual_masks = []
for lab in chosen:
binm = (seg == int(lab)).astype(np.uint8)
# Apply opening operation to clean up mask
k = max(3, int(round(max(image_h, image_w) * 0.01)))
if k % 2 == 0:
k += 1
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
eroded = cv2.erode(binm, kernel, iterations=1)
opened = cv2.dilate(eroded, kernel, iterations=1)
individual_masks.append(opened)
return individual_masks
def __getitem__(self, idx):
rec = self.ds[idx % len(self.ds)]
t_rel = rec.get("input_path", "")
r_rel = rec.get("output_path", "")
t_p = os.path.join(self.base_dir, t_rel)
relight_p = os.path.join(self.relight_base_dir, r_rel)
import cv2
tgt_img = Image.open(t_p).convert("RGB")
# Load relighted image, fallback to target if not available
try:
relighted_img = Image.open(relight_p).convert("RGB")
except Exception:
relighted_img = tgt_img.copy()
final_size = rec.get("final_size", {}) or {}
fw = int(final_size.get("width", tgt_img.width))
fh = int(final_size.get("height", tgt_img.height))
tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
relighted_img = relighted_img.resize((fw, fh), resample=Image.BILINEAR)
target_tensor = to_tensor_and_norm(tgt_img)
# Build segmentation map path and extract multiple segments
input_filename = os.path.basename(t_rel)
seg_path = self._build_segmap_path(input_filename)
individual_masks = self._extract_multiple_segments(fh, fw, seg_path)
if not individual_masks:
# Fallback: create empty mask (will be handled gracefully)
union_mask = np.zeros((fh, fw), dtype=np.uint8)
individual_masks = []
else:
# Apply augmentation to each segment mask and take union
augmented_masks = []
for seg_mask in individual_masks:
# 1) Grow mask by random 50-200 pixels
grown = _dilate_mask(seg_mask, 50, 200)
# 2) Optional random augmentation mask constrained by this segment
rand_mask = _constrained_random_mask(seg_mask, fh, fw, aug_prob=0.7)
# 3) Union for this segment
seg_union = np.clip(grown | rand_mask, 0, 1).astype(np.uint8)
augmented_masks.append(seg_union)
# Take union of all augmented masks
union_mask = np.zeros((fh, fw), dtype=np.uint8)
for m in augmented_masks:
union_mask = np.clip(union_mask | m, 0, 1).astype(np.uint8)
tgt_np = np.array(tgt_img)
# Build masked target first
masked_t_np = tgt_np.copy()
masked_t_np[union_mask.astype(bool)] = 255
composed_np = masked_t_np.copy()
# Process each segment independently with different augmentations
# This simulates multiple foreground objects from different sources
for seg_mask in individual_masks:
# 1) Choose source for this segment: 20% original, 20% color_augment, 60% relighted
r = random.random()
if r < 0.2:
# Original image
seg_source_img = tgt_img
else:
seg_source_img = _color_augment(tgt_img)
# elif r < 0.4:
# # Color augmentation
# seg_source_img = _color_augment(tgt_img)
# else:
# # Relighted image
# seg_source_img = relighted_img
seg_source_np = np.array(seg_source_img)
# 2) Apply resolution augmentation to this segment's region
if random.random() < 0.7:
ys, xs = np.where(seg_mask > 0)
if len(xs) > 0 and len(ys) > 0:
x1, x2 = int(xs.min()), int(xs.max())
y1, y2 = int(ys.min()), int(ys.max())
if x2 > x1 and y2 > y1:
crop = seg_source_np[y1:y2 + 1, x1:x2 + 1]
ch, cw = crop.shape[:2]
scale = random.uniform(0.2, 0.9)
dw = max(1, int(cw * scale))
dh = max(1, int(ch * scale))
try:
small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
back = small.resize((cw, ch), Image.BICUBIC)
crop_blurred = np.array(back).astype(np.uint8)
seg_source_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
except Exception:
pass
# 3) Paste this segment onto composed image
m_fore = seg_mask.astype(bool)
composed_np[m_fore] = seg_source_np[m_fore]
# Build tensors
source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
# Caption: prepend instruction
cap_orig = rec.get("caption", "") or ""
cap = _prepend_caption(cap_orig)
ids1, ids2 = _tokenize(tokenizers, cap)
return {
"source_pixel_values": source_tensor,
"pixel_values": target_tensor,
"token_ids_clip": ids1,
"token_ids_t5": ids2,
"mask_values": mask_tensor,
}
return PexelsDataset(records, base_dir, relight_base_dir, seg_base_dir)
def make_mixed_dataset(args, tokenizers, interactive_jsonl_path=None, placement_jsonl_path=None,
pexels_jsonl_path=None, interactive_base_dir=None, placement_base_dir=None,
pexels_base_dir=None, interactive_weight=1.0, placement_weight=1.0,
pexels_weight=1.0, accelerator=None):
"""
Create a mixed dataset combining interactive, placement, and pexels datasets.
Args:
args: Arguments object with dataset configuration
tokenizers: Tuple of tokenizers for text encoding
interactive_jsonl_path: Path to interactive dataset JSONL (optional)
placement_jsonl_path: Path to placement dataset JSONL (optional)
pexels_jsonl_path: Path to pexels dataset JSONL (optional)
interactive_base_dir: Base directory for interactive dataset paths (optional)
placement_base_dir: Base directory for placement dataset paths (optional)
pexels_base_dir: Base directory for pexels dataset paths (optional)
interactive_weight: Sampling weight for interactive dataset (default: 1.0)
placement_weight: Sampling weight for placement dataset (default: 1.0)
pexels_weight: Sampling weight for pexels dataset (default: 1.0)
accelerator: Optional accelerator object
Returns:
Mixed dataset that samples from all provided datasets with specified weights
"""
datasets = []
dataset_names = []
dataset_weights = []
# Create interactive dataset if path provided
if interactive_jsonl_path:
interactive_args = type('Args', (), {})()
for k, v in vars(args).items():
setattr(interactive_args, k, v)
interactive_args.train_data_jsonl = interactive_jsonl_path
if interactive_base_dir:
interactive_args.interactive_base_dir = interactive_base_dir
interactive_ds = make_interactive_dataset_subjects(interactive_args, tokenizers, accelerator)
datasets.append(interactive_ds)
dataset_names.append("interactive")
dataset_weights.append(interactive_weight)
# Create placement dataset if path provided
if placement_jsonl_path:
placement_args = type('Args', (), {})()
for k, v in vars(args).items():
setattr(placement_args, k, v)
placement_args.placement_data_jsonl = placement_jsonl_path
if placement_base_dir:
placement_args.placement_base_dir = placement_base_dir
placement_ds = make_placement_dataset_subjects(placement_args, tokenizers, accelerator)
datasets.append(placement_ds)
dataset_names.append("placement")
dataset_weights.append(placement_weight)
# Create pexels dataset if path provided
if pexels_jsonl_path:
pexels_args = type('Args', (), {})()
for k, v in vars(args).items():
setattr(pexels_args, k, v)
pexels_args.pexels_data_jsonl = pexels_jsonl_path
if pexels_base_dir:
pexels_args.pexels_base_dir = pexels_base_dir
pexels_ds = make_pexels_dataset_subjects(pexels_args, tokenizers, accelerator)
datasets.append(pexels_ds)
dataset_names.append("pexels")
dataset_weights.append(pexels_weight)
if not datasets:
raise ValueError("At least one dataset path must be provided")
if len(datasets) == 1:
return datasets[0]
# Mixed dataset class with balanced sampling (based on smallest dataset)
class MixedDataset(torch.utils.data.Dataset):
def __init__(self, datasets, dataset_names, dataset_weights):
self.datasets = datasets
self.dataset_names = dataset_names
self.lengths = [len(ds) for ds in datasets]
# Normalize weights
total_weight = sum(dataset_weights)
self.weights = [w / total_weight for w in dataset_weights]
# Calculate samples per dataset based on smallest dataset and weights
# Find the minimum weighted size
min_weighted_size = min(length / weight for length, weight in zip(self.lengths, dataset_weights))
# Each dataset contributes samples proportional to its weight, scaled by min_weighted_size
self.samples_per_dataset = [int(min_weighted_size * w) for w in dataset_weights]
self.total_length = sum(self.samples_per_dataset)
# Build cumulative sample counts for indexing
self.cumsum_samples = [0]
for count in self.samples_per_dataset:
self.cumsum_samples.append(self.cumsum_samples[-1] + count)
print(f"Balanced mixed dataset created:")
for i, name in enumerate(dataset_names):
print(f" {name}: {self.lengths[i]} total, {self.samples_per_dataset[i]} per epoch")
print(f" Total samples per epoch: {self.total_length}")
def __len__(self):
return self.total_length
def __getitem__(self, idx):
# Determine which dataset this idx belongs to
dataset_idx = 0
for i in range(len(self.cumsum_samples) - 1):
if self.cumsum_samples[i] <= idx < self.cumsum_samples[i + 1]:
dataset_idx = i
break
# Randomly sample from the selected dataset (enables different samples each epoch)
local_idx = random.randint(0, self.lengths[dataset_idx] - 1)
sample = self.datasets[dataset_idx][local_idx]
# Add dataset source information
sample["dataset_source"] = self.dataset_names[dataset_idx]
return sample
return MixedDataset(datasets, dataset_names, dataset_weights)
def _run_test_mode(
interactive_jsonl: str = None,
placement_jsonl: str = None,
pexels_jsonl: str = None,
interactive_base_dir: str = None,
placement_base_dir: str = None,
pexels_base_dir: str = None,
pexels_relight_base_dir: str = None,
seg_base_dir: str = None,
interactive_weight: float = 1.0,
placement_weight: float = 1.0,
pexels_weight: float = 1.0,
output_dir: str = "test_output",
num_samples: int = 100
):
"""Test dataset by saving samples with source labels.
Args:
interactive_jsonl: Path to interactive dataset JSONL (optional)
placement_jsonl: Path to placement dataset JSONL (optional)
pexels_jsonl: Path to pexels dataset JSONL (optional)
interactive_base_dir: Base directory for interactive dataset
placement_base_dir: Base directory for placement dataset
pexels_base_dir: Base directory for pexels dataset
pexels_relight_base_dir: Base directory for pexels relighted images
seg_base_dir: Directory containing segmentation maps for pexels dataset
interactive_weight: Sampling weight for interactive dataset (default: 1.0)
placement_weight: Sampling weight for placement dataset (default: 1.0)
pexels_weight: Sampling weight for pexels dataset (default: 1.0)
output_dir: Output directory for test images
num_samples: Number of samples to save
"""
if not interactive_jsonl and not placement_jsonl and not pexels_jsonl:
raise ValueError("At least one dataset path must be provided")
os.makedirs(output_dir, exist_ok=True)
# Create dummy tokenizers for testing
class DummyTokenizer:
def __call__(self, text, **kwargs):
class Result:
input_ids = torch.zeros(1, 77, dtype=torch.long)
return Result()
tokenizers = (DummyTokenizer(), DummyTokenizer())
# Create args object
class Args:
cond_size = 512
args = Args()
args.train_data_jsonl = interactive_jsonl
args.placement_data_jsonl = placement_jsonl
args.pexels_data_jsonl = pexels_jsonl
args.interactive_base_dir = interactive_base_dir
args.placement_base_dir = placement_base_dir
args.pexels_base_dir = pexels_base_dir
args.pexels_relight_base_dir = pexels_relight_base_dir if pexels_relight_base_dir else "/robby/share/Editing/lzc/data/relight_outputs"
args.seg_base_dir = seg_base_dir if seg_base_dir else "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182"
# Create dataset (single or mixed)
try:
# Count how many datasets are provided
num_datasets = sum([bool(interactive_jsonl), bool(placement_jsonl), bool(pexels_jsonl)])
if num_datasets > 1:
dataset = make_mixed_dataset(
args, tokenizers,
interactive_jsonl_path=interactive_jsonl,
placement_jsonl_path=placement_jsonl,
pexels_jsonl_path=pexels_jsonl,
interactive_base_dir=args.interactive_base_dir,
placement_base_dir=args.placement_base_dir,
pexels_base_dir=args.pexels_base_dir,
interactive_weight=interactive_weight,
placement_weight=placement_weight,
pexels_weight=pexels_weight
)
print(f"Created mixed dataset with {len(dataset)} samples")
weights_str = []
if interactive_jsonl:
weights_str.append(f"Interactive: {interactive_weight:.2f}")
if placement_jsonl:
weights_str.append(f"Placement: {placement_weight:.2f}")
if pexels_jsonl:
weights_str.append(f"Pexels: {pexels_weight:.2f}")
print(f"Sampling weights - {', '.join(weights_str)}")
elif pexels_jsonl:
dataset = make_pexels_dataset_subjects(args, tokenizers, base_dir=pexels_base_dir)
print(f"Created pexels dataset with {len(dataset)} samples")
elif placement_jsonl:
dataset = make_placement_dataset_subjects(args, tokenizers, base_dir=args.placement_base_dir)
print(f"Created placement dataset with {len(dataset)} samples")
else:
dataset = make_interactive_dataset_subjects(args, tokenizers, base_dir=args.interactive_base_dir)
print(f"Created interactive dataset with {len(dataset)} samples")
except Exception as e:
print(f"Failed to create dataset: {e}")
import traceback
traceback.print_exc()
return
# Sample and save
saved = 0
counts = {}
for attempt in range(min(num_samples * 3, len(dataset))):
try:
idx = random.randint(0, len(dataset) - 1)
sample = dataset[idx]
source_name = sample.get("dataset_source", "single")
counts[source_name] = counts.get(source_name, 0) + 1
# Denormalize tensors from [-1, 1] to [0, 255]
source_np = ((sample["source_pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
target_np = ((sample["pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
# Save images
idx_str = f"{saved:05d}"
Image.fromarray(source_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_source.jpg"))
Image.fromarray(target_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_target.jpg"))
saved += 1
if saved % 10 == 0:
print(f"Saved {saved}/{num_samples} samples - {counts}")
if saved >= num_samples:
break
except Exception as e:
print(f"Failed to process sample: {e}")
continue
print(f"\nTest complete. Saved {saved} samples to {output_dir}")
print(f"Distribution: {counts}")
def _parse_test_args():
import argparse
parser = argparse.ArgumentParser(description="Test visualization for Kontext datasets")
parser.add_argument("--interactive_jsonl", type=str, default="/robby/share/Editing/lzc/HOI_v1/final_metadata.jsonl",
help="Path to interactive dataset JSONL")
parser.add_argument("--placement_jsonl", type=str, default="/robby/share/Editing/lzc/subject_placement/metadata_relight.jsonl",
help="Path to placement dataset JSONL")
parser.add_argument("--pexels_jsonl", type=str, default=None,
help="Path to pexels dataset JSONL")
parser.add_argument("--interactive_base_dir", type=str, default="/robby/share/Editing/lzc/HOI_v1",
help="Base directory for interactive dataset")
parser.add_argument("--placement_base_dir", type=str, default=None,
help="Base directory for placement dataset")
parser.add_argument("--pexels_base_dir", type=str, default=None,
help="Base directory for pexels dataset")
parser.add_argument("--pexels_relight_base_dir", type=str, default="/robby/share/Editing/lzc/data/relight_outputs",
help="Base directory for pexels relighted images")
parser.add_argument("--seg_base_dir", type=str, default=None,
help="Directory containing segmentation maps for pexels dataset")
parser.add_argument("--interactive_weight", type=float, default=1.0,
help="Sampling weight for interactive dataset (default: 1.0)")
parser.add_argument("--placement_weight", type=float, default=1.0,
help="Sampling weight for placement dataset (default: 1.0)")
parser.add_argument("--pexels_weight", type=float, default=0,
help="Sampling weight for pexels dataset (default: 1.0)")
parser.add_argument("--output_dir", type=str, default="visualize_output",
help="Output directory to save pairs")
parser.add_argument("--num_samples", type=int, default=100,
help="Number of pairs to save")
# Legacy arguments
parser.add_argument("--test_jsonl", type=str, default=None,
help="Legacy: Path to JSONL (uses as interactive_jsonl)")
parser.add_argument("--base_dir", type=str, default=None,
help="Legacy: Base directory (uses as interactive_base_dir)")
return parser.parse_args()
if __name__ == "__main__":
try:
args = _parse_test_args()
# Handle legacy args
interactive_jsonl = args.interactive_jsonl or args.test_jsonl
interactive_base_dir = args.interactive_base_dir or args.base_dir
_run_test_mode(
interactive_jsonl=interactive_jsonl,
placement_jsonl=args.placement_jsonl,
pexels_jsonl=args.pexels_jsonl,
interactive_base_dir=interactive_base_dir,
placement_base_dir=args.placement_base_dir,
pexels_base_dir=args.pexels_base_dir,
pexels_relight_base_dir=args.pexels_relight_base_dir,
seg_base_dir=args.seg_base_dir,
interactive_weight=args.interactive_weight,
placement_weight=args.placement_weight,
pexels_weight=args.pexels_weight,
output_dir=args.output_dir,
num_samples=args.num_samples
)
except SystemExit:
# Allow import usage without triggering test mode
pass