File size: 3,318 Bytes
456aee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from scipy.spatial import Delaunay

def pca_reduce_to_2d(points):
    u, s, v = torch.svd(points)
    return points @ v[:, :2]

@torch.no_grad()
def compute_delaunay(points):
    """Compute Delaunay triangulation of points"""
    points_2d = pca_reduce_to_2d(points)
    return Delaunay(points_2d.cpu().numpy()).simplices
    # return Delaunay(points.cpu().numpy()).simplices
  
def compute_riemann_curvature_loss(points, simplices=None, domain_min=0, domain_max=1):
    """
    Calculate loss based on approximated Riemann curvature.
    
    The loss measures deviations from uniform metric tensors across simplices,
    which approximates variations in Riemann curvature.
    """
    if simplices is None:
        simplices = compute_delaunay(points)
        
    ideal_det = torch.tensor(1.0, device=points.device, dtype=torch.float64)
    
    # Process each simplex in parallel 
    simplices_tensor = torch.tensor(simplices, device=points.device)
    
    # Extract points that form each simplex
    simplex_points = points[simplices_tensor]
    
    # Calculate edge vectors from the first point of each simplex
    edges = simplex_points[:, 1:] - simplex_points[:, 0].unsqueeze(1)
    
    # Compute metric tensors (Gram matrices) for each simplex
    metric_tensors = torch.matmul(edges, edges.transpose(1, 2))
    
    # Calculate determinants (related to volume distortion)
    dets = torch.linalg.det(metric_tensors)
    
    # Penalize deviations from constant determinant
    valid_dets = dets[dets > 0]
    total_curvature = torch.mean((valid_dets - ideal_det)**2)
    return total_curvature
    # # Add boundary repulsion to keep points inside domain
    # boundary_penalty = torch.mean(torch.relu(domain_min - points)) + \
    #                     torch.mean(torch.relu(points - domain_max))
    
    # # Add point repulsion term for additional stability
    # dist_matrix = torch.cdist(points, points)
    # # Set diagonal to large value to avoid self-repulsion
    # mask = torch.eye(points.shape[0], device=points.device).bool()
    # dist_matrix = dist_matrix + mask * 1000
    # repulsion = torch.mean(1.0 / (dist_matrix + 1e-8))
    
    # return total_curvature + 10.0 * boundary_penalty + 0.01 * repulsion


def compute_axis_align_loss(data):
    """ Encourage axis alignment by minimizing off-diagonal elements in the covariance matrix """
    n, d = data.shape
    centered_data = data - data.mean(dim=0)  # Center the data
    cov_matrix = (centered_data.T @ centered_data) / n  # Compute covariance matrix
    
    eye = torch.eye(d, device=data.device)
    return torch.nn.functional.smooth_l1_loss(cov_matrix, eye)  # L1 loss between covariance matrix and identity matrix


def compute_repulsion_loss(points):

    # Add point repulsion term for additional stability
    dist_matrix = torch.cdist(points, points)
    # Set diagonal to large value to avoid self-repulsion
    mask = torch.eye(points.shape[0], device=points.device).bool()
    dist_matrix = dist_matrix + mask * 1000
    repulsion = 1.0 / (dist_matrix + 1e-8)
    non_diag = repulsion[~mask]
    return torch.mean(non_diag)
    
def compute_boundary_loss(points, domain_min=-1, domain_max=1):
    return torch.mean(torch.relu(domain_min - points)) + \
            torch.mean(torch.relu(points - domain_max))