MagicQuillV2 / train /src /jsonl_datasets_kontext_local.py
LiuZichen's picture
update
f460ce6
from PIL import Image
from datasets import Dataset
from torchvision import transforms
import random
import torch
import os
from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
from .jsonl_datasets_kontext import make_train_dataset_inpaint_mask
import numpy as np
import json
from .generate_diff_mask import generate_final_difference_mask, align_images
Image.MAX_IMAGE_PIXELS = None
BLEND_PIXEL_VALUES = True
def multiple_16(num: float):
return int(round(num / 16) * 16)
def choose_kontext_resolution_from_wh(width: int, height: int):
aspect_ratio = width / max(1, height)
_, best_w, best_h = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
return best_w, best_h
def collate_fn(examples):
if examples[0].get("cond_pixel_values") is not None:
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
else:
cond_pixel_values = None
if examples[0].get("source_pixel_values") is not None:
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
else:
source_pixel_values = None
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
mask_values = None
if examples[0].get("mask_values") is not None:
mask_values = torch.stack([example["mask_values"] for example in examples])
mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
return {
"cond_pixel_values": cond_pixel_values,
"source_pixel_values": source_pixel_values,
"pixel_values": target_pixel_values,
"text_ids_1": token_ids_clip,
"text_ids_2": token_ids_t5,
"mask_values": mask_values,
}
# New dataset for local_edits JSON mapping with on-the-fly diff-mask generation
def make_train_dataset_local_edits(args, tokenizers, accelerator=None):
# Read JSON entries
with open(args.local_edits_json, "r", encoding="utf-8") as f:
entries = json.load(f)
samples = []
for item in entries:
rel_path = item.get("path", "")
local_edits = item.get("local_edits", []) or []
if not rel_path or not local_edits:
continue
base_name = os.path.basename(rel_path)
prefix = os.path.splitext(base_name)[0]
group_dir = os.path.basename(os.path.dirname(rel_path))
gid_int = None
try:
gid_int = int(group_dir)
except Exception:
try:
digits = "".join([ch for ch in group_dir if ch.isdigit()])
gid_int = int(digits) if digits else None
except Exception:
gid_int = None
group_str = group_dir # e.g., "0139" from the JSON path segment
# Resolve source/target directories strictly as base/<0139>
src_dir_candidates = [os.path.join(args.source_frames_dir, group_str)]
tgt_dir_candidates = [os.path.join(args.target_frames_dir, group_str)]
src_dir = next((d for d in src_dir_candidates if d and os.path.isdir(d)), None)
tgt_dir = next((d for d in tgt_dir_candidates if d and os.path.isdir(d)), None)
if src_dir is None or tgt_dir is None:
continue
src_path = os.path.join(src_dir, f"{prefix}.png")
for idx, prompt in enumerate(local_edits, start=1):
tgt_path = os.path.join(tgt_dir, f"{prefix}_{idx}.png")
mask_path = os.path.join(args.masks_dir, group_str, f"{prefix}_{idx}.png")
if not (os.path.exists(src_path) and os.path.exists(tgt_path) and os.path.exists(mask_path)):
continue
samples.append({
"source_image": src_path,
"target_image": tgt_path,
"mask_image": mask_path,
"prompt": prompt,
})
size = args.cond_size
to_tensor_and_norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
cond_train_transforms = transforms.Compose(
[
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
tokenizer_clip = tokenizers[0]
tokenizer_t5 = tokenizers[1]
def tokenize_prompt_single(caption: str):
text_inputs_clip = tokenizer_clip(
[caption],
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
text_input_ids_1 = text_inputs_clip.input_ids[0]
text_inputs_t5 = tokenizer_t5(
[caption],
padding="max_length",
max_length=128,
truncation=True,
return_tensors="pt",
)
text_input_ids_2 = text_inputs_t5.input_ids[0]
return text_input_ids_1, text_input_ids_2
class LocalEditsDataset(torch.utils.data.Dataset):
def __init__(self, samples_ls):
self.samples = samples_ls
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
s_p = sample["source_image"]
t_p = sample["target_image"]
m_p = sample["mask_image"]
cap = sample["prompt"]
rr = random.randint(10, 20)
ri = random.randint(3, 5)
import cv2
mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
if mask_loaded is None:
raise ValueError("mask load failed")
mask = mask_loaded.copy()
# Pre-expand mask by a fixed number of pixels before any random expansion
# Uses a cross-shaped kernel when tapered_corners is True to emulate "tapered" growth
pre_expand_px = int(getattr(args, "pre_expand_mask_px", 50))
pre_expand_tapered = bool(getattr(args, "pre_expand_tapered_corners", True))
if pre_expand_px != 0:
c = 0 if pre_expand_tapered else 1
pre_kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]], dtype=np.uint8)
if pre_expand_px > 0:
mask = cv2.dilate(mask, pre_kernel, iterations=pre_expand_px)
else:
mask = cv2.erode(mask, pre_kernel, iterations=abs(pre_expand_px))
if rr > 0 and ri > 0:
ksize = max(1, 2 * int(rr) + 1)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
for _ in range(max(1, ri)):
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
src_aligned, tgt_aligned = align_images(s_p, t_p)
best_w, best_h = choose_kontext_resolution_from_wh(tgt_aligned.width, tgt_aligned.height)
final_img_rs = tgt_aligned.resize((best_w, best_h), resample=Image.BILINEAR)
raw_img_rs = src_aligned.resize((best_w, best_h), resample=Image.BILINEAR)
target_tensor = to_tensor_and_norm(final_img_rs)
source_tensor = to_tensor_and_norm(raw_img_rs)
mask_img = Image.fromarray(mask.astype(np.uint8)).convert("L")
if mask_img.size != src_aligned.size:
mask_img = mask_img.resize(src_aligned.size, Image.NEAREST)
mask_np = np.array(mask_img)
mask_bin = (mask_np > 127).astype(np.uint8)
inv_mask = (1 - mask_bin).astype(np.uint8)
src_np = np.array(src_aligned)
masked_raw_np = src_np * inv_mask[..., None]
masked_raw_img = Image.fromarray(masked_raw_np.astype(np.uint8))
cond_tensor = cond_train_transforms(masked_raw_img)
# Prepare mask_values tensor at training resolution (best_w, best_h)
mask_img_rs = mask_img.resize((best_w, best_h), Image.NEAREST)
mask_np_rs = np.array(mask_img_rs)
mask_bin_rs = (mask_np_rs > 127).astype(np.float32)
mask_tensor = torch.from_numpy(mask_bin_rs).unsqueeze(0) # [1, H, W]
ids1, ids2 = tokenize_prompt_single(cap if isinstance(cap, str) else "")
# Optionally blend target and source using a blurred mask, controlled by args
if getattr(args, "blend_pixel_values", BLEND_PIXEL_VALUES):
blend_kernel = int(getattr(args, "blend_kernel", 21))
if blend_kernel % 2 == 0:
blend_kernel += 1
blend_sigma = float(getattr(args, "blend_sigma", 10.0))
gb = transforms.GaussianBlur(kernel_size=(blend_kernel, blend_kernel), sigma=(blend_sigma, blend_sigma))
# mask_tensor: [1, H, W] in [0,1]
blurred_mask = gb(mask_tensor) # [1, H, W]
# Expand to 3 channels to match image tensors
blurred_mask_3c = blurred_mask.expand(target_tensor.shape[0], -1, -1) # [3, H, W]
# Blend in normalized space (both tensors already normalized to [-1, 1])
target_tensor = (source_tensor * (1.0 - blurred_mask_3c)) + (target_tensor * blurred_mask_3c)
target_tensor = target_tensor.clamp(-1.0, 1.0)
return {
"source_pixel_values": source_tensor,
"pixel_values": target_tensor,
"cond_pixel_values": cond_tensor,
"token_ids_clip": ids1,
"token_ids_t5": ids2,
"mask_values": mask_tensor,
}
return LocalEditsDataset(samples)
class BalancedMixDataset(torch.utils.data.Dataset):
"""
A wrapper dataset that mixes two datasets with a configurable ratio.
ratio_b_per_a defines how many samples from dataset_b for each sample from dataset_a:
- 0 => only dataset_a (local edits)
- 1 => 1:1 mix (default)
- 2 => 1:2 mix (A:B)
- any float supported (e.g., 0.5 => 2:1 mix)
"""
def __init__(self, dataset_a, dataset_b, ratio_b_per_a: float = 1.0):
self.dataset_a = dataset_a
self.dataset_b = dataset_b
self.ratio_b_per_a = max(0.0, float(ratio_b_per_a))
len_a = len(dataset_a)
len_b = len(dataset_b)
# If ratio is 0, use all of dataset_a only
if self.ratio_b_per_a == 0 or len_b == 0:
a_indices = list(range(len_a))
random.shuffle(a_indices)
self.mapping = [(0, i) for i in a_indices]
return
# Determine how many we can draw without replacement
# n_a limited by A size and B availability according to ratio
n_a_by_ratio = int(len_b / self.ratio_b_per_a)
n_a = min(len_a, max(1, n_a_by_ratio))
n_b = min(len_b, max(1, int(round(n_a * self.ratio_b_per_a))))
a_indices = list(range(len_a))
b_indices = list(range(len_b))
random.shuffle(a_indices)
random.shuffle(b_indices)
a_indices = a_indices[: n_a]
b_indices = b_indices[: n_b]
mixed = [(0, i) for i in a_indices] + [(1, i) for i in b_indices]
random.shuffle(mixed)
self.mapping = mixed
def __len__(self):
return len(self.mapping)
def __getitem__(self, idx):
which, real_idx = self.mapping[idx]
if which == 0:
return self.dataset_a[real_idx]
else:
return self.dataset_b[real_idx]
def make_train_dataset_mixed(args, tokenizers, accelerator=None):
"""
Create a mixed dataset from:
- Local edits dataset (this file)
- Inpaint-mask JSONL dataset (jsonl_datasets_kontext.make_train_dataset_inpaint_mask)
Ratio control via args.mix_ratio (float):
- 0 => only local edits dataset
- 1 => 1:1 mix (local:inpaint)
- 2 => 1:2 mix, etc.
Requirements:
- args.local_edits_json and related dirs must be set for local edits
- args.train_data_dir must be set for the JSONL inpaint dataset
"""
ds_local = make_train_dataset_local_edits(args, tokenizers, accelerator)
ds_inpaint = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
mix_ratio = getattr(args, "mix_ratio", 1.0)
return BalancedMixDataset(ds_local, ds_inpaint, ratio_b_per_a=mix_ratio)