smog / src /models /modeltype /motionclip.py
vonexel's picture
add: src
fe64bad verified
import numpy as np
import torch
import torch.nn as nn
import clip
from ..tools.losses import get_loss_function
from ..rotation2xyz import Rotation2xyz
loss_ce = nn.CrossEntropyLoss()
loss_mse = nn.MSELoss()
cosine_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
from tqdm import tqdm
class MOTIONCLIP(nn.Module):
def __init__(self, encoder, decoder, device, lambdas, latent_dim, outputxyz,
pose_rep, glob, glob_rot, translation, jointstype, vertstrans, clip_lambdas={}, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.outputxyz = outputxyz
self.lambdas = lambdas
self.clip_lambdas = clip_lambdas
self.latent_dim = latent_dim
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.device = device
self.translation = translation
self.jointstype = jointstype
self.vertstrans = vertstrans
self.clip_model = kwargs['clip_model']
self.clip_training = kwargs.get('clip_training', False)
if self.clip_training and self.clip_model:
self.clip_model.training = True
else:
if self.clip_model:
assert self.clip_model.training == False # make sure clip is frozen
self.losses = list(self.lambdas) + ["mixed"]
self.rotation2xyz = Rotation2xyz(device=self.device)
self.param2xyz = {"pose_rep": self.pose_rep,
"glob_rot": self.glob_rot,
"glob": self.glob,
"jointstype": self.jointstype,
"translation": self.translation,
"vertstrans": self.vertstrans}
def rot2xyz(self, x, mask, get_rotations_back=False, **kwargs):
kargs = self.param2xyz.copy()
kargs.update(kwargs)
return self.rotation2xyz(x, mask, get_rotations_back=get_rotations_back, **kargs)
def compute_loss(self, batch):
# compute all losses other than clip
mixed_loss = 0.
losses = {}
for ltype, lam in self.lambdas.items():
loss_function = get_loss_function(ltype)
loss = loss_function(self, batch)
mixed_loss += loss * lam
losses[ltype] = loss.item()
# compute clip losses
mixed_clip_loss, clip_losses = self.compute_clip_losses(batch)
# mix and add clip losses
mixed_loss_with_clip = mixed_loss + mixed_clip_loss # this is the ultimate loss to optimize, combining ALL losses
losses.update(clip_losses)
losses["mixed_without_clip"] = mixed_loss.item()
losses["mixed_clip_only"] = mixed_clip_loss if isinstance(mixed_clip_loss, float) else mixed_clip_loss.item()
losses["mixed_with_clip"] = mixed_loss_with_clip if isinstance(mixed_loss_with_clip,
float) else mixed_loss_with_clip.item()
return mixed_loss_with_clip, losses
def compute_clip_losses(self, batch):
mixed_clip_loss = 0.
clip_losses = {}
if self.clip_training:
for d in self.clip_training.split('_'):
if d == 'image':
features = self.clip_model.encode_image(
batch['clip_images']).float() # preprocess is done in dataloader
elif d == 'text':
texts = clip.tokenize(batch['clip_text']).to(self.device)
features = self.clip_model.encode_text(texts).float()
# normalized features
features_norm = features / features.norm(dim=-1, keepdim=True)
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
logit_scale = self.clip_model.logit_scale.exp()
logits_per_motion = logit_scale * seq_motion_features_norm @ features_norm.t()
logits_per_d = logits_per_motion.t()
batch_size = batch['x'].shape[0]
ground_truth = torch.arange(batch_size, dtype=torch.long, device=self.device)
ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.
clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
mixed_clip_loss += clip_mixed_loss
else:
for d in self.clip_lambdas.keys():
if len(self.clip_lambdas[d].keys()) == 0:
continue
with torch.no_grad():
if d == 'image':
features = self.clip_model.encode_image(
batch['clip_images']).float() # preprocess is done in dataloader
elif d == 'text':
texts = clip.tokenize(batch['clip_text']).to(self.device)
features = self.clip_model.encode_text(texts).float()
else:
raise ValueError(f'Invalid clip domain [{d}]')
# normalized features
features_norm = features / features.norm(dim=-1, keepdim=True)
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
if 'ce' in self.clip_lambdas[d].keys():
logit_scale = self.clip_model.logit_scale.exp()
logits_per_motion = logit_scale * seq_motion_features_norm @ features_norm.t()
logits_per_d = logits_per_motion.t()
batch_size = batch['x'].shape[0]
ground_truth = torch.arange(batch_size, dtype=torch.long, device=self.device)
ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.
clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
mixed_clip_loss += clip_mixed_loss * self.clip_lambdas[d]['ce']
if 'mse' in self.clip_lambdas[d].keys():
mse_clip_loss = loss_mse(features, batch["z"])
clip_losses[f'{d}_mse'] = mse_clip_loss.item()
mixed_clip_loss += mse_clip_loss * self.clip_lambdas[d]['mse']
if 'cosine' in self.clip_lambdas[d].keys():
cos = cosine_sim(features_norm, seq_motion_features_norm)
cosine_loss = (1 - cos).mean()
clip_losses[f'{d}_cosine'] = cosine_loss.item()
mixed_clip_loss += cosine_loss * self.clip_lambdas[d]['cosine']
return mixed_clip_loss, clip_losses
@staticmethod
def lengths_to_mask(lengths):
max_len = max(lengths)
if isinstance(max_len, torch.Tensor):
max_len = max_len.item()
index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len)
mask = index < lengths.unsqueeze(1)
return mask
def generate_one(self, cls, duration, fact=1, xyz=False):
y = torch.tensor([cls], dtype=int, device=self.device)[None]
lengths = torch.tensor([duration], dtype=int, device=self.device)
mask = self.lengths_to_mask(lengths)
z = torch.randn(self.latent_dim, device=self.device)[None]
batch = {"z": fact * z, "y": y, "mask": mask, "lengths": lengths}
batch = self.decoder(batch)
if not xyz:
return batch["output"][0]
output_xyz = self.rot2xyz(batch["output"], batch["mask"])
return output_xyz[0]
def generate(self, classes, durations, nspa=1,
# noise_same_action="random", noise_diff_action="random",
# fact=1,
is_amass=False, is_clip_features=False,
# input_type="motion",
textual_labels=None):
clip_dim = self.clip_model.ln_final.normalized_shape[0]
if is_clip_features:
# assumed dims: classes [nspa, nats, 512]
assert len(classes.shape) == 3
assert classes.shape[-1] == clip_dim
clip_features = classes.reshape([-1, clip_dim])
nspa, nats = classes.shape[:2]
# y = torch.zeros(y_action_names.shape, dtype=int)
y = clip_features
if textual_labels is not None:
y = np.array(textual_labels).reshape([-1])
if len(durations.shape) == 1:
lengths = durations.to(self.device).repeat(nspa)
else:
lengths = durations.to(self.device).reshape(clip_features.shape[0])
mask = self.lengths_to_mask(lengths)
batch = {"z": clip_features, # fact*z,
"y": y,
"mask": mask, "lengths": lengths}
if not is_clip_features:
batch['y'] = y
batch = self.decoder(batch)
if is_amass: # lose global orientation for amass dataset
batch['output'][:, 0] = torch.tensor([1, 0, 0, 0, -1, 0]).unsqueeze(0).unsqueeze(2)
if self.outputxyz:
batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
elif self.pose_rep == "xyz":
batch["output_xyz"] = batch["output"]
return batch
def generate_from_embedding(self, classes, durations, nspa=1, is_amass=False, classes_gaussians=None):
if nspa is None:
nspa = 1
nats = len(classes)
y = classes.to(self.device).repeat(nspa) # (view(nspa, nats))
if len(durations.shape) == 1:
lengths = durations.to(self.device).repeat(nspa)
else:
lengths = durations.to(self.device).reshape(y.shape)
mask = self.lengths_to_mask(lengths)
classes_np = classes.cpu().detach().numpy()
#motion_samples_ = np.zeros((classes_np.shape[0], 512), dtype='float32')
motion_samples_ = np.zeros((classes_np.shape[0], 512), dtype='float32')
for class_label in tqdm(np.unique(classes_np), total=len(np.unique(classes_np))):
class_mask = np.where(classes_np == class_label)[0]
sample_mu = classes_gaussians[class_label]['mu']
sample_var = classes_gaussians[class_label]['var']
sample = np.random.multivariate_normal(sample_mu, sample_var, size=len(class_mask))
motion_samples_[class_mask, :] = sample
zz = torch.from_numpy(motion_samples_).to(self.device)
batch = {"z": zz,
"y": y, "mask": mask, "lengths": lengths}
batch = self.decoder(batch)
if is_amass: # lose global orientation for amass dataset
batch['output'][:, 0] = torch.tensor([1, 0, 0, 0, -1, 0]).unsqueeze(0).unsqueeze(2)
if self.outputxyz:
batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
elif self.pose_rep == "xyz":
batch["output_xyz"] = batch["output"]
return batch
def forward(self, batch):
if self.outputxyz:
batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"])
elif self.pose_rep == "xyz":
batch["x_xyz"] = batch["x"]
# encode
batch.update(self.encoder(batch))
batch["z"] = batch["mu"]
# decode
batch.update(self.decoder(batch))
# if we want to output xyz
if self.outputxyz:
batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
elif self.pose_rep == "xyz":
batch["output_xyz"] = batch["output"]
return batch