Spaces:
No application file
No application file
File size: 4,962 Bytes
bb46cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import numpy as np
import random
from tqdm import tqdm
import argparse
import torch
import sys
import glob
import csv
import torch.utils.data as data
import pandas as pd
import re
import pdb
sys.path.append(os.path.abspath(os.path.dirname(os.path.realpath(__file__))+'/'+'..'))
parser = argparse.ArgumentParser()
parser.add_argument('--vid_csv_path', type=str, default=None)
parser.add_argument('--image_dir', type=str, default=None)
parser.add_argument('--text_dir', type=str, default=None)
parser.add_argument('--lyric_txt_path', type=str, default=None)
parser.add_argument('--save_csv', type=str, default=None)
parser.add_argument('--gpu', type=str, default='3')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# python video2img_retrieval.py --vid_csv_path ../vids.csv --image_dir ../feats --text_dir ../../BriVL-git/BriVL_code_inference/feat/text_vivo_5w_songids/ --lyric_txt_path /data/home/sxhong/tools/get_lyric/data/vivo_top5w_songid_lyrics.lyric --save_csv test.csv
def get_ImgFeat(img_feats):
# 获取imgs_feat
img_dict = {}
for i, np_img_path in enumerate(img_feats):
img_dict[i] = np_img_path
if i == 0:
np_img = np.expand_dims(np.load(np_img_path).astype(np.float64), axis=0)
else:
np_img = np.concatenate((np_img, np.expand_dims(np.load(np_img_path).astype(np.float64), axis=0)), axis=0)
img = torch.from_numpy(np_img).cuda()
return img_dict, img
class Text_Dataset(data.Dataset):
def __init__(self, text_dir):
self.text_feats = glob.glob(os.path.join(text_dir, '*.npy'))
def __len__(self):
return len(self.text_feats)
def __getitem__(self, index):
text_path = self.text_feats[index]
songid = text_path.split('/')[-1].split('.')[0]
text_feat = np.load(text_path)
return songid, text_feat
def get_TextFeat(text_dir):
# pdb.set_trace()
# 获取texts_feat
dataset = Text_Dataset(text_dir)
dataloader = data.DataLoader(dataset, batch_size=5000, shuffle=False)
all_songids = []
for i, (songid_, text_feats) in enumerate(dataloader):
all_songids.extend(list(songid_))
if i == 0:
text = text_feats
else:
text = torch.cat((text, text_feats), 0)
text = text.squeeze(dim=1)
text = text.cuda()
return all_songids, text
def get_lyric(lyric_txt_path):
# songid对应的歌词
text_dict = {}
for line in open(lyric_txt_path):
try:
songid, text_query = line.split(',')[0], line.split('"')[1]
text_dict[songid] = text_query
except:
pass
return text_dict
vids = pd.read_csv(args.vid_csv_path, dtype=str)['vid'].to_list()
all_songids, text = get_TextFeat(args.text_dir)
text_dict = get_lyric(args.lyric_txt_path)
vid2songid = {}
for vid in vids:
img_feat_paths = glob.glob(os.path.join(args.image_dir, str(vid)+'_'+'*.npy'))
img_dict, img = get_ImgFeat(img_feat_paths)
N_img = img.shape[0]
N_text = text.shape[0]
scores = torch.zeros((N_img, N_text), dtype=torch.float32).cuda()
print('Pair-to-pair: calculating scores')
for i in tqdm(range(N_img)): # row: image col: text
scores[i, :] = torch.sum(img[i] * text, -1)
songid2score = {}
songids = []
songids2scores = []
for i, score in enumerate(scores):
indices = torch.argsort(score, descending=True)
songids = []
for idx in indices:
if len(songids) <= 80:
idx = int(idx.cpu().numpy())
query_text = text_dict[all_songids[idx]]
query_text = query_text.split(',')
query_text = (query_text[len(query_text) // 2]).replace(' ','').replace('*','')
for exist_songid in songids:
key_text = (text_dict[exist_songid]).replace(' ','').replace('*','')
if re.findall(query_text, key_text):
if float(songid2score[exist_songid]) < float(score[idx].cpu().numpy()):
songid2score.pop(exist_songid)
songid2score[all_songids[idx]] = str(score[idx].cpu().numpy())
break
songid2score[all_songids[idx]] = str(score[idx].cpu().numpy())
else:
break
sorted_songid2score = sorted(songid2score.items(), key=lambda x:float(x[1]), reverse=True)
select_songids = ', '.join([songid for songid, score in sorted_songid2score[:100]])
correspond_scores = ', '.join([score for songid, score in sorted_songid2score[:100]])
vid2songid[vid] = (select_songids, correspond_scores)
data = []
for vid, values in vid2songid.items():
songids, scores = values
data.append([vid, songids, scores])
df = pd.DataFrame(data, columns=['img', 'songids', 'scores'])
df.to_csv(args.save_csv, index=False) |