Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import random | |
import faiss | |
from src.index import Indexer | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from lightning import Fabric | |
from tqdm import tqdm | |
import argparse | |
from src.text_embedding import TextEmbeddingModel | |
from utils.load_dataset import load_dataset, TextDataset, load_outdomain_dataset | |
def load_pkl(path): | |
with open(path, 'rb') as f: | |
return pickle.load(f) | |
def infer(passages_dataloder,fabric,tokenizer,model,ood=False): | |
if fabric.global_rank == 0 : | |
passages_dataloder=tqdm(passages_dataloder,total=len(passages_dataloder)) | |
if ood: | |
allids, allembeddings,alllabels,all_is_mixed= [],[],[],[] | |
else: | |
allids, allembeddings,alllabels,all_is_mixed,all_write_model= [],[],[],[],[] | |
model.model.eval() | |
with torch.no_grad(): | |
for batch in passages_dataloder: | |
if ood: | |
ids, text, label, is_mixed = batch | |
encoded_batch = tokenizer.batch_encode_plus( | |
text, | |
return_tensors="pt", | |
max_length=512, | |
padding="max_length", | |
# padding=True, | |
truncation=True, | |
) | |
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} | |
# output = model(**encoded_batch).last_hidden_state | |
# embeddings = pooling(output, encoded_batch) | |
# print(encoded_batch) | |
embeddings = model(encoded_batch) | |
# print(encoded_batch['input_ids'].shape) | |
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1)) | |
label = fabric.all_gather(label).view(-1) | |
ids = fabric.all_gather(ids).view(-1) | |
is_mixed = fabric.all_gather(is_mixed).view(-1) | |
if fabric.global_rank == 0 : | |
allembeddings.append(embeddings.cpu()) | |
allids.extend(ids.cpu().tolist()) | |
alllabels.extend(label.cpu().tolist()) | |
all_is_mixed.extend(is_mixed.cpu().tolist()) | |
else: | |
ids, text, label, is_mixed, write_model = batch | |
encoded_batch = tokenizer.batch_encode_plus( | |
text, | |
return_tensors="pt", | |
max_length=512, | |
padding="max_length", | |
# padding=True, | |
truncation=True, | |
) | |
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} | |
# output = model(**encoded_batch).last_hidden_state | |
# embeddings = pooling(output, encoded_batch) | |
# print(encoded_batch) | |
embeddings = model(encoded_batch) | |
# print(encoded_batch['input_ids'].shape) | |
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1)) | |
label = fabric.all_gather(label).view(-1) | |
ids = fabric.all_gather(ids).view(-1) | |
is_mixed = fabric.all_gather(is_mixed).view(-1) | |
write_model = fabric.all_gather(write_model).view(-1) | |
if fabric.global_rank == 0 : | |
allembeddings.append(embeddings.cpu()) | |
allids.extend(ids.cpu().tolist()) | |
alllabels.extend(label.cpu().tolist()) | |
all_is_mixed.extend(is_mixed.cpu().tolist()) | |
all_write_model.extend(write_model.cpu().tolist()) | |
if fabric.global_rank == 0 : | |
allembeddings = torch.cat(allembeddings, dim=0) | |
epsilon = 1e-6 | |
if ood: | |
emb_dict,label_dict,is_mixed_dict={},{},{} | |
allembeddings= F.normalize(allembeddings,dim=-1) | |
for i in range(len(allids)): | |
emb_dict[allids[i]]=allembeddings[i] | |
label_dict[allids[i]]=alllabels[i] | |
is_mixed_dict[allids[i]]=all_is_mixed[i] | |
allids,allembeddings,alllabels,all_is_mixed=[],[],[],[] | |
for key in emb_dict: | |
allids.append(key) | |
allembeddings.append(emb_dict[key]) | |
alllabels.append(label_dict[key]) | |
all_is_mixed.append(is_mixed_dict[key]) | |
allembeddings = torch.stack(allembeddings, dim=0) | |
return allids,allembeddings.numpy(),alllabels,all_is_mixed | |
else: | |
emb_dict,label_dict,is_mixed_dict,write_model_dict={},{},{},{} | |
allembeddings= F.normalize(allembeddings,dim=-1) | |
for i in range(len(allids)): | |
emb_dict[allids[i]]=allembeddings[i] | |
label_dict[allids[i]]=alllabels[i] | |
is_mixed_dict[allids[i]]=all_is_mixed[i] | |
write_model_dict[allids[i]]=all_write_model[i] | |
allids,allembeddings,alllabels,all_is_mixed,all_write_model=[],[],[],[],[] | |
for key in emb_dict: | |
allids.append(key) | |
allembeddings.append(emb_dict[key]) | |
alllabels.append(label_dict[key]) | |
all_is_mixed.append(is_mixed_dict[key]) | |
all_write_model.append(write_model_dict[key]) | |
allembeddings = torch.stack(allembeddings, dim=0) | |
return allids, allembeddings.numpy(),alllabels,all_is_mixed,all_write_model | |
else: | |
if ood: | |
return [],[],[],[] | |
return [],[],[],[],[] | |
def set_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. | |
np.random.seed(seed) # Numpy module. | |
random.seed(seed) # Python random module. | |
def test(opt): | |
if opt.device_num>1: | |
fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp') | |
else: | |
fabric = Fabric(accelerator="cuda",devices=opt.device_num) | |
fabric.launch() | |
model = TextEmbeddingModel(opt.model_name).cuda() | |
state_dict = torch.load(opt.model_path, map_location=model.model.device) | |
new_state_dict={} | |
for key in state_dict.keys(): | |
if key.startswith('model.'): | |
new_state_dict[key[6:]]=state_dict[key] | |
model.load_state_dict(state_dict) | |
tokenizer=model.tokenizer | |
database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name] | |
passage_dataset = TextDataset(database,need_ids=True) | |
print(len(passage_dataset)) | |
passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True) | |
passages_dataloder=fabric.setup_dataloaders(passages_dataloder) | |
model=fabric.setup(model) | |
train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model) | |
fabric.barrier() | |
if fabric.global_rank == 0: | |
index = Indexer(opt.embedding_dim) | |
index.index_data(train_ids, train_embeddings) | |
label_dict={} | |
is_mixed_dict={} | |
write_model_dict={} | |
for i in range(len(train_ids)): | |
label_dict[train_ids[i]]=train_labels[i] | |
is_mixed_dict[train_ids[i]]=train_is_mixed[i] | |
write_model_dict[train_ids[i]]=train_write_model[i] | |
if not os.path.exists(opt.save_path): | |
os.makedirs(opt.save_path) | |
index.serialize(opt.save_path) | |
#save label_dict using pickle | |
with open(os.path.join(opt.save_path, 'label_dict.pkl'), 'wb') as f: | |
pickle.dump(label_dict, f) | |
#save is_mixed_dict using pickle | |
with open(os.path.join(opt.save_path, 'is_mixed_dict.pkl'), 'wb') as f: | |
pickle.dump(is_mixed_dict, f) | |
#save write_model_dict using pickle | |
with open(os.path.join(opt.save_path, 'write_model_dict.pkl'), 'wb') as f: | |
pickle.dump(write_model_dict, f) | |
def add_to_existed_index(opt): | |
if opt.device_num>1: | |
fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp') | |
else: | |
fabric = Fabric(accelerator="cuda",devices=opt.device_num) | |
fabric.launch() | |
model = TextEmbeddingModel(opt.model_name).cuda() | |
state_dict = torch.load(opt.model_path, map_location=model.model.device) | |
new_state_dict={} | |
for key in state_dict.keys(): | |
if key.startswith('model.'): | |
new_state_dict[key[6:]]=state_dict[key] | |
model.load_state_dict(state_dict) | |
tokenizer=model.tokenizer | |
if opt.ood: | |
database = load_outdomain_dataset(opt.database_path)[opt.database_name] | |
else: | |
database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name] | |
passage_dataset = TextDataset(database,need_ids=True,out_domain=opt.ood) | |
print(len(passage_dataset)) | |
passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True) | |
passages_dataloder=fabric.setup_dataloaders(passages_dataloder) | |
model=fabric.setup(model) | |
if opt.ood: | |
train_ids, train_embeddings,train_labels, train_is_mixed = infer(passages_dataloder,fabric,tokenizer,model,ood=True) | |
else: | |
train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model) | |
fabric.barrier() | |
if fabric.global_rank == 0: | |
new_index = Indexer(opt.embedding_dim) | |
new_index.index_data(train_ids, train_embeddings) | |
old_index = Indexer(opt.embedding_dim) | |
old_index.deserialize_from(opt.existed_index_path) | |
old_ids = old_index.index_id_to_db_id | |
# Ensure both indexes are of type IndexFlatIP | |
# assert isinstance(new_index.index, faiss.IndexFlatIP) | |
# assert isinstance(old_index.index, faiss.IndexFlatIP) | |
# Ensure both indexes have the same dimensionality | |
assert new_index.index.d == old_index.index.d | |
# Extract vectors from old_index.index | |
vectors = old_index.index.reconstruct_n(0, old_index.index.ntotal) | |
# Add vectors to new_index.index | |
new_index.index_data(old_ids, vectors) | |
if not os.path.exists(opt.new_save_path): | |
os.makedirs(opt.new_save_path) | |
new_index.serialize(opt.new_save_path) | |
if opt.ood: | |
label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl')) | |
is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl')) | |
for i in range(len(train_ids)): | |
label_dict[train_ids[i]]=train_labels[i] | |
is_mixed_dict[train_ids[i]]=train_is_mixed[i] | |
#save label_dict using pickle | |
with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f: | |
pickle.dump(label_dict, f) | |
#save is_mixed_dict using pickle | |
with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f: | |
pickle.dump(is_mixed_dict, f) | |
else: | |
label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl')) | |
is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl')) | |
write_model_dict=load_pkl(os.path.join(opt.existed_index_path, 'write_model_dict.pkl')) | |
for i in range(len(train_ids)): | |
label_dict[train_ids[i]]=train_labels[i] | |
is_mixed_dict[train_ids[i]]=train_is_mixed[i] | |
write_model_dict[train_ids[i]]=train_write_model[i] | |
#save label_dict using pickle | |
with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f: | |
pickle.dump(label_dict, f) | |
#save is_mixed_dict using pickle | |
with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f: | |
pickle.dump(is_mixed_dict, f) | |
#save write_model_dict using pickle | |
with open(os.path.join(opt.new_save_path, 'write_model_dict.pkl'), 'wb') as f: | |
pickle.dump(write_model_dict, f) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--device_num', type=int, default=1) | |
parser.add_argument('--batch_size', type=int, default=128) | |
parser.add_argument('--num_workers', type=int, default=8) | |
parser.add_argument('--embedding_dim', type=int, default=768) | |
# parser.add_argument('--mode', type=str, default='deepfake', help="deepfake,MGT or MGTDetect_CoCo") | |
parser.add_argument("--database_path", type=str, default="data/FALCONSet", help="Path to the data") | |
parser.add_argument('--dataset_name', type=str, default='falconset', help="falconset, llmdetectaive, hart") | |
parser.add_argument('--database_name', type=str, default='train', help="train,valid,test,test_ood") | |
parser.add_argument("--model_path", type=str, default="runs/authscan_v6/model_best.pth",\ | |
help="Path to the embedding model checkpoint") | |
parser.add_argument('--model_name', type=str, default="FacebookAI/xlm-roberta-base", help="Model name") | |
parser.add_argument("--save_path", type=str, default="/output", help="Path to save the database") | |
parser.add_argument("--add_to_existed_index", type=int, default=0) | |
# parser.add_argument("--add_to_existed_index_path", type=str, default="/output", help="Path to save the database") | |
parser.add_argument("--ood", type=int, default=0) | |
parser.add_argument("--existed_index_path", type=str, default="/output", help="Path of existed index") | |
parser.add_argument("--new_save_path", type=str, default="/new_db", help="Path to save the database") | |
parser.add_argument('--seed', type=int, default=0) | |
opt = parser.parse_args() | |
set_seed(opt.seed) | |
if not opt.add_to_existed_index: | |
test(opt) | |
else: | |
add_to_existed_index(opt) | |