import torch # from https://github.com/napsternxg/pytorch-practice/blob/master/Pytorch%20-%20MMD%20VAE.ipynb def compute_kernel(x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) x = x.unsqueeze(1) # (x_size, 1, dim) y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(x, y): x_kernel = compute_kernel(x, x) y_kernel = compute_kernel(y, y) xy_kernel = compute_kernel(x, y) mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() return mmd