""" Functions for compuing Procrustes alignment and reconstruction error Parts of the code are adapted from https://github.com/akanazawa/hmr """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np def compute_similarity_transform(S1, S2): """Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. """ transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.T S2 = S2.T transposed = True assert(S2.shape[1] == S1.shape[1]) # 1. Remove mean. mu1 = S1.mean(axis=1, keepdims=True) mu2 = S2.mean(axis=1, keepdims=True) X1 = S1 - mu1 X2 = S2 - mu2 # 2. Compute variance of X1 used for scale. var1 = np.sum(X1**2) # 3. The outer product of X1 and X2. K = X1.dot(X2.T) # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are # singular vectors of K. U, s, Vh = np.linalg.svd(K) V = Vh.T # Construct Z that fixes the orientation of R to get det(R)=1. Z = np.eye(U.shape[0]) Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) # Construct R. R = V.dot(Z.dot(U.T)) # 5. Recover scale. scale = np.trace(R.dot(K)) / var1 # 6. Recover translation. t = mu2 - scale*(R.dot(mu1)) # 7. Error: S1_hat = scale*R.dot(S1) + t if transposed: S1_hat = S1_hat.T return S1_hat def compute_similarity_transform_batch(S1, S2): """Batched version of compute_similarity_transform.""" S1_hat = np.zeros_like(S1) for i in range(S1.shape[0]): S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) return S1_hat def reconstruction_error(S1, S2, reduction='mean'): """Do Procrustes alignment and compute reconstruction error.""" S1_hat = compute_similarity_transform_batch(S1, S2) re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) if reduction == 'mean': re = re.mean() elif reduction == 'sum': re = re.sum() return re def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'): """Do Procrustes alignment and compute reconstruction error.""" S1_hat = compute_similarity_transform_batch(S1, S2) S1_hat = S1_hat[:,J24_TO_J14,:] S2 = S2[:,J24_TO_J14,:] re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) if reduction == 'mean': re = re.mean() elif reduction == 'sum': re = re.sum() return re def get_alignMesh(S1, S2, reduction='mean'): """Do Procrustes alignment and compute reconstruction error.""" S1_hat = compute_similarity_transform_batch(S1, S2) re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) if reduction == 'mean': re = re.mean() elif reduction == 'sum': re = re.sum() return re, S1_hat, S2