File size: 13,033 Bytes
6fce8cc
 
6678b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fce8cc
6678b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import spaces

import numpy as np
import cv2
import torch
from typing import Union

def contour_to_points_and_mask(contour: np.ndarray, image_shape: tuple) -> tuple[np.ndarray, np.ndarray]:
    """Convert a contour to a set of points and binary mask.
    
    This function takes a contour and creates both a binary mask and a list of points
    that lie within the contour. The points are represented in (x, y) coordinates.
    
    Args:
        contour (np.ndarray): Input contour of shape (N, 2) or (N, 1, 2) where N is 
            the number of points. Each point should be in (x, y) format.
        image_shape (tuple): Shape of the output mask as (height, width).
            
    Returns:
        tuple:
            - np.ndarray: Array of points in (x, y) format with shape (M, 2),
              where M is the number of points inside the contour.
              Returns empty array of shape (0, 2) if contour is empty.
            - np.ndarray: Binary mask of shape image_shape where pixels inside
              the contour are 255 and outside are 0.
    """
    if len(contour) == 0:
        return np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8)
    
    # Create empty mask and fill the contour in the mask
    mask = np.zeros(image_shape, dtype=np.uint8)
    cv2.drawContours(mask, [contour.reshape(-1, 1, 2)], -1, 255, cv2.FILLED)
    
    # Get points inside contour (y, x) and convert to (x, y)
    points = np.column_stack(np.where(mask)).astype(np.int32)[:, [1, 0]]
    
    # Return empty array if no points found
    if len(points) == 0:
        points = np.zeros((0, 2), dtype=np.int32)
        
    return points, mask

