import os import numpy as np import random from tqdm import tqdm import argparse import torch import sys import glob import csv import torch.utils.data as data import pandas as pd import re import pdb sys.path.append(os.path.abspath(os.path.dirname(os.path.realpath(__file__))+'/'+'..')) parser = argparse.ArgumentParser() parser.add_argument('--vid_csv_path', type=str, default=None) parser.add_argument('--image_dir', type=str, default=None) parser.add_argument('--text_dir', type=str, default=None) parser.add_argument('--lyric_txt_path', type=str, default=None) parser.add_argument('--save_csv', type=str, default=None) parser.add_argument('--gpu', type=str, default='3') args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # python video2img_retrieval.py --vid_csv_path ../vids.csv --image_dir ../feats --text_dir ../../BriVL-git/BriVL_code_inference/feat/text_vivo_5w_songids/ --lyric_txt_path /data/home/sxhong/tools/get_lyric/data/vivo_top5w_songid_lyrics.lyric --save_csv test.csv def get_ImgFeat(img_feats): # 获取imgs_feat img_dict = {} for i, np_img_path in enumerate(img_feats): img_dict[i] = np_img_path if i == 0: np_img = np.expand_dims(np.load(np_img_path).astype(np.float64), axis=0) else: np_img = np.concatenate((np_img, np.expand_dims(np.load(np_img_path).astype(np.float64), axis=0)), axis=0) img = torch.from_numpy(np_img).cuda() return img_dict, img class Text_Dataset(data.Dataset): def __init__(self, text_dir): self.text_feats = glob.glob(os.path.join(text_dir, '*.npy')) def __len__(self): return len(self.text_feats) def __getitem__(self, index): text_path = self.text_feats[index] songid = text_path.split('/')[-1].split('.')[0] text_feat = np.load(text_path) return songid, text_feat def get_TextFeat(text_dir): # pdb.set_trace() # 获取texts_feat dataset = Text_Dataset(text_dir) dataloader = data.DataLoader(dataset, batch_size=5000, shuffle=False) all_songids = [] for i, (songid_, text_feats) in enumerate(dataloader): all_songids.extend(list(songid_)) if i == 0: text = text_feats else: text = torch.cat((text, text_feats), 0) text = text.squeeze(dim=1) text = text.cuda() return all_songids, text def get_lyric(lyric_txt_path): # songid对应的歌词 text_dict = {} for line in open(lyric_txt_path): try: songid, text_query = line.split(',')[0], line.split('"')[1] text_dict[songid] = text_query except: pass return text_dict vids = pd.read_csv(args.vid_csv_path, dtype=str)['vid'].to_list() all_songids, text = get_TextFeat(args.text_dir) text_dict = get_lyric(args.lyric_txt_path) vid2songid = {} for vid in vids: img_feat_paths = glob.glob(os.path.join(args.image_dir, str(vid)+'_'+'*.npy')) img_dict, img = get_ImgFeat(img_feat_paths) N_img = img.shape[0] N_text = text.shape[0] scores = torch.zeros((N_img, N_text), dtype=torch.float32).cuda() print('Pair-to-pair: calculating scores') for i in tqdm(range(N_img)): # row: image col: text scores[i, :] = torch.sum(img[i] * text, -1) songid2score = {} songids = [] songids2scores = [] for i, score in enumerate(scores): indices = torch.argsort(score, descending=True) songids = [] for idx in indices: if len(songids) <= 80: idx = int(idx.cpu().numpy()) query_text = text_dict[all_songids[idx]] query_text = query_text.split(',') query_text = (query_text[len(query_text) // 2]).replace(' ','').replace('*','') for exist_songid in songids: key_text = (text_dict[exist_songid]).replace(' ','').replace('*','') if re.findall(query_text, key_text): if float(songid2score[exist_songid]) < float(score[idx].cpu().numpy()): songid2score.pop(exist_songid) songid2score[all_songids[idx]] = str(score[idx].cpu().numpy()) break songid2score[all_songids[idx]] = str(score[idx].cpu().numpy()) else: break sorted_songid2score = sorted(songid2score.items(), key=lambda x:float(x[1]), reverse=True) select_songids = ', '.join([songid for songid, score in sorted_songid2score[:100]]) correspond_scores = ', '.join([score for songid, score in sorted_songid2score[:100]]) vid2songid[vid] = (select_songids, correspond_scores) data = [] for vid, values in vid2songid.items(): songids, scores = values data.append([vid, songids, scores]) df = pd.DataFrame(data, columns=['img', 'songids', 'scores']) df.to_csv(args.save_csv, index=False)