|
import sys
|
|
|
|
sys.path.append('')
|
|
import argparse
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from src.utils.get_model_and_data import get_model_and_data
|
|
from src.parser.visualize import parser
|
|
from src.utils.misc import load_model_wo_clip
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
from src.utils.tensors import collate
|
|
import clip
|
|
from src.models.get_model import get_model
|
|
|
|
cosine_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
|
|
|
|
if __name__ == '__main__':
|
|
parameters, folder, checkpointname, epoch = parser()
|
|
|
|
data_split = 'vald'
|
|
parameters[
|
|
'datapath'] = '/disk2/briangordon/NEW_MOTION_CLIP/MotionClip/ACTOR/PUBLIC_AMASS_DIR/amass_30fps_legacy_clip_images_v02_db.pt'
|
|
|
|
clip_model, clip_preprocess = clip.load("ViT-B/32", device=parameters['device'],
|
|
jit=False)
|
|
clip.model.convert_weights(clip_model)
|
|
|
|
model, datasets = get_model_and_data(parameters, split=data_split)
|
|
dataset = datasets["train"]
|
|
|
|
print("Restore weights..")
|
|
checkpointpath = os.path.join(folder, checkpointname)
|
|
target_path = os.path.join(os.path.dirname(checkpointpath), f'motion_rep_{data_split}_{epoch}.npz')
|
|
state_dict = torch.load(checkpointpath, map_location=parameters["device"])
|
|
|
|
load_model_wo_clip(model, state_dict)
|
|
|
|
iterator = DataLoader(dataset, batch_size=2,
|
|
shuffle=False, num_workers=8, collate_fn=collate)
|
|
|
|
keep_keys = ['x', 'z', 'clip_text', 'clip_path']
|
|
buf = {}
|
|
|
|
filename = 'generated_014_fixed'
|
|
modi_motion = np.load(f'/disk2/briangordon/good_examples/{filename}.npy', allow_pickle=True)
|
|
modi_motion = modi_motion.astype('float32')
|
|
modi_motion = torch.from_numpy(modi_motion).to(model.device)
|
|
lenghts = torch.zeros(1, device=model.device) + 64
|
|
mask = torch.ones((1, 64), dtype=torch.bool, device=model.device)
|
|
y = torch.zeros(1, dtype=torch.long, device=model.device)
|
|
|
|
batch = {
|
|
'x': modi_motion,
|
|
'mask': mask,
|
|
'lenghts': lenghts,
|
|
'y': y
|
|
}
|
|
batch.update(model.encoder(batch))
|
|
batch["z"] = batch["mu"]
|
|
|
|
texts = clip.tokenize(['walk', 'run', 'lie', 'sit', 'swim']).to(parameters['device'])
|
|
text_features = clip_model.encode_text(texts).float()
|
|
|
|
|
|
features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
|
|
cos = cosine_sim(features_norm, seq_motion_features_norm)
|
|
|
|
print(cos)
|
|
input('look')
|
|
with torch.no_grad():
|
|
for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
|
|
|
|
for key in batch.keys():
|
|
if torch.is_tensor(batch[key]):
|
|
batch[key] = batch[key].to(parameters['device'])
|
|
|
|
|
|
print(f'x: {batch["x"].shape}')
|
|
print(f'mask: {batch["mask"].shape}')
|
|
print(f'lengths: {batch["lengths"].shape}')
|
|
print(f'y: {batch["y"].shape}')
|
|
|
|
|
|
|
|
|
|
|
|
input('look')
|
|
if model.outputxyz:
|
|
batch["x_xyz"] = model.rot2xyz(batch["x"], batch["mask"])
|
|
elif model.pose_rep == "xyz":
|
|
batch["x_xyz"] = batch["x"]
|
|
|
|
|
|
batch.update(model.encoder(batch))
|
|
batch["z"] = batch["mu"]
|
|
|
|
|
|
texts = clip.tokenize(batch['clip_text']).to(parameters['device'])
|
|
text_features = clip_model.encode_text(texts).float()
|
|
|
|
|
|
features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
|
|
cos = cosine_sim(features_norm, seq_motion_features_norm)
|
|
cosine_loss = (1 - cos).mean()
|
|
|
|
|
|
|
|
|
|
""" LOGIC TO SAVE OUTPUTS TO NPZ FILE - not required 100%"""
|
|
if len(buf) == 0:
|
|
for k in keep_keys:
|
|
_to_write = batch[k].cpu().numpy() if torch.is_tensor(batch[k]) else np.array(batch[k])
|
|
buf[k] = _to_write
|
|
else:
|
|
for k in keep_keys:
|
|
_to_write = batch[k].cpu().numpy() if torch.is_tensor(batch[k]) else np.array(batch[k])
|
|
buf[k] = np.concatenate((buf[k], _to_write), axis=0)
|
|
print('buf', {k: v.shape for k, v in buf.items()})
|
|
|
|
|
|
|
|
|
|
if i == 5:
|
|
break
|
|
|
|
print(f'Saving {target_path}')
|
|
np.savez(target_path, buf)
|
|
|