def find_control_points(
    region_points: torch.Tensor,
    source_control_points: torch.Tensor,
    target_control_points: torch.Tensor,
    distance_threshold: float = 1e-6
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find control points that match points within a region.

    This function identifies which control points lie within or very close to
    the specified region points. It matches source control points to region points
    and returns both source and corresponding target control points that satisfy
    the distance threshold criterion.

    Args:
        region_points (torch.Tensor): Points defining a region, shape (N, 2).
            Each point is in (x, y) format.
        source_control_points (torch.Tensor): Source control points, shape (M, 2).
            Each point is in (x, y) format.
        target_control_points (torch.Tensor): Target control points, shape (M, 2).
            Must have same first dimension as source_control_points.
        distance_threshold (float, optional): Maximum distance for a point to be
            considered matching. Defaults to 1e-6.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: 
            - Matched source control points, shape (K, 2) where K ≤ M
            - Corresponding target control points, shape (K, 2)
            If no matches found or inputs empty, returns empty tensors of shape (0, 2)
    """
    # Handle empty input cases
    if len(region_points) == 0 or len(source_control_points) == 0:
        return (
            torch.zeros((0, 2), device=source_control_points.device),
            torch.zeros((0, 2), device=target_control_points.device)
        )

    # Calculate pairwise distances between source control points and region points
    distances = torch.cdist(source_control_points, region_points)
    
    # Find points that are within threshold distance of any region point
    min_distances = distances.min(dim=1)[0]
    matching_indices = min_distances < distance_threshold

    # Return matched pairs of control points
    return source_control_points[matching_indices], target_control_points[matching_indices]

def interpolate_points_with_weighted_directions(
    points: torch.Tensor,
    reference_points: torch.Tensor,
    direction_vectors: torch.Tensor,
    max_reference_points: int = 100,
    num_nearest_neighbors: int = 4,
    eps: float = 1e-6
) -> torch.Tensor:
    """Interpolate points based on weighted directions from nearest reference points.
    
    This function moves each point by a weighted combination of direction vectors.
    The weights are determined by the inverse distances to the nearest reference points.
    If there are too many reference points, they are subsampled for efficiency.

    Args:
        points (torch.Tensor): Points to interpolate, shape (N, 2) in (x, y) format
        reference_points (torch.Tensor): Reference point locations, shape (M, 2)
        direction_vectors (torch.Tensor): Direction vectors for each reference point,
            shape (M, 2), must match reference_points first dimension
        max_reference_points (int, optional): Maximum number of reference points to use.
            If exceeded, points are subsampled. Defaults to 100.
        num_nearest_neighbors (int, optional): Number of nearest neighbors to consider
            for interpolation. Defaults to 4.
        eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6.

    Returns:
        torch.Tensor: Interpolated points with shape (N, 2). If input points or
            references are empty, returns the input points unchanged.
    """
    # Handle empty input cases
    if len(points) == 0 or len(reference_points) == 0:
        return points
    
    # Handle single reference point case
    if len(reference_points) == 1:
        return points + direction_vectors

    # Subsample reference points if too many
    if len(reference_points) > max_reference_points:
        indices = torch.linspace(0, len(reference_points)-1, max_reference_points).long()
        reference_points = reference_points[indices]
        direction_vectors = direction_vectors[indices]

    # Calculate distances to all reference points
    distances = torch.cdist(points, reference_points)
    
    # Find k nearest neighbors (k = min(num_nearest_neighbors, num_references))
    k = min(num_nearest_neighbors, len(reference_points))
    topk_distances, neighbor_indices = torch.topk(
        distances, 
        k=k, 
        dim=1, 
        largest=False
    )
    
    # Calculate weights based on inverse distances
    weights = 1.0 / (topk_distances + eps)
    weights = weights / weights.sum(dim=1, keepdim=True)
    
    # Get directions for nearest neighbors and compute weighted average
    neighbor_directions = direction_vectors[neighbor_indices]
    weighted_directions = (weights.unsqueeze(-1) * neighbor_directions).sum(dim=1)
    
    # Apply weighted directions and round to nearest integer
    interpolated_points = (points + weighted_directions).round().float()
    
    return interpolated_points

def get_points_within_image_bounds(
    points: torch.Tensor,
    image_shape: tuple[int, int]
) -> torch.Tensor:
    """Create a boolean mask for points that lie within image boundaries.

    Identifies which points from the input tensor fall within valid image coordinates.
    Points are assumed to be in (x, y) format, while image_shape is in (height, width) format.

    Args:
        points (torch.Tensor): Points to check, shape (N, 2) in (x, y) format.
            x coordinates correspond to width/columns
            y coordinates correspond to height/rows
        image_shape (tuple[int, int]): Image dimensions as (height, width).

    Returns:
        torch.Tensor: Boolean mask of shape (N,) where True indicates the point
            is within bounds. Returns empty tensor of shape (0,) if input is empty.
    """
    # Handle empty input case
    if len(points) == 0:
        return torch.zeros(0, dtype=torch.bool, device=points.device)

    # Unpack image dimensions
    height, width = image_shape

    # Check both x and y coordinates are within bounds
    x_in_bounds = (points[:, 0] >= 0) & (points[:, 0] < width)
    y_in_bounds = (points[:, 1] >= 0) & (points[:, 1] < height)
    
    # Combine conditions
    valid_points_mask = x_in_bounds & y_in_bounds

    return valid_points_mask

@spaces.GPU
def bi_warp(
    region_mask: np.ndarray,
    control_points: Union[np.ndarray, torch.Tensor],
    kernel_size: int = 5
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Generate corresponding source/target points and inpainting mask for masked regions.
    
    Args:
        region_mask: Binary mask defining regions of interest (2D array with 0s and 1s)
        control_points: Alternating source and target control points. Shape (N*2, 2)
        kernel_size: Controls dilation kernel size. Must be odd number or 0.
                    Contour thickness will be (kernel_size-1)*2 (default: 5)
                    Set to 0 for no contour drawing and no dilation.
    
    Returns:
        tuple containing:
            - Source points (M, 2)
            - Target points (M, 2) 
            - Inpainting mask combined with target contour mask
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_shape = region_mask.shape
    
    # Ensure kernel_size is odd or 0
    kernel_size = max(0, kernel_size)
    if kernel_size > 0 and kernel_size % 2 == 0:
        kernel_size += 1

    # 1. Initialize tensors and masks
    control_points = torch.tensor(control_points, dtype=torch.float32, device=device) if not isinstance(control_points, torch.Tensor) else control_points
    source_control_points = control_points[0:-1:2]
    target_control_points = control_points[1::2]
    
    combined_source_mask = np.zeros(image_shape, dtype=np.uint8)
    combined_target_mask = np.zeros(image_shape, dtype=np.uint8)
    region_mask_binary = np.where(region_mask > 0, 1, 0).astype(np.uint8)
    contour_mask = np.zeros(image_shape, dtype=np.uint8)

    # 2. Process regions
    contours = cv2.findContours(region_mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
    all_source_points = []
    all_target_points = []

    for contour in contours:
        if len(contour) == 0:
            continue

        # 3. Get source region points and mask
        source_contour = torch.from_numpy(contour[:, 0, :]).float().to(device)
        source_region_points, source_mask = contour_to_points_and_mask(contour[:, 0, :], image_shape)
        source_mask = (source_mask > 0).astype(np.uint8)
        
        if len(source_region_points) == 0:
            continue
            
        source_region_points = torch.from_numpy(source_region_points).float().to(device)

        # 4. Transform points
        source, target = find_control_points(source_region_points, source_control_points, target_control_points)
        if len(source) == 0:
            continue

        directions = target - source
        target_contour = interpolate_points_with_weighted_directions(source_contour, source, directions)
        interpolated_target = interpolate_points_with_weighted_directions(source_region_points, source, directions)

        # 5. Get target region points and mask
        target_region_points, target_mask = contour_to_points_and_mask(target_contour.cpu().int().numpy(), image_shape)
        target_mask = (target_mask > 0).astype(np.uint8)
        
        if len(target_region_points) == 0:
            continue
            
        # Draw target contour
        target_contour_np = target_contour.cpu().int().numpy()
        if kernel_size > 0:
            cv2.drawContours(contour_mask, [target_contour_np], -1, 1, kernel_size)
            
        target_region = torch.from_numpy(target_region_points).float().to(device)

        # 6. Apply reverse transformation
        back_directions = source_region_points - interpolated_target
        interpolated_source = interpolate_points_with_weighted_directions(target_region, interpolated_target, back_directions)

        # 7. Filter valid points
        valid_mask = get_points_within_image_bounds(interpolated_source, image_shape)
        if valid_mask.any():
            all_source_points.append(interpolated_source[valid_mask])
            all_target_points.append(target_region[valid_mask])
            combined_source_mask = np.logical_or(combined_source_mask, source_mask).astype(np.uint8)
            combined_target_mask = np.logical_or(combined_target_mask, target_mask).astype(np.uint8)

    # 8. Handle empty case
    if not all_source_points:
        return np.zeros((0, 2), dtype=np.int32), np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8)

    # 9. Finalize outputs
    final_source = torch.cat(all_source_points).cpu().numpy().astype(np.int32)
    final_target = torch.cat(all_target_points).cpu().numpy().astype(np.int32)
    
    # Create and combine masks
    inpaint_mask = np.logical_and(combined_source_mask, np.logical_not(combined_target_mask)).astype(np.uint8)
    if kernel_size > 0:
        kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)
        inpaint_mask = cv2.dilate(inpaint_mask, kernel)
    final_mask = np.logical_or(inpaint_mask, contour_mask).astype(np.uint8)

    return final_source, final_target, final_mask