File size: 5,249 Bytes
54216bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import json
import base64
import random
import argparse
import natsort
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from src.run_gpt import run_gpt
random.seed(10)
dict_api = {
"api_key":"ADD",
}
class CustomDatasetGPT(Dataset):
def __init__(self, questions, num_kf):
self.questions = questions
self.num_kf = num_kf
def __getitem__(self, index):
line = self.questions[index]
group = 4
newnum_per_group = self.num_kf // group
oldnum_per_group = len(line["VLM_path"]) // group
assert oldnum_per_group >= newnum_per_group, f"oldnum_per_group:{oldnum_per_group} is smaller than newnum_per_group:{newnum_per_group}"
new_kf_paths = []
new_kf_timelines = []
for i in range(group):
start_index = i * oldnum_per_group
end_index = start_index + oldnum_per_group
sub_kf_paths = line["VLM_path"][start_index:min(end_index, len(line["VLM_path"]))]
sub_kf_timelines = line["VLM_timeline"][start_index:min(end_index, len(line["VLM_timeline"]))]
new_kf_paths.extend(sub_kf_paths[:newnum_per_group])
new_kf_timelines.extend(sub_kf_timelines[:newnum_per_group])
kf_paths = natsort.natsorted(new_kf_paths)
kf_timelines = natsort.natsorted(new_kf_timelines)
images = []
images_base64 = []
for e in kf_paths:
images.append(Image.open(e).convert('RGB'))
images_base64.append(encode_image(e))
return images_base64, kf_paths, kf_timelines
def __len__(self):
return len(self.questions)
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def create_data_loader_gpt(questions, num_kf, batch_size=1, num_workers=4):
assert batch_size == 1, "batch_size must be 1"
dataset = CustomDatasetGPT(questions, num_kf)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
return data_loader, dataset
def eval_model(args):
base_dir, question_path, vlm, num_kf, temp = (
args.output_dir,
args.question_path,
args.gptmodel,
args.num_kf,
args.temp,
)
questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
fname = question_path.split('/')[-1]
answer_path = f"{base_dir}/egoschema/{num_kf}/{fname}"
os.makedirs(os.path.dirname(answer_path), exist_ok=True)
print(f"question_path:{question_path}\nanswer_path:{answer_path}")
ans_file = open(answer_path, "w")
data_loader, dataset = create_data_loader_gpt(questions, num_kf)
for (base64_image, kf_paths, kf_timelines), line in tqdm(zip(data_loader, questions), total=len(questions)):
idx = line["q_uid"]
CA = line["CA"] if "CA" in line else None
option0 = line['option 0']
option1 = line['option 1']
option2 = line['option 2']
option3 = line['option 3']
option4 = line['option 4']
question = line['question']
lenwords = "50"
prompt = f"'C' stands for the cameraman. Describe the activity depicted in this first-person perspective image in less than {lenwords} words. In your answer, don't mention that the image is in first-person perspective, as we already know this."
prompts = [prompt] * num_kf
image_paths = [e[0] for e in kf_paths]
image_timelines = [e[0] for e in kf_timelines]
output_VLM = run_gpt(
images=image_paths,
texts=prompts,
api_keys=list(dict_api.values()),
max_tokens=2000,
model=vlm,
temperature=temp,
num_threads=20, # Tune this
backoff_time=1 * 60,
silent=False,
dataset="egoschema",
verbose=False,
)
output_VLM = list(output_VLM)
for j, e in enumerate(image_timelines):
line_frame = line.copy()
line_frame["answer"] = f"At {str(e)} seconds, {output_VLM[j]}"
line_frame["AR-VLM_model_id"] = vlm
line_frame["AR-VLM_prompt"] = prompts[j]
line_frame["timeline"] = float(e)
line_frame["frame_idx"] = j
line_frame["image_paths"] = image_paths
if "imgidx_kw_dict" in line_frame.keys(): line_frame.pop("imgidx_kw_dict")
if "google_drive_id" in line_frame.keys(): line_frame.pop("google_drive_id")
ans_file.write(json.dumps(line_frame)+"\n")
print(f"question.\nquestion_path:{question_path}\nanswer_path:{answer_path}")
ans_file.close()
return "job is done"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str)
parser.add_argument("--question-path", type=str, default="")
parser.add_argument("--num-kf", type=int)
parser.add_argument("--gptmodel", type=str, default="gpt-4o")
parser.add_argument("--temp", type=float, default=None)
args = parser.parse_args()
eval_model(args)
|