File size: 4,332 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- encoding: utf-8 -*-
'''
@File    :   text_feat_extractor.py
@Time    :   2021/08/26 10:46:15
@Author  :   Chuhao Jin
@Email   :   jinchuhao@ruc.edu.cn
'''

# here put the import lib

import os
import sys
import pickle
import argparse

base_dir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(base_dir)
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
from transformers import AutoTokenizer

from utils import getLanMask
from utils.config import cfg_from_yaml_file, cfg
from models.vl_model import *
from tqdm import tqdm
import pdb
import json


class TextModel(nn.Module):
    def __init__(self, model_cfg):
        super(TextModel, self).__init__()

        self.model_cfg = model_cfg

        self.learnable = nn.ModuleDict()
        self.learnable['textencoder'] = TextLearnableEncoder(model_cfg)

    def forward(self, texts, maskTexts):
        textFea = self.learnable['textencoder'](texts, maskTexts)  # <bsz, img_dim>
        textFea = F.normalize(textFea, p=2, dim=-1)
        return textFea


class TextFeatureExtractor:
    def __init__(self, cfg_file, model_weights, gpu_id=0):
        self.gpu_id = gpu_id
        self.cfg_file = cfg_file
        self.cfg = cfg_from_yaml_file(self.cfg_file, cfg)
        self.cfg.MODEL.ENCODER = os.path.join(base_dir, self.cfg.MODEL.ENCODER)
        self.text_model = TextModel(model_cfg=self.cfg.MODEL)

        self.text_model = self.text_model.cuda(self.gpu_id)
        model_component = torch.load(model_weights, map_location=torch.device('cuda:{}'.format(self.gpu_id)))
        text_model_component = {}
        for key in model_component["learnable"].keys():
            if "textencoder." in key:
                text_model_component[key] = model_component["learnable"][key]
        self.text_model.learnable.load_state_dict(text_model_component)
        self.text_model.eval()

        self.text_transform = AutoTokenizer.from_pretrained('./hfl/chinese-bert-wwm-ext')

    def extract(self, text_input):
        if text_input is None:
            return None
        else:
            text_info = self.text_transform(text_input, padding='max_length', truncation=True,
                                            max_length=self.cfg.MODEL.MAX_TEXT_LEN, return_tensors='pt')
            text = text_info.input_ids.reshape(-1)
            text_len = torch.sum(text_info.attention_mask)
            with torch.no_grad():
                texts = text.unsqueeze(0)
                text_lens = text_len.unsqueeze(0)
                textMask = getLanMask(text_lens, cfg.MODEL.MAX_TEXT_LEN)
                textMask = textMask.cuda(self.gpu_id)
                texts = texts.cuda(self.gpu_id)
                text_lens = text_lens.cuda(self.gpu_id)
                text_fea = self.text_model(texts, textMask)
                text_fea = text_fea.cpu().numpy()
            return text_fea


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--txt_path', type=str, default=None)
    parser.add_argument('--feat_save_dir', type=str, default=None)
    parser.add_argument('--cfg_file', type=str, default='cfg/test_xyb.yml')
    parser.add_argument('--brivl_checkpoint', type=str,
                        default='/innovation_cfs/mmatch/infguo/weights/BriVL-1.0-5500w.pth')
    args = parser.parse_args()

    cfg_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.cfg_file)
    model_weights = args.brivl_checkpoint
    vfe = TextFeatureExtractor(cfg_file, model_weights)

    save_dir = args.feat_save_dir
    if not os.path.exists(args.feat_save_dir):
        os.makedirs(args.feat_save_dir)

    for i in os.listdir(args.txt_path):
        clip_data = json.load(open(os.path.join(args.txt_path, i)), encoding='UTF-8')
        for clip in clip_data["clips"]:
            clip["multi_factor"] = {"semantics": None}
            if "original_text" in clip and clip["original_text"] and len(clip["original_text"]) > 0:
                text = clip["original_text"]
                fea = vfe.extract(text)
                fea = fea.squeeze(axis=0).tolist()
                clip["multi_factor"]["semantics"] = fea
        with open(os.path.join(args.feat_save_dir, i), "w", encoding="utf-8") as fp:
            json.dump(clip_data, fp, ensure_ascii=False, indent=4)