File size: 4,962 Bytes
bb46cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import numpy as np
import random
from tqdm import tqdm
import argparse
import torch
import sys 
import glob
import csv
import torch.utils.data as data
import pandas as pd
import re
import pdb
sys.path.append(os.path.abspath(os.path.dirname(os.path.realpath(__file__))+'/'+'..')) 

parser = argparse.ArgumentParser()
parser.add_argument('--vid_csv_path', type=str, default=None)
parser.add_argument('--image_dir', type=str, default=None)
parser.add_argument('--text_dir', type=str, default=None)
parser.add_argument('--lyric_txt_path', type=str, default=None)
parser.add_argument('--save_csv', type=str, default=None)
parser.add_argument('--gpu', type=str, default='3')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu


# python video2img_retrieval.py --vid_csv_path ../vids.csv --image_dir ../feats --text_dir ../../BriVL-git/BriVL_code_inference/feat/text_vivo_5w_songids/ --lyric_txt_path  /data/home/sxhong/tools/get_lyric/data/vivo_top5w_songid_lyrics.lyric --save_csv test.csv


def get_ImgFeat(img_feats):
    # 获取imgs_feat
    img_dict = {}
    for i, np_img_path in enumerate(img_feats):
        img_dict[i] = np_img_path
        if i == 0:
            np_img = np.expand_dims(np.load(np_img_path).astype(np.float64), axis=0)
        else:
            np_img = np.concatenate((np_img, np.expand_dims(np.load(np_img_path).astype(np.float64), axis=0)), axis=0)
    img = torch.from_numpy(np_img).cuda()
    return img_dict, img

      
class Text_Dataset(data.Dataset):
    def __init__(self, text_dir):
        self.text_feats = glob.glob(os.path.join(text_dir, '*.npy'))
        
    def __len__(self):
        return len(self.text_feats)
    
    def __getitem__(self, index):
        text_path = self.text_feats[index]
        songid = text_path.split('/')[-1].split('.')[0]
        text_feat = np.load(text_path)
        return songid, text_feat

def get_TextFeat(text_dir):
    # pdb.set_trace()
    # 获取texts_feat 
    dataset = Text_Dataset(text_dir)
    dataloader = data.DataLoader(dataset, batch_size=5000, shuffle=False)
    all_songids = []
    for i, (songid_, text_feats) in enumerate(dataloader):
        all_songids.extend(list(songid_))
        if i == 0:
            text = text_feats
        else:
            text = torch.cat((text, text_feats), 0)
    text = text.squeeze(dim=1)
    text = text.cuda()
    return all_songids, text


def get_lyric(lyric_txt_path):
    # songid对应的歌词
    text_dict = {}
    for line in open(lyric_txt_path):
        try:
            songid, text_query = line.split(',')[0], line.split('"')[1]
            text_dict[songid] = text_query
        except:
            pass
    return text_dict


vids = pd.read_csv(args.vid_csv_path, dtype=str)['vid'].to_list()
all_songids, text = get_TextFeat(args.text_dir)
text_dict = get_lyric(args.lyric_txt_path)
vid2songid = {}

for vid in vids:    
    img_feat_paths = glob.glob(os.path.join(args.image_dir, str(vid)+'_'+'*.npy')) 
    img_dict, img = get_ImgFeat(img_feat_paths)
    N_img = img.shape[0]
    N_text = text.shape[0]
    scores = torch.zeros((N_img, N_text), dtype=torch.float32).cuda()

    print('Pair-to-pair: calculating scores')
    for i in tqdm(range(N_img)): # row: image  col: text
        scores[i, :] = torch.sum(img[i] * text, -1)
    
    songid2score = {}
    songids = []
    songids2scores = []
    for i, score in enumerate(scores):
        indices = torch.argsort(score, descending=True)
        songids = []
        for idx in indices:
            if len(songids) <= 80:
                idx = int(idx.cpu().numpy())
                query_text = text_dict[all_songids[idx]]
                query_text = query_text.split(',')
                query_text = (query_text[len(query_text) // 2]).replace(' ','').replace('*','')
                for exist_songid in songids:
                    key_text = (text_dict[exist_songid]).replace(' ','').replace('*','')
                    if re.findall(query_text, key_text):
                        if float(songid2score[exist_songid]) < float(score[idx].cpu().numpy()):
                            songid2score.pop(exist_songid)
                            songid2score[all_songids[idx]] = str(score[idx].cpu().numpy())
                            break
                
                songid2score[all_songids[idx]] = str(score[idx].cpu().numpy())
            else:
                break
    sorted_songid2score = sorted(songid2score.items(), key=lambda x:float(x[1]), reverse=True)
    select_songids = ', '.join([songid for songid, score in sorted_songid2score[:100]])
    correspond_scores = ', '.join([score for songid, score in sorted_songid2score[:100]])
    vid2songid[vid] = (select_songids, correspond_scores)

data = []
for vid, values in vid2songid.items():
    songids, scores = values
    data.append([vid, songids, scores])

df = pd.DataFrame(data, columns=['img', 'songids', 'scores'])
df.to_csv(args.save_csv, index=False)