smog / src /utils /demo.py
vonexel's picture
add: src
fe64bad verified
import sys
sys.path.append('')
import argparse
import os
import torch
import torch.nn as nn
import numpy as np
from src.utils.get_model_and_data import get_model_and_data
from src.parser.visualize import parser
from src.utils.misc import load_model_wo_clip
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.utils.tensors import collate
import clip
from src.models.get_model import get_model
cosine_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
if __name__ == '__main__':
parameters, folder, checkpointname, epoch = parser()
data_split = 'vald' # Hardcoded
parameters[
'datapath'] = '/disk2/briangordon/NEW_MOTION_CLIP/MotionClip/ACTOR/PUBLIC_AMASS_DIR/amass_30fps_legacy_clip_images_v02_db.pt' # FIXME - hardcoded
clip_model, clip_preprocess = clip.load("ViT-B/32", device=parameters['device'],
jit=False) # Must set jit=False for training
clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
# model = get_model(parameters, clip_model)
model, datasets = get_model_and_data(parameters, split=data_split)
dataset = datasets["train"]
print("Restore weights..")
checkpointpath = os.path.join(folder, checkpointname)
target_path = os.path.join(os.path.dirname(checkpointpath), f'motion_rep_{data_split}_{epoch}.npz')
state_dict = torch.load(checkpointpath, map_location=parameters["device"])
# model.load_state_dict(state_dict)
load_model_wo_clip(model, state_dict)
iterator = DataLoader(dataset, batch_size=2, # parameters["batch_size"],
shuffle=False, num_workers=8, collate_fn=collate)
keep_keys = ['x', 'z', 'clip_text', 'clip_path']
buf = {}
filename = 'generated_014_fixed'
modi_motion = np.load(f'/disk2/briangordon/good_examples/{filename}.npy', allow_pickle=True)
modi_motion = modi_motion.astype('float32')
modi_motion = torch.from_numpy(modi_motion).to(model.device)
lenghts = torch.zeros(1, device=model.device) + 64
mask = torch.ones((1, 64), dtype=torch.bool, device=model.device)
y = torch.zeros(1, dtype=torch.long, device=model.device)
batch = {
'x': modi_motion,
'mask': mask,
'lenghts': lenghts,
'y': y
}
batch.update(model.encoder(batch))
batch["z"] = batch["mu"]
# Encode text with clip
texts = clip.tokenize(['walk', 'run', 'lie', 'sit', 'swim']).to(parameters['device'])
text_features = clip_model.encode_text(texts).float()
# normalized features motion & text
features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
cos = cosine_sim(features_norm, seq_motion_features_norm)
print(cos)
input('look')
with torch.no_grad():
for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
# print('batch', {k: type(v) for k,v in batch.items()})
for key in batch.keys():
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(parameters['device'])
# batch = {key: val.to(parameters['device']) for key, val in batch.items()}
# print('batch', {k: v.shape for k,v in batch.items()})
print(f'x: {batch["x"].shape}')
print(f'mask: {batch["mask"].shape}')
print(f'lengths: {batch["lengths"].shape}')
print(f'y: {batch["y"].shape}')
# print(f'x: {batch["x"][0, :, :, 1]}')
# print('batch', {k: v for k,v in batch.items()})
input('look')
if model.outputxyz:
batch["x_xyz"] = model.rot2xyz(batch["x"], batch["mask"])
elif model.pose_rep == "xyz":
batch["x_xyz"] = batch["x"]
# Encode Motion - Encoded motion will be under "z" key.
batch.update(model.encoder(batch))
batch["z"] = batch["mu"]
# Encode text with clip
texts = clip.tokenize(batch['clip_text']).to(parameters['device'])
text_features = clip_model.encode_text(texts).float()
# normalized features motion & text
features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
cos = cosine_sim(features_norm, seq_motion_features_norm)
cosine_loss = (1 - cos).mean()
# batch = model(batch)
# print('batch', {k: v.shape for k,v in batch.items()})
""" LOGIC TO SAVE OUTPUTS TO NPZ FILE - not required 100%"""
if len(buf) == 0:
for k in keep_keys:
_to_write = batch[k].cpu().numpy() if torch.is_tensor(batch[k]) else np.array(batch[k])
buf[k] = _to_write
else:
for k in keep_keys:
_to_write = batch[k].cpu().numpy() if torch.is_tensor(batch[k]) else np.array(batch[k])
buf[k] = np.concatenate((buf[k], _to_write), axis=0)
print('buf', {k: v.shape for k, v in buf.items()})
# print('clip_text', buf['clip_text'])
# print('clip_path', buf['clip_path'])
# FIXME - for now we need just a sample - hence, adding an early stop
if i == 5:
break
print(f'Saving {target_path}')
np.savez(target_path, buf)