|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import clip
|
|
from ..tools.losses import get_loss_function
|
|
from ..rotation2xyz import Rotation2xyz
|
|
|
|
loss_ce = nn.CrossEntropyLoss()
|
|
loss_mse = nn.MSELoss()
|
|
|
|
cosine_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
|
|
from tqdm import tqdm
|
|
|
|
|
|
class MOTIONCLIP(nn.Module):
|
|
def __init__(self, encoder, decoder, device, lambdas, latent_dim, outputxyz,
|
|
pose_rep, glob, glob_rot, translation, jointstype, vertstrans, clip_lambdas={}, **kwargs):
|
|
super().__init__()
|
|
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
|
|
self.outputxyz = outputxyz
|
|
|
|
self.lambdas = lambdas
|
|
self.clip_lambdas = clip_lambdas
|
|
|
|
self.latent_dim = latent_dim
|
|
self.pose_rep = pose_rep
|
|
self.glob = glob
|
|
self.glob_rot = glob_rot
|
|
self.device = device
|
|
self.translation = translation
|
|
self.jointstype = jointstype
|
|
self.vertstrans = vertstrans
|
|
|
|
self.clip_model = kwargs['clip_model']
|
|
self.clip_training = kwargs.get('clip_training', False)
|
|
if self.clip_training and self.clip_model:
|
|
self.clip_model.training = True
|
|
else:
|
|
if self.clip_model:
|
|
assert self.clip_model.training == False
|
|
|
|
self.losses = list(self.lambdas) + ["mixed"]
|
|
|
|
self.rotation2xyz = Rotation2xyz(device=self.device)
|
|
self.param2xyz = {"pose_rep": self.pose_rep,
|
|
"glob_rot": self.glob_rot,
|
|
"glob": self.glob,
|
|
"jointstype": self.jointstype,
|
|
"translation": self.translation,
|
|
"vertstrans": self.vertstrans}
|
|
|
|
def rot2xyz(self, x, mask, get_rotations_back=False, **kwargs):
|
|
kargs = self.param2xyz.copy()
|
|
kargs.update(kwargs)
|
|
return self.rotation2xyz(x, mask, get_rotations_back=get_rotations_back, **kargs)
|
|
|
|
def compute_loss(self, batch):
|
|
|
|
|
|
mixed_loss = 0.
|
|
losses = {}
|
|
for ltype, lam in self.lambdas.items():
|
|
loss_function = get_loss_function(ltype)
|
|
loss = loss_function(self, batch)
|
|
mixed_loss += loss * lam
|
|
losses[ltype] = loss.item()
|
|
|
|
|
|
mixed_clip_loss, clip_losses = self.compute_clip_losses(batch)
|
|
|
|
|
|
mixed_loss_with_clip = mixed_loss + mixed_clip_loss
|
|
losses.update(clip_losses)
|
|
losses["mixed_without_clip"] = mixed_loss.item()
|
|
losses["mixed_clip_only"] = mixed_clip_loss if isinstance(mixed_clip_loss, float) else mixed_clip_loss.item()
|
|
losses["mixed_with_clip"] = mixed_loss_with_clip if isinstance(mixed_loss_with_clip,
|
|
float) else mixed_loss_with_clip.item()
|
|
|
|
return mixed_loss_with_clip, losses
|
|
|
|
def compute_clip_losses(self, batch):
|
|
mixed_clip_loss = 0.
|
|
clip_losses = {}
|
|
|
|
if self.clip_training:
|
|
for d in self.clip_training.split('_'):
|
|
if d == 'image':
|
|
features = self.clip_model.encode_image(
|
|
batch['clip_images']).float()
|
|
elif d == 'text':
|
|
texts = clip.tokenize(batch['clip_text']).to(self.device)
|
|
features = self.clip_model.encode_text(texts).float()
|
|
|
|
|
|
features_norm = features / features.norm(dim=-1, keepdim=True)
|
|
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
|
|
logit_scale = self.clip_model.logit_scale.exp()
|
|
logits_per_motion = logit_scale * seq_motion_features_norm @ features_norm.t()
|
|
logits_per_d = logits_per_motion.t()
|
|
|
|
batch_size = batch['x'].shape[0]
|
|
ground_truth = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
|
|
|
ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
|
|
ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
|
|
clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.
|
|
|
|
clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
|
|
clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
|
|
clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
|
|
mixed_clip_loss += clip_mixed_loss
|
|
else:
|
|
for d in self.clip_lambdas.keys():
|
|
if len(self.clip_lambdas[d].keys()) == 0:
|
|
continue
|
|
with torch.no_grad():
|
|
if d == 'image':
|
|
features = self.clip_model.encode_image(
|
|
batch['clip_images']).float()
|
|
elif d == 'text':
|
|
texts = clip.tokenize(batch['clip_text']).to(self.device)
|
|
features = self.clip_model.encode_text(texts).float()
|
|
else:
|
|
raise ValueError(f'Invalid clip domain [{d}]')
|
|
|
|
|
|
features_norm = features / features.norm(dim=-1, keepdim=True)
|
|
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
|
|
|
|
if 'ce' in self.clip_lambdas[d].keys():
|
|
logit_scale = self.clip_model.logit_scale.exp()
|
|
logits_per_motion = logit_scale * seq_motion_features_norm @ features_norm.t()
|
|
logits_per_d = logits_per_motion.t()
|
|
|
|
batch_size = batch['x'].shape[0]
|
|
ground_truth = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
|
|
|
ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
|
|
ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
|
|
clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.
|
|
|
|
clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
|
|
clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
|
|
clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
|
|
mixed_clip_loss += clip_mixed_loss * self.clip_lambdas[d]['ce']
|
|
|
|
if 'mse' in self.clip_lambdas[d].keys():
|
|
mse_clip_loss = loss_mse(features, batch["z"])
|
|
clip_losses[f'{d}_mse'] = mse_clip_loss.item()
|
|
mixed_clip_loss += mse_clip_loss * self.clip_lambdas[d]['mse']
|
|
|
|
if 'cosine' in self.clip_lambdas[d].keys():
|
|
cos = cosine_sim(features_norm, seq_motion_features_norm)
|
|
cosine_loss = (1 - cos).mean()
|
|
clip_losses[f'{d}_cosine'] = cosine_loss.item()
|
|
mixed_clip_loss += cosine_loss * self.clip_lambdas[d]['cosine']
|
|
|
|
return mixed_clip_loss, clip_losses
|
|
|
|
@staticmethod
|
|
def lengths_to_mask(lengths):
|
|
max_len = max(lengths)
|
|
if isinstance(max_len, torch.Tensor):
|
|
max_len = max_len.item()
|
|
index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len)
|
|
mask = index < lengths.unsqueeze(1)
|
|
return mask
|
|
|
|
def generate_one(self, cls, duration, fact=1, xyz=False):
|
|
y = torch.tensor([cls], dtype=int, device=self.device)[None]
|
|
lengths = torch.tensor([duration], dtype=int, device=self.device)
|
|
mask = self.lengths_to_mask(lengths)
|
|
z = torch.randn(self.latent_dim, device=self.device)[None]
|
|
|
|
batch = {"z": fact * z, "y": y, "mask": mask, "lengths": lengths}
|
|
batch = self.decoder(batch)
|
|
|
|
if not xyz:
|
|
return batch["output"][0]
|
|
|
|
output_xyz = self.rot2xyz(batch["output"], batch["mask"])
|
|
|
|
return output_xyz[0]
|
|
|
|
def generate(self, classes, durations, nspa=1,
|
|
|
|
|
|
is_amass=False, is_clip_features=False,
|
|
|
|
textual_labels=None):
|
|
clip_dim = self.clip_model.ln_final.normalized_shape[0]
|
|
if is_clip_features:
|
|
|
|
assert len(classes.shape) == 3
|
|
assert classes.shape[-1] == clip_dim
|
|
clip_features = classes.reshape([-1, clip_dim])
|
|
nspa, nats = classes.shape[:2]
|
|
|
|
y = clip_features
|
|
if textual_labels is not None:
|
|
y = np.array(textual_labels).reshape([-1])
|
|
|
|
if len(durations.shape) == 1:
|
|
lengths = durations.to(self.device).repeat(nspa)
|
|
else:
|
|
lengths = durations.to(self.device).reshape(clip_features.shape[0])
|
|
|
|
mask = self.lengths_to_mask(lengths)
|
|
|
|
batch = {"z": clip_features,
|
|
"y": y,
|
|
"mask": mask, "lengths": lengths}
|
|
|
|
if not is_clip_features:
|
|
batch['y'] = y
|
|
|
|
batch = self.decoder(batch)
|
|
|
|
if is_amass:
|
|
batch['output'][:, 0] = torch.tensor([1, 0, 0, 0, -1, 0]).unsqueeze(0).unsqueeze(2)
|
|
|
|
if self.outputxyz:
|
|
batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
|
|
elif self.pose_rep == "xyz":
|
|
batch["output_xyz"] = batch["output"]
|
|
|
|
return batch
|
|
|
|
def generate_from_embedding(self, classes, durations, nspa=1, is_amass=False, classes_gaussians=None):
|
|
|
|
if nspa is None:
|
|
nspa = 1
|
|
nats = len(classes)
|
|
|
|
y = classes.to(self.device).repeat(nspa)
|
|
if len(durations.shape) == 1:
|
|
lengths = durations.to(self.device).repeat(nspa)
|
|
else:
|
|
lengths = durations.to(self.device).reshape(y.shape)
|
|
mask = self.lengths_to_mask(lengths)
|
|
classes_np = classes.cpu().detach().numpy()
|
|
|
|
|
|
motion_samples_ = np.zeros((classes_np.shape[0], 512), dtype='float32')
|
|
for class_label in tqdm(np.unique(classes_np), total=len(np.unique(classes_np))):
|
|
class_mask = np.where(classes_np == class_label)[0]
|
|
sample_mu = classes_gaussians[class_label]['mu']
|
|
sample_var = classes_gaussians[class_label]['var']
|
|
|
|
sample = np.random.multivariate_normal(sample_mu, sample_var, size=len(class_mask))
|
|
motion_samples_[class_mask, :] = sample
|
|
|
|
zz = torch.from_numpy(motion_samples_).to(self.device)
|
|
|
|
batch = {"z": zz,
|
|
"y": y, "mask": mask, "lengths": lengths}
|
|
batch = self.decoder(batch)
|
|
|
|
if is_amass:
|
|
batch['output'][:, 0] = torch.tensor([1, 0, 0, 0, -1, 0]).unsqueeze(0).unsqueeze(2)
|
|
|
|
if self.outputxyz:
|
|
batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
|
|
elif self.pose_rep == "xyz":
|
|
batch["output_xyz"] = batch["output"]
|
|
return batch
|
|
|
|
def forward(self, batch):
|
|
if self.outputxyz:
|
|
batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"])
|
|
elif self.pose_rep == "xyz":
|
|
batch["x_xyz"] = batch["x"]
|
|
|
|
batch.update(self.encoder(batch))
|
|
|
|
batch["z"] = batch["mu"]
|
|
|
|
batch.update(self.decoder(batch))
|
|
|
|
|
|
if self.outputxyz:
|
|
batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
|
|
elif self.pose_rep == "xyz":
|
|
batch["output_xyz"] = batch["output"]
|
|
return batch
|
|
|