import spaces import numpy as np import cv2 import torch from typing import Union def contour_to_points_and_mask(contour: np.ndarray, image_shape: tuple) -> tuple[np.ndarray, np.ndarray]: """Convert a contour to a set of points and binary mask. This function takes a contour and creates both a binary mask and a list of points that lie within the contour. The points are represented in (x, y) coordinates. Args: contour (np.ndarray): Input contour of shape (N, 2) or (N, 1, 2) where N is the number of points. Each point should be in (x, y) format. image_shape (tuple): Shape of the output mask as (height, width). Returns: tuple: - np.ndarray: Array of points in (x, y) format with shape (M, 2), where M is the number of points inside the contour. Returns empty array of shape (0, 2) if contour is empty. - np.ndarray: Binary mask of shape image_shape where pixels inside the contour are 255 and outside are 0. """ if len(contour) == 0: return np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8) # Create empty mask and fill the contour in the mask mask = np.zeros(image_shape, dtype=np.uint8) cv2.drawContours(mask, [contour.reshape(-1, 1, 2)], -1, 255, cv2.FILLED) # Get points inside contour (y, x) and convert to (x, y) points = np.column_stack(np.where(mask)).astype(np.int32)[:, [1, 0]] # Return empty array if no points found if len(points) == 0: points = np.zeros((0, 2), dtype=np.int32) return points, mask def find_control_points( region_points: torch.Tensor, source_control_points: torch.Tensor, target_control_points: torch.Tensor, distance_threshold: float = 1e-6 ) -> tuple[torch.Tensor, torch.Tensor]: """Find control points that match points within a region. This function identifies which control points lie within or very close to the specified region points. It matches source control points to region points and returns both source and corresponding target control points that satisfy the distance threshold criterion. Args: region_points (torch.Tensor): Points defining a region, shape (N, 2). Each point is in (x, y) format. source_control_points (torch.Tensor): Source control points, shape (M, 2). Each point is in (x, y) format. target_control_points (torch.Tensor): Target control points, shape (M, 2). Must have same first dimension as source_control_points. distance_threshold (float, optional): Maximum distance for a point to be considered matching. Defaults to 1e-6. Returns: tuple[torch.Tensor, torch.Tensor]: - Matched source control points, shape (K, 2) where K ≤ M - Corresponding target control points, shape (K, 2) If no matches found or inputs empty, returns empty tensors of shape (0, 2) """ # Handle empty input cases if len(region_points) == 0 or len(source_control_points) == 0: return ( torch.zeros((0, 2), device=source_control_points.device), torch.zeros((0, 2), device=target_control_points.device) ) # Calculate pairwise distances between source control points and region points distances = torch.cdist(source_control_points, region_points) # Find points that are within threshold distance of any region point min_distances = distances.min(dim=1)[0] matching_indices = min_distances < distance_threshold # Return matched pairs of control points return source_control_points[matching_indices], target_control_points[matching_indices] def interpolate_points_with_weighted_directions( points: torch.Tensor, reference_points: torch.Tensor, direction_vectors: torch.Tensor, max_reference_points: int = 100, num_nearest_neighbors: int = 4, eps: float = 1e-6 ) -> torch.Tensor: """Interpolate points based on weighted directions from nearest reference points. This function moves each point by a weighted combination of direction vectors. The weights are determined by the inverse distances to the nearest reference points. If there are too many reference points, they are subsampled for efficiency. Args: points (torch.Tensor): Points to interpolate, shape (N, 2) in (x, y) format reference_points (torch.Tensor): Reference point locations, shape (M, 2) direction_vectors (torch.Tensor): Direction vectors for each reference point, shape (M, 2), must match reference_points first dimension max_reference_points (int, optional): Maximum number of reference points to use. If exceeded, points are subsampled. Defaults to 100. num_nearest_neighbors (int, optional): Number of nearest neighbors to consider for interpolation. Defaults to 4. eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6. Returns: torch.Tensor: Interpolated points with shape (N, 2). If input points or references are empty, returns the input points unchanged. """ # Handle empty input cases if len(points) == 0 or len(reference_points) == 0: return points # Handle single reference point case if len(reference_points) == 1: return points + direction_vectors # Subsample reference points if too many if len(reference_points) > max_reference_points: indices = torch.linspace(0, len(reference_points)-1, max_reference_points).long() reference_points = reference_points[indices] direction_vectors = direction_vectors[indices] # Calculate distances to all reference points distances = torch.cdist(points, reference_points) # Find k nearest neighbors (k = min(num_nearest_neighbors, num_references)) k = min(num_nearest_neighbors, len(reference_points)) topk_distances, neighbor_indices = torch.topk( distances, k=k, dim=1, largest=False ) # Calculate weights based on inverse distances weights = 1.0 / (topk_distances + eps) weights = weights / weights.sum(dim=1, keepdim=True) # Get directions for nearest neighbors and compute weighted average neighbor_directions = direction_vectors[neighbor_indices] weighted_directions = (weights.unsqueeze(-1) * neighbor_directions).sum(dim=1) # Apply weighted directions and round to nearest integer interpolated_points = (points + weighted_directions).round().float() return interpolated_points def get_points_within_image_bounds( points: torch.Tensor, image_shape: tuple[int, int] ) -> torch.Tensor: """Create a boolean mask for points that lie within image boundaries. Identifies which points from the input tensor fall within valid image coordinates. Points are assumed to be in (x, y) format, while image_shape is in (height, width) format. Args: points (torch.Tensor): Points to check, shape (N, 2) in (x, y) format. x coordinates correspond to width/columns y coordinates correspond to height/rows image_shape (tuple[int, int]): Image dimensions as (height, width). Returns: torch.Tensor: Boolean mask of shape (N,) where True indicates the point is within bounds. Returns empty tensor of shape (0,) if input is empty. """ # Handle empty input case if len(points) == 0: return torch.zeros(0, dtype=torch.bool, device=points.device) # Unpack image dimensions height, width = image_shape # Check both x and y coordinates are within bounds x_in_bounds = (points[:, 0] >= 0) & (points[:, 0] < width) y_in_bounds = (points[:, 1] >= 0) & (points[:, 1] < height) # Combine conditions valid_points_mask = x_in_bounds & y_in_bounds return valid_points_mask @spaces.GPU def bi_warp( region_mask: np.ndarray, control_points: Union[np.ndarray, torch.Tensor], kernel_size: int = 5 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Generate corresponding source/target points and inpainting mask for masked regions. Args: region_mask: Binary mask defining regions of interest (2D array with 0s and 1s) control_points: Alternating source and target control points. Shape (N*2, 2) kernel_size: Controls dilation kernel size. Must be odd number or 0. Contour thickness will be (kernel_size-1)*2 (default: 5) Set to 0 for no contour drawing and no dilation. Returns: tuple containing: - Source points (M, 2) - Target points (M, 2) - Inpainting mask combined with target contour mask """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') image_shape = region_mask.shape # Ensure kernel_size is odd or 0 kernel_size = max(0, kernel_size) if kernel_size > 0 and kernel_size % 2 == 0: kernel_size += 1 # 1. Initialize tensors and masks control_points = torch.tensor(control_points, dtype=torch.float32, device=device) if not isinstance(control_points, torch.Tensor) else control_points source_control_points = control_points[0:-1:2] target_control_points = control_points[1::2] combined_source_mask = np.zeros(image_shape, dtype=np.uint8) combined_target_mask = np.zeros(image_shape, dtype=np.uint8) region_mask_binary = np.where(region_mask > 0, 1, 0).astype(np.uint8) contour_mask = np.zeros(image_shape, dtype=np.uint8) # 2. Process regions contours = cv2.findContours(region_mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] all_source_points = [] all_target_points = [] for contour in contours: if len(contour) == 0: continue # 3. Get source region points and mask source_contour = torch.from_numpy(contour[:, 0, :]).float().to(device) source_region_points, source_mask = contour_to_points_and_mask(contour[:, 0, :], image_shape) source_mask = (source_mask > 0).astype(np.uint8) if len(source_region_points) == 0: continue source_region_points = torch.from_numpy(source_region_points).float().to(device) # 4. Transform points source, target = find_control_points(source_region_points, source_control_points, target_control_points) if len(source) == 0: continue directions = target - source target_contour = interpolate_points_with_weighted_directions(source_contour, source, directions) interpolated_target = interpolate_points_with_weighted_directions(source_region_points, source, directions) # 5. Get target region points and mask target_region_points, target_mask = contour_to_points_and_mask(target_contour.cpu().int().numpy(), image_shape) target_mask = (target_mask > 0).astype(np.uint8) if len(target_region_points) == 0: continue # Draw target contour target_contour_np = target_contour.cpu().int().numpy() if kernel_size > 0: cv2.drawContours(contour_mask, [target_contour_np], -1, 1, kernel_size) target_region = torch.from_numpy(target_region_points).float().to(device) # 6. Apply reverse transformation back_directions = source_region_points - interpolated_target interpolated_source = interpolate_points_with_weighted_directions(target_region, interpolated_target, back_directions) # 7. Filter valid points valid_mask = get_points_within_image_bounds(interpolated_source, image_shape) if valid_mask.any(): all_source_points.append(interpolated_source[valid_mask]) all_target_points.append(target_region[valid_mask]) combined_source_mask = np.logical_or(combined_source_mask, source_mask).astype(np.uint8) combined_target_mask = np.logical_or(combined_target_mask, target_mask).astype(np.uint8) # 8. Handle empty case if not all_source_points: return np.zeros((0, 2), dtype=np.int32), np.zeros((0, 2), dtype=np.int32), np.zeros(image_shape, dtype=np.uint8) # 9. Finalize outputs final_source = torch.cat(all_source_points).cpu().numpy().astype(np.int32) final_target = torch.cat(all_target_points).cpu().numpy().astype(np.int32) # Create and combine masks inpaint_mask = np.logical_and(combined_source_mask, np.logical_not(combined_target_mask)).astype(np.uint8) if kernel_size > 0: kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) inpaint_mask = cv2.dilate(inpaint_mask, kernel) final_mask = np.logical_or(inpaint_mask, contour_mask).astype(np.uint8) return final_source, final_target, final_mask