# -*- encoding: utf-8 -*- ''' @File : text_feat_extractor.py @Time : 2021/08/26 10:46:15 @Author : Chuhao Jin @Email : jinchuhao@ruc.edu.cn ''' # here put the import lib import os import sys import pickle import argparse base_dir = os.path.abspath(os.path.dirname(__file__)) sys.path.append(base_dir) import torch import torch.nn as nn import torchvision.transforms as transforms import numpy as np from transformers import AutoTokenizer from utils import getLanMask from utils.config import cfg_from_yaml_file, cfg from models.vl_model import * from tqdm import tqdm import pdb import json class TextModel(nn.Module): def __init__(self, model_cfg): super(TextModel, self).__init__() self.model_cfg = model_cfg self.learnable = nn.ModuleDict() self.learnable['textencoder'] = TextLearnableEncoder(model_cfg) def forward(self, texts, maskTexts): textFea = self.learnable['textencoder'](texts, maskTexts) # textFea = F.normalize(textFea, p=2, dim=-1) return textFea class TextFeatureExtractor: def __init__(self, cfg_file, model_weights, gpu_id=0): self.gpu_id = gpu_id self.cfg_file = cfg_file self.cfg = cfg_from_yaml_file(self.cfg_file, cfg) self.cfg.MODEL.ENCODER = os.path.join(base_dir, self.cfg.MODEL.ENCODER) self.text_model = TextModel(model_cfg=self.cfg.MODEL) self.text_model = self.text_model.cuda(self.gpu_id) model_component = torch.load(model_weights, map_location=torch.device('cuda:{}'.format(self.gpu_id))) text_model_component = {} for key in model_component["learnable"].keys(): if "textencoder." in key: text_model_component[key] = model_component["learnable"][key] self.text_model.learnable.load_state_dict(text_model_component) self.text_model.eval() self.text_transform = AutoTokenizer.from_pretrained('./hfl/chinese-bert-wwm-ext') def extract(self, text_input): if text_input is None: return None else: text_info = self.text_transform(text_input, padding='max_length', truncation=True, max_length=self.cfg.MODEL.MAX_TEXT_LEN, return_tensors='pt') text = text_info.input_ids.reshape(-1) text_len = torch.sum(text_info.attention_mask) with torch.no_grad(): texts = text.unsqueeze(0) text_lens = text_len.unsqueeze(0) textMask = getLanMask(text_lens, cfg.MODEL.MAX_TEXT_LEN) textMask = textMask.cuda(self.gpu_id) texts = texts.cuda(self.gpu_id) text_lens = text_lens.cuda(self.gpu_id) text_fea = self.text_model(texts, textMask) text_fea = text_fea.cpu().numpy() return text_fea if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--txt_path', type=str, default=None) parser.add_argument('--feat_save_dir', type=str, default=None) parser.add_argument('--cfg_file', type=str, default='cfg/test_xyb.yml') parser.add_argument('--brivl_checkpoint', type=str, default='/innovation_cfs/mmatch/infguo/weights/BriVL-1.0-5500w.pth') args = parser.parse_args() cfg_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.cfg_file) model_weights = args.brivl_checkpoint vfe = TextFeatureExtractor(cfg_file, model_weights) save_dir = args.feat_save_dir if not os.path.exists(args.feat_save_dir): os.makedirs(args.feat_save_dir) for i in os.listdir(args.txt_path): clip_data = json.load(open(os.path.join(args.txt_path, i)), encoding='UTF-8') for clip in clip_data["clips"]: clip["multi_factor"] = {"semantics": None} if "original_text" in clip and clip["original_text"] and len(clip["original_text"]) > 0: text = clip["original_text"] fea = vfe.extract(text) fea = fea.squeeze(axis=0).tolist() clip["multi_factor"]["semantics"] = fea with open(os.path.join(args.feat_save_dir, i), "w", encoding="utf-8") as fp: json.dump(clip_data, fp, ensure_ascii=False, indent=4)