|
import os |
|
import argparse |
|
import json |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
from llava.eval.video.run_inference_video_qa import get_model_output |
|
from llava.mm_utils import get_model_name_from_path |
|
from llava.model.builder import load_pretrained_model |
|
|
|
|
|
def parse_args(): |
|
""" |
|
Parse command-line arguments. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--model_path', help='', required=True) |
|
parser.add_argument('--cache_dir', help='', required=True) |
|
parser.add_argument('--video_dir', help='Directory containing video files.', required=True) |
|
parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True) |
|
parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True) |
|
parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True) |
|
|
|
parser.add_argument("--device", type=str, required=False, default='cuda:0') |
|
parser.add_argument('--model_base', help='', default=None, type=str, required=False) |
|
parser.add_argument("--model_max_length", type=int, required=False, default=2048) |
|
|
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def run_inference(args): |
|
""" |
|
Run inference on a set of video files using the provided model. |
|
|
|
Args: |
|
args: Command-line arguments. |
|
""" |
|
|
|
model_name = get_model_name_from_path(args.model_path) |
|
tokenizer, model, processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) |
|
model = model.to(args.device) |
|
|
|
|
|
with open(args.gt_file) as file: |
|
gt_contents = json.load(file) |
|
|
|
|
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
|
|
output_list = [] |
|
|
|
|
|
video_formats = ['.mp4', '.avi', '.mov', '.mkv'] |
|
|
|
|
|
for sample in tqdm(gt_contents): |
|
video_name = sample['video_name'] |
|
sample_set = sample |
|
question_1 = sample['Q1'] |
|
question_2 = sample['Q2'] |
|
|
|
try: |
|
|
|
for fmt in video_formats: |
|
temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}") |
|
if os.path.exists(temp_path): |
|
video_path = temp_path |
|
|
|
|
|
output_1 = get_model_output(model, processor['video'], tokenizer, video_path, question_1, args) |
|
sample_set['pred1'] = output_1 |
|
|
|
|
|
output_2 = get_model_output(model, processor['video'], tokenizer, video_path, question_2, args) |
|
sample_set['pred2'] = output_2 |
|
|
|
output_list.append(sample_set) |
|
break |
|
|
|
except Exception as e: |
|
print(f"Error processing video file '{video_name}': {e}") |
|
|
|
|
|
with open(os.path.join(args.output_dir, f"{args.output_name}.json"), 'w') as file: |
|
json.dump(output_list, file) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
run_inference(args) |
|
|