|
import logging |
|
import math |
|
import os |
|
import sys |
|
from dataclasses import dataclass, field |
|
from typing import Optional, Union, List, Dict, Tuple |
|
import torch |
|
import collections |
|
import random |
|
|
|
from datasets import load_dataset |
|
|
|
import transformers |
|
from transformers import ( |
|
CONFIG_MAPPING, |
|
MODEL_FOR_MASKED_LM_MAPPING, |
|
AutoConfig, |
|
AutoModelForMaskedLM, |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
DataCollatorForLanguageModeling, |
|
DataCollatorWithPadding, |
|
HfArgumentParser, |
|
Trainer, |
|
TrainingArguments, |
|
default_data_collator, |
|
set_seed, |
|
EvalPrediction, |
|
BertModel, |
|
BertForPreTraining, |
|
RobertaModel |
|
) |
|
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase |
|
from transformers.trainer_utils import is_main_process |
|
from transformers.data.data_collator import DataCollatorForLanguageModeling |
|
from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available |
|
from simcse.models import RobertaForCL, BertForCL |
|
from simcse.trainers import CLTrainer |
|
|
|
logger = logging.getLogger(__name__) |
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) |
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
|
""" |
|
|
|
|
|
model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The model checkpoint for weights initialization." |
|
"Don't set if you want to train a model from scratch." |
|
}, |
|
) |
|
model_type: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, |
|
) |
|
config_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
|
) |
|
model_revision: str = field( |
|
default="main", |
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
|
) |
|
use_auth_token: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " |
|
"with private models)." |
|
}, |
|
) |
|
|
|
|
|
temp: float = field( |
|
default=0.05, |
|
metadata={ |
|
"help": "Temperature for softmax." |
|
} |
|
) |
|
pooler_type: str = field( |
|
default="cls", |
|
metadata={ |
|
"help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)." |
|
} |
|
) |
|
hard_negative_weight: float = field( |
|
default=0, |
|
metadata={ |
|
"help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)." |
|
} |
|
) |
|
do_mlm: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to use MLM auxiliary objective." |
|
} |
|
) |
|
mlm_weight: float = field( |
|
default=0.1, |
|
metadata={ |
|
"help": "Weight for MLM auxiliary objective (only effective if --do_mlm)." |
|
} |
|
) |
|
mlp_only_train: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Use MLP only during training" |
|
} |
|
) |
|
|
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
|
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
validation_split_percentage: Optional[int] = field( |
|
default=5, |
|
metadata={ |
|
"help": "The percentage of the train set used as validation set in case there's no validation split" |
|
}, |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
|
|
|
|
train_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "The training data file (.txt or .csv)."} |
|
) |
|
max_seq_length: Optional[int] = field( |
|
default=32, |
|
metadata={ |
|
"help": "The maximum total input sequence length after tokenization. Sequences longer " |
|
"than this will be truncated." |
|
}, |
|
) |
|
pad_to_max_length: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to pad all samples to `max_seq_length`. " |
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch." |
|
}, |
|
) |
|
mlm_probability: float = field( |
|
default=0.15, |
|
metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"} |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." |
|
|
|
|
|
@dataclass |
|
class OurTrainingArguments(TrainingArguments): |
|
|
|
|
|
|
|
|
|
eval_transfer: bool = field( |
|
default=False, |
|
metadata={"help": "Evaluate transfer task dev sets (in validation)."} |
|
) |
|
|
|
@cached_property |
|
@torch_required |
|
def _setup_devices(self) -> "torch.device": |
|
logger.info("PyTorch: setting up devices") |
|
if self.no_cuda: |
|
device = torch.device("cpu") |
|
self._n_gpu = 0 |
|
elif is_torch_tpu_available(): |
|
import torch_xla.core.xla_model as xm |
|
device = xm.xla_device() |
|
self._n_gpu = 0 |
|
elif self.local_rank == -1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self._n_gpu = torch.cuda.device_count() |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.deepspeed: |
|
from .integrations import is_deepspeed_available |
|
|
|
if not is_deepspeed_available(): |
|
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") |
|
import deepspeed |
|
|
|
deepspeed.init_distributed() |
|
else: |
|
torch.distributed.init_process_group(backend="nccl") |
|
device = torch.device("cuda", self.local_rank) |
|
self._n_gpu = 1 |
|
|
|
if device.type == "cuda": |
|
torch.cuda.set_device(device) |
|
|
|
return device |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments)) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
if ( |
|
os.path.exists(training_args.output_dir) |
|
and os.listdir(training_args.output_dir) |
|
and training_args.do_train |
|
and not training_args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty." |
|
"Use --overwrite_output_dir to overcome." |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, |
|
) |
|
|
|
|
|
logger.warning( |
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
|
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
|
) |
|
|
|
if is_main_process(training_args.local_rank): |
|
transformers.utils.logging.set_verbosity_info() |
|
transformers.utils.logging.enable_default_handler() |
|
transformers.utils.logging.enable_explicit_format() |
|
logger.info("Training/evaluation parameters %s", training_args) |
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
extension = data_args.train_file.split(".")[-1] |
|
if extension == "txt": |
|
extension = "text" |
|
if extension == "csv": |
|
datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/", delimiter="\t" if "tsv" in data_args.train_file else ",") |
|
else: |
|
datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_kwargs = { |
|
"cache_dir": model_args.cache_dir, |
|
"revision": model_args.model_revision, |
|
"use_auth_token": True if model_args.use_auth_token else None, |
|
} |
|
if model_args.config_name: |
|
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) |
|
elif model_args.model_name_or_path: |
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) |
|
else: |
|
config = CONFIG_MAPPING[model_args.model_type]() |
|
logger.warning("You are instantiating a new config instance from scratch.") |
|
|
|
tokenizer_kwargs = { |
|
"cache_dir": model_args.cache_dir, |
|
"use_fast": model_args.use_fast_tokenizer, |
|
"revision": model_args.model_revision, |
|
"use_auth_token": True if model_args.use_auth_token else None, |
|
} |
|
if model_args.tokenizer_name: |
|
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) |
|
elif model_args.model_name_or_path: |
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) |
|
else: |
|
raise ValueError( |
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script." |
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name." |
|
) |
|
|
|
if model_args.model_name_or_path: |
|
if 'roberta' in model_args.model_name_or_path: |
|
model = RobertaForCL.from_pretrained( |
|
model_args.model_name_or_path, |
|
from_tf=bool(".ckpt" in model_args.model_name_or_path), |
|
config=config, |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
model_args=model_args |
|
) |
|
elif 'bert' in model_args.model_name_or_path: |
|
model = BertForCL.from_pretrained( |
|
model_args.model_name_or_path, |
|
from_tf=bool(".ckpt" in model_args.model_name_or_path), |
|
config=config, |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
model_args=model_args |
|
) |
|
if model_args.do_mlm: |
|
pretrained_model = BertForPreTraining.from_pretrained(model_args.model_name_or_path) |
|
model.lm_head.load_state_dict(pretrained_model.cls.predictions.state_dict()) |
|
else: |
|
raise NotImplementedError |
|
else: |
|
raise NotImplementedError |
|
logger.info("Training new model from scratch") |
|
model = AutoModelForMaskedLM.from_config(config) |
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
column_names = datasets["train"].column_names |
|
sent2_cname = None |
|
if len(column_names) == 2: |
|
|
|
sent0_cname = column_names[0] |
|
sent1_cname = column_names[1] |
|
elif len(column_names) == 3: |
|
|
|
sent0_cname = column_names[0] |
|
sent1_cname = column_names[1] |
|
sent2_cname = column_names[2] |
|
elif len(column_names) == 1: |
|
|
|
sent0_cname = column_names[0] |
|
sent1_cname = column_names[0] |
|
else: |
|
raise NotImplementedError |
|
|
|
def prepare_features(examples): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total = len(examples[sent0_cname]) |
|
|
|
|
|
for idx in range(total): |
|
if examples[sent0_cname][idx] is None: |
|
examples[sent0_cname][idx] = " " |
|
if examples[sent1_cname][idx] is None: |
|
examples[sent1_cname][idx] = " " |
|
|
|
sentences = examples[sent0_cname] + examples[sent1_cname] |
|
|
|
|
|
if sent2_cname is not None: |
|
for idx in range(total): |
|
if examples[sent2_cname][idx] is None: |
|
examples[sent2_cname][idx] = " " |
|
sentences += examples[sent2_cname] |
|
|
|
sent_features = tokenizer( |
|
sentences, |
|
max_length=data_args.max_seq_length, |
|
truncation=True, |
|
padding="max_length" if data_args.pad_to_max_length else False, |
|
) |
|
|
|
features = {} |
|
if sent2_cname is not None: |
|
for key in sent_features: |
|
features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)] |
|
else: |
|
for key in sent_features: |
|
features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)] |
|
|
|
return features |
|
|
|
if training_args.do_train: |
|
train_dataset = datasets["train"].map( |
|
prepare_features, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
|
|
@dataclass |
|
class OurDataCollatorWithPadding: |
|
|
|
tokenizer: PreTrainedTokenizerBase |
|
padding: Union[bool, str, PaddingStrategy] = True |
|
max_length: Optional[int] = None |
|
pad_to_multiple_of: Optional[int] = None |
|
mlm: bool = True |
|
mlm_probability: float = data_args.mlm_probability |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
special_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'mlm_input_ids', 'mlm_labels'] |
|
bs = len(features) |
|
if bs > 0: |
|
num_sent = len(features[0]['input_ids']) |
|
else: |
|
return |
|
flat_features = [] |
|
for feature in features: |
|
for i in range(num_sent): |
|
flat_features.append({k: feature[k][i] if k in special_keys else feature[k] for k in feature}) |
|
|
|
batch = self.tokenizer.pad( |
|
flat_features, |
|
padding=self.padding, |
|
max_length=self.max_length, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors="pt", |
|
) |
|
if model_args.do_mlm: |
|
batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"]) |
|
|
|
batch = {k: batch[k].view(bs, num_sent, -1) if k in special_keys else batch[k].view(bs, num_sent, -1)[:, 0] for k in batch} |
|
|
|
if "label" in batch: |
|
batch["labels"] = batch["label"] |
|
del batch["label"] |
|
if "label_ids" in batch: |
|
batch["labels"] = batch["label_ids"] |
|
del batch["label_ids"] |
|
|
|
return batch |
|
|
|
def mask_tokens( |
|
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
|
""" |
|
inputs = inputs.clone() |
|
labels = inputs.clone() |
|
|
|
probability_matrix = torch.full(labels.shape, self.mlm_probability) |
|
if special_tokens_mask is None: |
|
special_tokens_mask = [ |
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
|
] |
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
|
else: |
|
special_tokens_mask = special_tokens_mask.bool() |
|
|
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
|
inputs[indices_random] = random_words[indices_random] |
|
|
|
|
|
return inputs, labels |
|
|
|
data_collator = default_data_collator if data_args.pad_to_max_length else OurDataCollatorWithPadding(tokenizer) |
|
|
|
trainer = CLTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset if training_args.do_train else None, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
) |
|
trainer.model_args = model_args |
|
|
|
|
|
if training_args.do_train: |
|
model_path = ( |
|
model_args.model_name_or_path |
|
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) |
|
else None |
|
) |
|
train_result = trainer.train(model_path=model_path) |
|
trainer.save_model() |
|
|
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") |
|
if trainer.is_world_process_zero(): |
|
with open(output_train_file, "w") as writer: |
|
logger.info("***** Train results *****") |
|
for key, value in sorted(train_result.metrics.items()): |
|
logger.info(f" {key} = {value}") |
|
writer.write(f"{key} = {value}\n") |
|
|
|
|
|
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) |
|
|
|
|
|
results = {} |
|
if training_args.do_eval: |
|
logger.info("*** Evaluate ***") |
|
results = trainer.evaluate(eval_senteval_transfer=True) |
|
|
|
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") |
|
if trainer.is_world_process_zero(): |
|
with open(output_eval_file, "w") as writer: |
|
logger.info("***** Eval results *****") |
|
for key, value in sorted(results.items()): |
|
logger.info(f" {key} = {value}") |
|
writer.write(f"{key} = {value}\n") |
|
|
|
return results |
|
|
|
def _mp_fn(index): |
|
|
|
main() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|