Spaces:
Running
on
Zero
Running
on
Zero
| from PIL import Image | |
| from datasets import Dataset | |
| from torchvision import transforms | |
| import random | |
| import torch | |
| import os | |
| from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS | |
| from .jsonl_datasets_kontext import make_train_dataset_inpaint_mask | |
| import numpy as np | |
| import json | |
| from .generate_diff_mask import generate_final_difference_mask, align_images | |
| Image.MAX_IMAGE_PIXELS = None | |
| BLEND_PIXEL_VALUES = True | |
| def multiple_16(num: float): | |
| return int(round(num / 16) * 16) | |
| def choose_kontext_resolution_from_wh(width: int, height: int): | |
| aspect_ratio = width / max(1, height) | |
| _, best_w, best_h = min( | |
| (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS | |
| ) | |
| return best_w, best_h | |
| def collate_fn(examples): | |
| if examples[0].get("cond_pixel_values") is not None: | |
| cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) | |
| cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| else: | |
| cond_pixel_values = None | |
| if examples[0].get("source_pixel_values") is not None: | |
| source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples]) | |
| source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| else: | |
| source_pixel_values = None | |
| target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
| target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() | |
| token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples]) | |
| token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples]) | |
| mask_values = None | |
| if examples[0].get("mask_values") is not None: | |
| mask_values = torch.stack([example["mask_values"] for example in examples]) | |
| mask_values = mask_values.to(memory_format=torch.contiguous_format).float() | |
| return { | |
| "cond_pixel_values": cond_pixel_values, | |
| "source_pixel_values": source_pixel_values, | |
| "pixel_values": target_pixel_values, | |
| "text_ids_1": token_ids_clip, | |
| "text_ids_2": token_ids_t5, | |
| "mask_values": mask_values, | |
| } | |
| # New dataset for local_edits JSON mapping with on-the-fly diff-mask generation | |
| def make_train_dataset_local_edits(args, tokenizers, accelerator=None): | |
| # Read JSON entries | |
| with open(args.local_edits_json, "r", encoding="utf-8") as f: | |
| entries = json.load(f) | |
| samples = [] | |
| for item in entries: | |
| rel_path = item.get("path", "") | |
| local_edits = item.get("local_edits", []) or [] | |
| if not rel_path or not local_edits: | |
| continue | |
| base_name = os.path.basename(rel_path) | |
| prefix = os.path.splitext(base_name)[0] | |
| group_dir = os.path.basename(os.path.dirname(rel_path)) | |
| gid_int = None | |
| try: | |
| gid_int = int(group_dir) | |
| except Exception: | |
| try: | |
| digits = "".join([ch for ch in group_dir if ch.isdigit()]) | |
| gid_int = int(digits) if digits else None | |
| except Exception: | |
| gid_int = None | |
| group_str = group_dir # e.g., "0139" from the JSON path segment | |
| # Resolve source/target directories strictly as base/<0139> | |
| src_dir_candidates = [os.path.join(args.source_frames_dir, group_str)] | |
| tgt_dir_candidates = [os.path.join(args.target_frames_dir, group_str)] | |
| src_dir = next((d for d in src_dir_candidates if d and os.path.isdir(d)), None) | |
| tgt_dir = next((d for d in tgt_dir_candidates if d and os.path.isdir(d)), None) | |
| if src_dir is None or tgt_dir is None: | |
| continue | |
| src_path = os.path.join(src_dir, f"{prefix}.png") | |
| for idx, prompt in enumerate(local_edits, start=1): | |
| tgt_path = os.path.join(tgt_dir, f"{prefix}_{idx}.png") | |
| mask_path = os.path.join(args.masks_dir, group_str, f"{prefix}_{idx}.png") | |
| if not (os.path.exists(src_path) and os.path.exists(tgt_path) and os.path.exists(mask_path)): | |
| continue | |
| samples.append({ | |
| "source_image": src_path, | |
| "target_image": tgt_path, | |
| "mask_image": mask_path, | |
| "prompt": prompt, | |
| }) | |
| size = args.cond_size | |
| to_tensor_and_norm = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| cond_train_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| tokenizer_clip = tokenizers[0] | |
| tokenizer_t5 = tokenizers[1] | |
| def tokenize_prompt_single(caption: str): | |
| text_inputs_clip = tokenizer_clip( | |
| [caption], | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids_1 = text_inputs_clip.input_ids[0] | |
| text_inputs_t5 = tokenizer_t5( | |
| [caption], | |
| padding="max_length", | |
| max_length=128, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids_2 = text_inputs_t5.input_ids[0] | |
| return text_input_ids_1, text_input_ids_2 | |
| class LocalEditsDataset(torch.utils.data.Dataset): | |
| def __init__(self, samples_ls): | |
| self.samples = samples_ls | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| sample = self.samples[idx] | |
| s_p = sample["source_image"] | |
| t_p = sample["target_image"] | |
| m_p = sample["mask_image"] | |
| cap = sample["prompt"] | |
| rr = random.randint(10, 20) | |
| ri = random.randint(3, 5) | |
| import cv2 | |
| mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE) | |
| if mask_loaded is None: | |
| raise ValueError("mask load failed") | |
| mask = mask_loaded.copy() | |
| # Pre-expand mask by a fixed number of pixels before any random expansion | |
| # Uses a cross-shaped kernel when tapered_corners is True to emulate "tapered" growth | |
| pre_expand_px = int(getattr(args, "pre_expand_mask_px", 50)) | |
| pre_expand_tapered = bool(getattr(args, "pre_expand_tapered_corners", True)) | |
| if pre_expand_px != 0: | |
| c = 0 if pre_expand_tapered else 1 | |
| pre_kernel = np.array([[c, 1, c], | |
| [1, 1, 1], | |
| [c, 1, c]], dtype=np.uint8) | |
| if pre_expand_px > 0: | |
| mask = cv2.dilate(mask, pre_kernel, iterations=pre_expand_px) | |
| else: | |
| mask = cv2.erode(mask, pre_kernel, iterations=abs(pre_expand_px)) | |
| if rr > 0 and ri > 0: | |
| ksize = max(1, 2 * int(rr) + 1) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) | |
| for _ in range(max(1, ri)): | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| src_aligned, tgt_aligned = align_images(s_p, t_p) | |
| best_w, best_h = choose_kontext_resolution_from_wh(tgt_aligned.width, tgt_aligned.height) | |
| final_img_rs = tgt_aligned.resize((best_w, best_h), resample=Image.BILINEAR) | |
| raw_img_rs = src_aligned.resize((best_w, best_h), resample=Image.BILINEAR) | |
| target_tensor = to_tensor_and_norm(final_img_rs) | |
| source_tensor = to_tensor_and_norm(raw_img_rs) | |
| mask_img = Image.fromarray(mask.astype(np.uint8)).convert("L") | |
| if mask_img.size != src_aligned.size: | |
| mask_img = mask_img.resize(src_aligned.size, Image.NEAREST) | |
| mask_np = np.array(mask_img) | |
| mask_bin = (mask_np > 127).astype(np.uint8) | |
| inv_mask = (1 - mask_bin).astype(np.uint8) | |
| src_np = np.array(src_aligned) | |
| masked_raw_np = src_np * inv_mask[..., None] | |
| masked_raw_img = Image.fromarray(masked_raw_np.astype(np.uint8)) | |
| cond_tensor = cond_train_transforms(masked_raw_img) | |
| # Prepare mask_values tensor at training resolution (best_w, best_h) | |
| mask_img_rs = mask_img.resize((best_w, best_h), Image.NEAREST) | |
| mask_np_rs = np.array(mask_img_rs) | |
| mask_bin_rs = (mask_np_rs > 127).astype(np.float32) | |
| mask_tensor = torch.from_numpy(mask_bin_rs).unsqueeze(0) # [1, H, W] | |
| ids1, ids2 = tokenize_prompt_single(cap if isinstance(cap, str) else "") | |
| # Optionally blend target and source using a blurred mask, controlled by args | |
| if getattr(args, "blend_pixel_values", BLEND_PIXEL_VALUES): | |
| blend_kernel = int(getattr(args, "blend_kernel", 21)) | |
| if blend_kernel % 2 == 0: | |
| blend_kernel += 1 | |
| blend_sigma = float(getattr(args, "blend_sigma", 10.0)) | |
| gb = transforms.GaussianBlur(kernel_size=(blend_kernel, blend_kernel), sigma=(blend_sigma, blend_sigma)) | |
| # mask_tensor: [1, H, W] in [0,1] | |
| blurred_mask = gb(mask_tensor) # [1, H, W] | |
| # Expand to 3 channels to match image tensors | |
| blurred_mask_3c = blurred_mask.expand(target_tensor.shape[0], -1, -1) # [3, H, W] | |
| # Blend in normalized space (both tensors already normalized to [-1, 1]) | |
| target_tensor = (source_tensor * (1.0 - blurred_mask_3c)) + (target_tensor * blurred_mask_3c) | |
| target_tensor = target_tensor.clamp(-1.0, 1.0) | |
| return { | |
| "source_pixel_values": source_tensor, | |
| "pixel_values": target_tensor, | |
| "cond_pixel_values": cond_tensor, | |
| "token_ids_clip": ids1, | |
| "token_ids_t5": ids2, | |
| "mask_values": mask_tensor, | |
| } | |
| return LocalEditsDataset(samples) | |
| class BalancedMixDataset(torch.utils.data.Dataset): | |
| """ | |
| A wrapper dataset that mixes two datasets with a configurable ratio. | |
| ratio_b_per_a defines how many samples from dataset_b for each sample from dataset_a: | |
| - 0 => only dataset_a (local edits) | |
| - 1 => 1:1 mix (default) | |
| - 2 => 1:2 mix (A:B) | |
| - any float supported (e.g., 0.5 => 2:1 mix) | |
| """ | |
| def __init__(self, dataset_a, dataset_b, ratio_b_per_a: float = 1.0): | |
| self.dataset_a = dataset_a | |
| self.dataset_b = dataset_b | |
| self.ratio_b_per_a = max(0.0, float(ratio_b_per_a)) | |
| len_a = len(dataset_a) | |
| len_b = len(dataset_b) | |
| # If ratio is 0, use all of dataset_a only | |
| if self.ratio_b_per_a == 0 or len_b == 0: | |
| a_indices = list(range(len_a)) | |
| random.shuffle(a_indices) | |
| self.mapping = [(0, i) for i in a_indices] | |
| return | |
| # Determine how many we can draw without replacement | |
| # n_a limited by A size and B availability according to ratio | |
| n_a_by_ratio = int(len_b / self.ratio_b_per_a) | |
| n_a = min(len_a, max(1, n_a_by_ratio)) | |
| n_b = min(len_b, max(1, int(round(n_a * self.ratio_b_per_a)))) | |
| a_indices = list(range(len_a)) | |
| b_indices = list(range(len_b)) | |
| random.shuffle(a_indices) | |
| random.shuffle(b_indices) | |
| a_indices = a_indices[: n_a] | |
| b_indices = b_indices[: n_b] | |
| mixed = [(0, i) for i in a_indices] + [(1, i) for i in b_indices] | |
| random.shuffle(mixed) | |
| self.mapping = mixed | |
| def __len__(self): | |
| return len(self.mapping) | |
| def __getitem__(self, idx): | |
| which, real_idx = self.mapping[idx] | |
| if which == 0: | |
| return self.dataset_a[real_idx] | |
| else: | |
| return self.dataset_b[real_idx] | |
| def make_train_dataset_mixed(args, tokenizers, accelerator=None): | |
| """ | |
| Create a mixed dataset from: | |
| - Local edits dataset (this file) | |
| - Inpaint-mask JSONL dataset (jsonl_datasets_kontext.make_train_dataset_inpaint_mask) | |
| Ratio control via args.mix_ratio (float): | |
| - 0 => only local edits dataset | |
| - 1 => 1:1 mix (local:inpaint) | |
| - 2 => 1:2 mix, etc. | |
| Requirements: | |
| - args.local_edits_json and related dirs must be set for local edits | |
| - args.train_data_dir must be set for the JSONL inpaint dataset | |
| """ | |
| ds_local = make_train_dataset_local_edits(args, tokenizers, accelerator) | |
| ds_inpaint = make_train_dataset_inpaint_mask(args, tokenizers, accelerator) | |
| mix_ratio = getattr(args, "mix_ratio", 1.0) | |
| return BalancedMixDataset(ds_local, ds_inpaint, ratio_b_per_a=mix_ratio) |