import os import pickle import random import faiss from src.index import Indexer import torch import torch.nn.functional as F import numpy as np from torch.utils.data import DataLoader from lightning import Fabric from tqdm import tqdm import argparse from src.text_embedding import TextEmbeddingModel from utils.load_dataset import load_dataset, TextDataset, load_outdomain_dataset def load_pkl(path): with open(path, 'rb') as f: return pickle.load(f) def infer(passages_dataloder,fabric,tokenizer,model,ood=False): if fabric.global_rank == 0 : passages_dataloder=tqdm(passages_dataloder,total=len(passages_dataloder)) if ood: allids, allembeddings,alllabels,all_is_mixed= [],[],[],[] else: allids, allembeddings,alllabels,all_is_mixed,all_write_model= [],[],[],[],[] model.model.eval() with torch.no_grad(): for batch in passages_dataloder: if ood: ids, text, label, is_mixed = batch encoded_batch = tokenizer.batch_encode_plus( text, return_tensors="pt", max_length=512, padding="max_length", # padding=True, truncation=True, ) encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} # output = model(**encoded_batch).last_hidden_state # embeddings = pooling(output, encoded_batch) # print(encoded_batch) embeddings = model(encoded_batch) # print(encoded_batch['input_ids'].shape) embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1)) label = fabric.all_gather(label).view(-1) ids = fabric.all_gather(ids).view(-1) is_mixed = fabric.all_gather(is_mixed).view(-1) if fabric.global_rank == 0 : allembeddings.append(embeddings.cpu()) allids.extend(ids.cpu().tolist()) alllabels.extend(label.cpu().tolist()) all_is_mixed.extend(is_mixed.cpu().tolist()) else: ids, text, label, is_mixed, write_model = batch encoded_batch = tokenizer.batch_encode_plus( text, return_tensors="pt", max_length=512, padding="max_length", # padding=True, truncation=True, ) encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} # output = model(**encoded_batch).last_hidden_state # embeddings = pooling(output, encoded_batch) # print(encoded_batch) embeddings = model(encoded_batch) # print(encoded_batch['input_ids'].shape) embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1)) label = fabric.all_gather(label).view(-1) ids = fabric.all_gather(ids).view(-1) is_mixed = fabric.all_gather(is_mixed).view(-1) write_model = fabric.all_gather(write_model).view(-1) if fabric.global_rank == 0 : allembeddings.append(embeddings.cpu()) allids.extend(ids.cpu().tolist()) alllabels.extend(label.cpu().tolist()) all_is_mixed.extend(is_mixed.cpu().tolist()) all_write_model.extend(write_model.cpu().tolist()) if fabric.global_rank == 0 : allembeddings = torch.cat(allembeddings, dim=0) epsilon = 1e-6 if ood: emb_dict,label_dict,is_mixed_dict={},{},{} allembeddings= F.normalize(allembeddings,dim=-1) for i in range(len(allids)): emb_dict[allids[i]]=allembeddings[i] label_dict[allids[i]]=alllabels[i] is_mixed_dict[allids[i]]=all_is_mixed[i] allids,allembeddings,alllabels,all_is_mixed=[],[],[],[] for key in emb_dict: allids.append(key) allembeddings.append(emb_dict[key]) alllabels.append(label_dict[key]) all_is_mixed.append(is_mixed_dict[key]) allembeddings = torch.stack(allembeddings, dim=0) return allids,allembeddings.numpy(),alllabels,all_is_mixed else: emb_dict,label_dict,is_mixed_dict,write_model_dict={},{},{},{} allembeddings= F.normalize(allembeddings,dim=-1) for i in range(len(allids)): emb_dict[allids[i]]=allembeddings[i] label_dict[allids[i]]=alllabels[i] is_mixed_dict[allids[i]]=all_is_mixed[i] write_model_dict[allids[i]]=all_write_model[i] allids,allembeddings,alllabels,all_is_mixed,all_write_model=[],[],[],[],[] for key in emb_dict: allids.append(key) allembeddings.append(emb_dict[key]) alllabels.append(label_dict[key]) all_is_mixed.append(is_mixed_dict[key]) all_write_model.append(write_model_dict[key]) allembeddings = torch.stack(allembeddings, dim=0) return allids, allembeddings.numpy(),alllabels,all_is_mixed,all_write_model else: if ood: return [],[],[],[] return [],[],[],[],[] def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. np.random.seed(seed) # Numpy module. random.seed(seed) # Python random module. def test(opt): if opt.device_num>1: fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp') else: fabric = Fabric(accelerator="cuda",devices=opt.device_num) fabric.launch() model = TextEmbeddingModel(opt.model_name).cuda() state_dict = torch.load(opt.model_path, map_location=model.model.device) new_state_dict={} for key in state_dict.keys(): if key.startswith('model.'): new_state_dict[key[6:]]=state_dict[key] model.load_state_dict(state_dict) tokenizer=model.tokenizer database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name] passage_dataset = TextDataset(database,need_ids=True) print(len(passage_dataset)) passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True) passages_dataloder=fabric.setup_dataloaders(passages_dataloder) model=fabric.setup(model) train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model) fabric.barrier() if fabric.global_rank == 0: index = Indexer(opt.embedding_dim) index.index_data(train_ids, train_embeddings) label_dict={} is_mixed_dict={} write_model_dict={} for i in range(len(train_ids)): label_dict[train_ids[i]]=train_labels[i] is_mixed_dict[train_ids[i]]=train_is_mixed[i] write_model_dict[train_ids[i]]=train_write_model[i] if not os.path.exists(opt.save_path): os.makedirs(opt.save_path) index.serialize(opt.save_path) #save label_dict using pickle with open(os.path.join(opt.save_path, 'label_dict.pkl'), 'wb') as f: pickle.dump(label_dict, f) #save is_mixed_dict using pickle with open(os.path.join(opt.save_path, 'is_mixed_dict.pkl'), 'wb') as f: pickle.dump(is_mixed_dict, f) #save write_model_dict using pickle with open(os.path.join(opt.save_path, 'write_model_dict.pkl'), 'wb') as f: pickle.dump(write_model_dict, f) def add_to_existed_index(opt): if opt.device_num>1: fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp') else: fabric = Fabric(accelerator="cuda",devices=opt.device_num) fabric.launch() model = TextEmbeddingModel(opt.model_name).cuda() state_dict = torch.load(opt.model_path, map_location=model.model.device) new_state_dict={} for key in state_dict.keys(): if key.startswith('model.'): new_state_dict[key[6:]]=state_dict[key] model.load_state_dict(state_dict) tokenizer=model.tokenizer if opt.ood: database = load_outdomain_dataset(opt.database_path)[opt.database_name] else: database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name] passage_dataset = TextDataset(database,need_ids=True,out_domain=opt.ood) print(len(passage_dataset)) passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True) passages_dataloder=fabric.setup_dataloaders(passages_dataloder) model=fabric.setup(model) if opt.ood: train_ids, train_embeddings,train_labels, train_is_mixed = infer(passages_dataloder,fabric,tokenizer,model,ood=True) else: train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model) fabric.barrier() if fabric.global_rank == 0: new_index = Indexer(opt.embedding_dim) new_index.index_data(train_ids, train_embeddings) old_index = Indexer(opt.embedding_dim) old_index.deserialize_from(opt.existed_index_path) old_ids = old_index.index_id_to_db_id # Ensure both indexes are of type IndexFlatIP # assert isinstance(new_index.index, faiss.IndexFlatIP) # assert isinstance(old_index.index, faiss.IndexFlatIP) # Ensure both indexes have the same dimensionality assert new_index.index.d == old_index.index.d # Extract vectors from old_index.index vectors = old_index.index.reconstruct_n(0, old_index.index.ntotal) # Add vectors to new_index.index new_index.index_data(old_ids, vectors) if not os.path.exists(opt.new_save_path): os.makedirs(opt.new_save_path) new_index.serialize(opt.new_save_path) if opt.ood: label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl')) is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl')) for i in range(len(train_ids)): label_dict[train_ids[i]]=train_labels[i] is_mixed_dict[train_ids[i]]=train_is_mixed[i] #save label_dict using pickle with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f: pickle.dump(label_dict, f) #save is_mixed_dict using pickle with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f: pickle.dump(is_mixed_dict, f) else: label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl')) is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl')) write_model_dict=load_pkl(os.path.join(opt.existed_index_path, 'write_model_dict.pkl')) for i in range(len(train_ids)): label_dict[train_ids[i]]=train_labels[i] is_mixed_dict[train_ids[i]]=train_is_mixed[i] write_model_dict[train_ids[i]]=train_write_model[i] #save label_dict using pickle with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f: pickle.dump(label_dict, f) #save is_mixed_dict using pickle with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f: pickle.dump(is_mixed_dict, f) #save write_model_dict using pickle with open(os.path.join(opt.new_save_path, 'write_model_dict.pkl'), 'wb') as f: pickle.dump(write_model_dict, f) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--device_num', type=int, default=1) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--embedding_dim', type=int, default=768) # parser.add_argument('--mode', type=str, default='deepfake', help="deepfake,MGT or MGTDetect_CoCo") parser.add_argument("--database_path", type=str, default="data/FALCONSet", help="Path to the data") parser.add_argument('--dataset_name', type=str, default='falconset', help="falconset, llmdetectaive, hart") parser.add_argument('--database_name', type=str, default='train', help="train,valid,test,test_ood") parser.add_argument("--model_path", type=str, default="runs/authscan_v6/model_best.pth",\ help="Path to the embedding model checkpoint") parser.add_argument('--model_name', type=str, default="FacebookAI/xlm-roberta-base", help="Model name") parser.add_argument("--save_path", type=str, default="/output", help="Path to save the database") parser.add_argument("--add_to_existed_index", type=int, default=0) # parser.add_argument("--add_to_existed_index_path", type=str, default="/output", help="Path to save the database") parser.add_argument("--ood", type=int, default=0) parser.add_argument("--existed_index_path", type=str, default="/output", help="Path of existed index") parser.add_argument("--new_save_path", type=str, default="/new_db", help="Path to save the database") parser.add_argument('--seed', type=int, default=0) opt = parser.parse_args() set_seed(opt.seed) if not opt.add_to_existed_index: test(opt) else: add_to_existed_index(opt)