File size: 2,669 Bytes
dc80a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import argparse
import gradio as gr
from goldfish_lv import GoldFish_LV 
from theme import minigptlv_style
import time
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def get_arguments():
    parser = argparse.ArgumentParser(description="Inference parameters")
    parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
    parser.add_argument("--neighbours", type=int, default=3)
    parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
    parser.add_argument("--add_subtitles", action='store_true')
    parser.add_argument("--max_new_tokens", type=int, default=512)
    parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
    parser.add_argument("--batch_size", type=int, default=2, help="Batch size for short video clips")
    parser.add_argument("--lora_r", type=int, default=64)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--video_path", type=str,default="path for video.mp4", help="Path to the video file or youtube url")
    parser.add_argument("--question", type=str, default="Why rachel is wearing a wedding dress?")
    parser.add_argument("--options", nargs="+")
    return parser.parse_args()

def download_video(youtube_url):
    processed_video_path = goldfish_lv.process_video_url(youtube_url)
    return processed_video_path

def process_video(video_path, has_subtitles, instruction="",number_of_neighbours=-1):
    result = goldfish_lv.inference(video_path, has_subtitles, instruction,number_of_neighbours)
    pred = result["pred"]
    return pred

def return_video_path(youtube_url):
    video_id = youtube_url.split("https://www.youtube.com/watch?v=")[-1].split('&')[0]
    if video_id:
        return os.path.join("workspace", "tmp", f"{video_id}.mp4")
    else:
        raise ValueError("Invalid YouTube URL provided.")

args=get_arguments()
if __name__ == "__main__":
    t1=time.time()
    print("using openai: ", args.use_openai_embedding)
    goldfish_lv = GoldFish_LV(args)
    t2=time.time()
    print("Time taken to load model: ", t2-t1)
    processed_video_path = goldfish_lv.process_video_url(args.video_path)
    pred=process_video(processed_video_path, args.add_subtitles, args.question,args.neighbours)      
    print("Question answer: ", pred)
    print(f"Time taken for inference: ", time.time()-t2)