File size: 5,775 Bytes
2e75ade 2d4c714 2e75ade 6e1184d 2e75ade 2d4c714 2e75ade 6e1184d 2e75ade 2d4c714 2e75ade 6e1184d 2e75ade 6e1184d 2e75ade 6e1184d 2e75ade 6e1184d 2e75ade 6e1184d 2e75ade 6e1184d ccb8e5b 2e75ade 6e1184d 2e75ade 2d4c714 6e1184d 2d4c714 6e1184d 2d4c714 2e75ade 6e1184d 2e75ade 2d4c714 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
from termcolor import colored
from modeling_tara import TARA, read_frames_decord, read_images_decord
import warnings
warnings.filterwarnings("ignore")
def main(model_path: str = "."):
print(colored("="*60, 'yellow'))
print(colored("TARA Model Demo", 'yellow', attrs=['bold']))
print(colored("="*60, 'yellow'))
# Load model from current directory
print(colored("\n[1/6] Loading model...", 'cyan'))
model = TARA.from_pretrained(
model_path, # Load from current directory
device_map='auto',
torch_dtype=torch.bfloat16,
)
n_params = sum(p.numel() for p in model.model.parameters())
print(colored(f"β Model loaded successfully!", 'green'))
print(f"Number of parameters: {round(n_params/1e9, 3)}B")
print("-" * 100)
# Encode a sample video
print(colored("\n[2/6] Testing video encoding and captioning ...", 'cyan'))
video_path = "./assets/folding_paper.mp4"
try:
video_tensor = read_frames_decord(video_path, num_frames=16)
video_tensor = video_tensor.unsqueeze(0)
video_tensor = video_tensor.to(model.model.device)
with torch.no_grad():
video_emb = model.encode_vision(video_tensor).cpu().squeeze(0).float()
# Get caption for the video
video_caption = model.describe(video_tensor)[0]
print(colored("β Video encoded successfully!", 'green'))
print(f"Video shape: {video_tensor.shape}") # torch.Size([1, 16, 3, 240, 426])
print(f"Video embedding shape: {video_emb.shape}") # torch.Size([4096])
print(colored(f"Video caption: {video_caption}", 'magenta'))
except FileNotFoundError:
print(colored(f"β Video file not found: {video_path}", 'red'))
print(colored(" Please add a video file or update the path in demo_usage.py", 'yellow'))
video_emb = None
print("-" * 100)
# Encode sample texts
print(colored("\n[3/6] Testing text encoding...", 'cyan'))
text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper']
# NOTE: It can also take a single string
with torch.no_grad():
text_emb = model.encode_text(text).cpu().float()
print(colored("β Text encoded successfully!", 'green'))
print(f"Text: {text}")
print(f"Text embedding shape: {text_emb.shape}") # torch.Size([3, 4096])
# Compute similarities if video was encoded
if video_emb is not None:
print(colored("\n[4/6] Computing video-text similarities...", 'cyan'))
similarities = torch.cosine_similarity(
video_emb.unsqueeze(0).unsqueeze(0), # [1, 1, 4096]
text_emb.unsqueeze(0), # [1, 3, 4096]
dim=-1
)
print(colored("β Similarities computed!", 'green'))
for i, txt in enumerate(text):
print(f" '{txt}': {similarities[0, i].item():.4f}")
print("-" * 100)
# Negation example: a negation in text query should result
# in retrieval of images without the neg. object in the query
print(colored("\n[5/6] Testing negation example...", 'cyan'))
image_paths = [
'./assets/cat.png',
'./assets/dog+cat.png',
]
image_tensors = read_images_decord(image_paths)
with torch.no_grad():
image_embs = model.encode_vision(image_tensors.to(model.model.device)).cpu().float()
image_embs = torch.nn.functional.normalize(image_embs, dim=-1)
print(f"Image embedding shape: {image_embs.shape}")
texts = ['an image of a cat but there is no dog in it']
with torch.no_grad():
text_embs = model.encode_text(texts).cpu().float()
text_embs = torch.nn.functional.normalize(text_embs, dim=-1)
print("Text query: ", texts)
sim = text_embs @ image_embs.t()
print(f"Text-Image similarity: {sim}")
print("- " * 50)
texts = ['an image of a cat and a dog together']
with torch.no_grad():
text_embs = model.encode_text(texts).cpu().float()
text_embs = torch.nn.functional.normalize(text_embs, dim=-1)
print("Text query: ", texts)
sim = text_embs @ image_embs.t()
print(f"Text-Image similarity: {sim}")
print("-" * 100)
# Composed video retrieval example
print(colored("\n[6/6] Testing composed video retrieval...", 'cyan'))
# source_video_path = './assets/source-27375787.mp4'
# target_video_path = './assets/target-27387901.mp4'
# edit_text = "Make the billboard blank"
source_video_path = "./assets/5369546.mp4"
target_video_path = "./assets/1006630957.mp4"
edit_text ="make the tree lit up"
source_video_tensor = read_frames_decord(source_video_path, num_frames=4)
target_video_tensor = read_frames_decord(target_video_path, num_frames=16)
with torch.no_grad():
source_video_emb = model.encode_vision(source_video_tensor.unsqueeze(0), edit_text).cpu().squeeze(0).float()
source_video_emb = torch.nn.functional.normalize(source_video_emb, dim=-1)
target_video_emb = model.encode_vision(target_video_tensor.unsqueeze(0)).cpu().squeeze(0).float()
target_video_emb = torch.nn.functional.normalize(target_video_emb, dim=-1)
sim_with_edit = source_video_emb @ target_video_emb.t()
print(f"Source-Target similarity with edit: {sim_with_edit}")
print(colored("\n" + "="*60, 'yellow'))
print(colored("Demo completed successfully! π", 'green', attrs=['bold']))
print(colored("="*60, 'yellow'))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=".")
args = parser.parse_args()
main(args.model_path) |