File size: 15,790 Bytes
00274d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
from typing import Union
import torch
import random
import numpy as np
import cv2
import os


def create_random_mask(batch_size, num_frames, height, width, device, dtype, shape_type=None):
    """
    Create random masks for sketch frames.
    
    Args:
        batch_size: Batch size
        num_frames: Number of frames to mask
        height, width: Image dimensions
        device: Device for tensor
        dtype: Data type for tensor
        mask_area_ratio: Ratio of area to mask (0-1)
        shape_type: Type of shape for masking ('square', 'circle', 'random'). If None, one is randomly selected.
    
    Returns:
        Mask tensor in [b, 1, num_frames, height, width] where 0 indicates areas to mask (inverse of previous implementation)
    """
    # Initialize with ones (unmasked)
    masks = torch.ones(batch_size, 1, num_frames, height, width, device=device, dtype=dtype)
    
    for b in range(batch_size):
        for f in range(num_frames):
            # Randomly select shape type if not specified
            if shape_type is None:
                shape_type = random.choice(['square', 'circle', 'random'])
            
            # Create numpy mask for easier shape drawing
            mask = np.zeros((height, width), dtype=np.float32)
            
            if shape_type == 'square':
                # Random squares
                num_squares = random.randint(1, 5)
                for _ in range(num_squares):
                    # Random square size (proportional to image dimensions)
                    max_size = min(height, width)
                    size = random.randint(max_size // 4, max_size)
                    
                    # Random position
                    x = random.randint(0, width - size)
                    y = random.randint(0, height - size)
                    
                    # Draw square
                    mask[y:y+size, x:x+size] = 1.0
                    
            elif shape_type == 'circle':
                # Random circles
                num_circles = random.randint(1, 5)
                for _ in range(num_circles):
                    # Random radius (proportional to image dimensions)
                    max_radius = min(height, width) // 2
                    radius = random.randint(max_radius // 4, max_radius)
                    
                    # Random center
                    center_x = random.randint(radius, width - radius)
                    center_y = random.randint(radius, height - radius)
                    
                    # Draw circle
                    cv2.circle(mask, (center_x, center_y), radius, 1.0, -1)
                    
            elif shape_type == 'random':
                # Create connected random shape with cv2
                num_points = random.randint(5, 16)
                points = []
                
                # Generate random points
                for _ in range(num_points):
                    x = random.randint(0, width - 1)
                    y = random.randint(0, height - 1)
                    points.append([x, y])
                
                # Convert to numpy array for cv2
                points = np.array(points, dtype=np.int32)
                
                # Draw filled polygon
                cv2.fillPoly(mask, [points], 1.0)
            
            # Convert numpy mask to tensor and subtract from ones (inverse the mask)
            masks[b, 0, f] = 1.0 - torch.from_numpy(mask).to(device=device, dtype=dtype)
    
    return masks


@torch.no_grad()
def extract_img_to_sketch(_sketch_model, _img, model_name="random"):
    """
    Return sketch: [-1, 1]
    """
    orig_shape = (_img.shape[-2], _img.shape[-1])
    with torch.amp.autocast(dtype=torch.float32, device_type="cuda"):
        reshaped_img = torch.nn.functional.interpolate(_img, (2048, 2048))
        sketch = _sketch_model(reshaped_img, model_name=model_name)
        sketch = torch.nn.functional.interpolate(sketch, orig_shape)
    if sketch.shape[1] == 1:
        sketch = sketch.repeat(1, 3, 1, 1)
    return sketch


def video_to_frame_and_sketch(
    sketch_model,
    original_video,
    max_num_preserved_sketch_frames=2,
    max_num_preserved_image_frames=1,
    min_num_preserved_sketch_frames=2,
    min_num_preserved_image_frames=1,
    model_name=None,
    detach_image_and_sketch=False,
    equally_spaced_preserve_sketch=False,
    apply_sketch_mask=False,
    sketch_mask_ratio=0.2,
    sketch_mask_shape=None,
    no_first_sketch: Union[bool, float] = False,
    video_clip_names=None,
    is_flux_sketch_available=None,
    is_evaluation=False,
):
    """
    Args:
        sketch_model: torch.nn.Module, a sketch pool for extracting sketches from images
        original_video: torch.Tensor, shape=(batch_size, num_channels, num_frames, height, width)
        max_num_preserved_sketch_frames: int, maximum number of preserved sketch frames
        max_num_preserved_image_frames: int, maximum number of preserved image frames
        min_num_preserved_sketch_frames: int, minimum number of preserved sketch frames
        min_num_preserved_image_frames: int, minimum number of preserved image frames
        model_name: str, name of the sketch model. If None, randomly select from ["lineart", "lineart_anime", "anime2sketch"]. Default: None.
        equally_spaced_preserve_sketch: bool, whether to preserve sketches at equally spaced intervals. Default: False.
        apply_sketch_mask: bool, whether to apply random masking to sketch frames. Default: False.
        sketch_mask_ratio: float, ratio of frames to mask (0-1). Default: 0.2.
        sketch_mask_shape: str, shape type for masking ('square', 'circle', 'random'). If None, randomly selected. Default: None.
    Returns:
        conditional_image: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
        preserving_image_mask: torch.Tensor, shape=(batch_size, num_frames, 1, height, width)
        full_sketch_frames: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
        sketch_local_mask: torch.Tensor, shape=(batch_size, 1, num_frames, height, width) or None if apply_sketch_mask=False
    """
    video_shape = original_video.shape
    video_dtype = original_video.dtype
    video_device = original_video.device

    if min_num_preserved_sketch_frames is None or min_num_preserved_sketch_frames < 2:
        min_num_preserved_sketch_frames = 2  # Minimum num: 2 (the first and the last)
    num_preserved_sketch_frames = random.randint(min_num_preserved_sketch_frames, max_num_preserved_sketch_frames)
    num_preserved_sketch_frames = min(num_preserved_sketch_frames, video_shape[2])
    
    # Always include first and last frames
    if video_clip_names is not None and is_flux_sketch_available is not None:
        if is_flux_sketch_available[0]:
            num_preserved_sketch_frames = 2
    
    if isinstance(no_first_sketch, float):
        no_first_sketch = random.random() < no_first_sketch
    
    if equally_spaced_preserve_sketch:
        preserved_sketch_indices = torch.linspace(0, video_shape[2] - 1, num_preserved_sketch_frames).long().tolist()
        if no_first_sketch:
            preserved_sketch_indices = preserved_sketch_indices[1:]
    else:
        if no_first_sketch:
            preserved_sketch_indices = [video_shape[2] - 1] 
        else:   
            preserved_sketch_indices = [0, video_shape[2] - 1] 
        # If we need more frames than just first and last
        if num_preserved_sketch_frames > 2 and video_shape[2] > 4:
            # Create set of all valid candidates (excluding first, last and their adjacent frames)
            # Exclude indices adjacent to first and last
            candidates = set(range(2, video_shape[2] - 2))
            
            # Determine how many additional frames to select
            additional_frames_needed = min(num_preserved_sketch_frames - 2, len(candidates))
            
            # Keep selecting frames until we have enough or run out of candidates
            additional_indices = []
            while len(additional_indices) < additional_frames_needed and candidates:
                # Convert set to list for random selection
                candidate_list = list(candidates)
                # Select a random candidate
                idx = random.choice(candidate_list)
                additional_indices.append(idx)
                
                # Remove selected index and adjacent indices from candidates
                candidates.remove(idx)
                if idx - 1 in candidates:
                    candidates.remove(idx - 1)
                if idx + 1 in candidates:
                    candidates.remove(idx + 1)
            
            preserved_sketch_indices.extend(additional_indices)
            preserved_sketch_indices.sort()
            
    # Indices to preserve has been determined. 
    # Later code will not care the number of preserved frames but rely on the indices only.
    preserved_image_indices = [0]
    if max_num_preserved_image_frames is not None and max_num_preserved_image_frames > 1:
        max_num_preserved_image_frames -= 1
        if min_num_preserved_image_frames is None or min_num_preserved_image_frames < 1:
            min_num_preserved_image_frames = 1
        min_num_preserved_image_frames -= 1
        other_indices = torch.tensor([i for i in range(video_shape[2]) if i not in preserved_sketch_indices])
        max_num_preserved_image_frames = min(max_num_preserved_image_frames, len(other_indices))
        min_num_preserved_image_frames = min(min_num_preserved_image_frames, max_num_preserved_image_frames)
        num_preserved_image_frames = random.randint(min_num_preserved_image_frames, max_num_preserved_image_frames)
        other_indices = other_indices[torch.randperm(len(other_indices))]
        if num_preserved_image_frames > 0:
            preserved_image_indices.extend(other_indices[:num_preserved_image_frames])
    
    preserved_condition_mask = torch.zeros(size=(video_shape[0], video_shape[2]), dtype=video_dtype, device=video_device)  # [b, t]
    masked_condition_video = torch.zeros_like(original_video)   # [b, c, t, h, w]
    full_sketch_frames = torch.zeros_like(original_video)  # [b, c, t, h, w]
    
    if detach_image_and_sketch:
        preserved_condition_mask_sketch = torch.zeros_like(preserved_condition_mask)
        masked_condition_video_sketch = torch.zeros_like(masked_condition_video)
        if 0 not in preserved_sketch_indices and not no_first_sketch:
            preserved_sketch_indices.append(0)
    else:
        preserved_condition_mask_sketch = None
        masked_condition_video_sketch = None

    for _idx in preserved_image_indices:
        preserved_condition_mask[:, _idx] = 1.0
        masked_condition_video[:, :, _idx, :, :] = original_video[:, :, _idx, :, :]
    
    # Set up sketch_local_mask if masking is applied
    sketch_local_mask = None
        
    if apply_sketch_mask:
        # Create a full-sized mask initialized to all ones (unmasked)
        sketch_local_mask = torch.ones(
            video_shape[0], video_shape[2], video_shape[3], video_shape[4],
            device=video_device,
            dtype=video_dtype
        ).unsqueeze(1)  # Add channel dimension to get [b, 1, t, h, w]
        
        if not is_evaluation and random.random() < sketch_mask_ratio:
            # For preserved frames, apply random masking
            for i, frame_idx in enumerate(preserved_sketch_indices):
                if i == 0:
                    # First frame is not masked
                    continue
                # Create masks only for preserved frames
                frame_masks = create_random_mask(
                    batch_size=video_shape[0],
                    num_frames=1,  # Just one frame at a time
                    height=video_shape[3],
                    width=video_shape[4],
                    device=video_device,
                    dtype=video_dtype,
                    # mask_area_ratio=0.4 * random.random() + 0.1,
                    shape_type=sketch_mask_shape
                )
                
                # Set the mask for this preserved frame
                sketch_local_mask[:, :, frame_idx:frame_idx+1, :, :] = frame_masks
    
    # Produce sketches for preserved frames
    # Sketches can either be 1) calculated from sketch pool or 2) loaded from the flux sketch directory
    if is_flux_sketch_available is not None and is_flux_sketch_available[0]:
        should_use_flux_sketch = random.random() < 0.75 if not is_evaluation else True
    else:
        should_use_flux_sketch = False
        
    cur_model_name = "flux" if should_use_flux_sketch else random.choice(["lineart", "lineart_anime", "anime2sketch"]) if model_name is None else model_name # "anime2sketch"
    # cur_model_name = "anyline"
    for _idx in preserved_sketch_indices:
        sketch_frame = None
        if should_use_flux_sketch:
            # Load flux sketch
            sketech_path = f"/group/40005/gzhiwang/iclora/linearts/{video_clip_names[0]}/{_idx}.lineart.png"
            print(f"Loading flux sketch from {sketech_path}...")
            if os.path.exists(sketech_path):
                sketch_frame = cv2.imread(sketech_path)
                sketch_frame = cv2.cvtColor(sketch_frame, cv2.COLOR_BGR2RGB)
                # resize to 480p
                sketch_frame = cv2.resize(sketch_frame, (video_shape[4], video_shape[3]))
                sketch_frame = torch.from_numpy(sketch_frame).to(video_device, dtype=video_dtype)
                # Normalize to [-1, 1]
                sketch_frame = sketch_frame / 255.0 * 2.0 - 1.0
                sketch_frame = sketch_frame.permute(2, 0, 1)
                sketch_frame = sketch_frame.unsqueeze(0)
            else:
                print(f"FLUX Sketch path {sketech_path} does not exist. Falling back to sketch pool.")
            #     raise ValueError(f"FLUX Sketch path {sketech_path} does not exist.")
        if sketch_frame is None:
            # Calculate sketch from sketch pool
            sketch_frame = extract_img_to_sketch(
                    sketch_model, original_video[:, :, _idx, :, :].float(),
                    model_name=cur_model_name).to(video_device, dtype=video_dtype)
        # Convert white BG (from sketch pool or loaded from flux sketch files) to black BG (for training)
        sketch_frame = -torch.clip(sketch_frame, -1, 1)
        full_sketch_frames[:, :, _idx, :, :] = sketch_frame

    if len(preserved_sketch_indices) > 0:
        _mask_to_add = preserved_condition_mask_sketch if detach_image_and_sketch else preserved_condition_mask
        _video_to_add = masked_condition_video_sketch if detach_image_and_sketch else masked_condition_video
        if not detach_image_and_sketch:
            preserved_sketch_indices = preserved_sketch_indices[1:]
        
        # Apply masking to sketch frames if required
        if apply_sketch_mask and sketch_local_mask is not None:
            # sketch_local_mask: [b, 1, t, h, w]
            for _idx in preserved_sketch_indices:
                _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
                _video_to_add[:, :, _idx, :, :] = torch.where(sketch_local_mask[:, 0:1, _idx, :, :] == 0, -1.0, full_sketch_frames[:, :, _idx, :, :])
        else:
            for _idx in preserved_sketch_indices:
                _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
                _video_to_add[:, :, _idx, :, :] = full_sketch_frames[:, :, _idx, :, :]
                     
    return masked_condition_video, preserved_condition_mask, masked_condition_video_sketch, preserved_condition_mask_sketch, full_sketch_frames, sketch_local_mask, cur_model_name