\n",
"['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n",
"2\n"
]
}
],
"source": [
"print(type(train_encodings))\n",
"print(dir(train_encodings))\n",
"print(len(train_encodings))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def get_labels(y: List[List[int]]):\n",
" labels = np.zeros((len(y), len(tag_to_index)))\n",
" for i in tqdm(range(len(y))):\n",
" labels[i, y[i]] = 1\n",
" return labels.tolist()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 28700/28700 [00:00<00:00, 388780.42it/s]\n",
"100%|██████████| 8200/8200 [00:00<00:00, 223262.03it/s]\n",
"100%|██████████| 4100/4100 [00:00<00:00, 165215.75it/s]\n"
]
}
],
"source": [
"labels_train = get_labels(y_train)\n",
"labels_val = get_labels(y_val)\n",
"labels_test = get_labels(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"train_encodings['labels'] = labels_train\n",
"val_encodings['labels'] = labels_val\n",
"test_encodings['labels'] = labels_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = Dataset.from_dict(train_encodings)\n",
"val_dataset = Dataset.from_dict(val_encodings)\n",
"test_dataset = Dataset.from_dict(test_encodings)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" USED_MODEL, \n",
" problem_type=\"multi_label_classification\", \n",
" num_labels=len(tag_to_index),\n",
" id2label=index_to_tag,\n",
" label2id=tag_to_index\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 8\n",
"metric_name = \"f1\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
" warnings.warn(\n"
]
}
],
"source": [
"args = TrainingArguments(\n",
" output_dir=f'train-{USED_MODEL}-baseline',\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" num_train_epochs=5,\n",
" weight_decay=0.01,\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=metric_name,\n",
" push_to_hub=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n",
"from transformers import EvalPrediction\n",
"import torch\n",
" \n",
"# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n",
"def multi_label_metrics(predictions, labels, threshold=0.5):\n",
" # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)\n",
" sigmoid = torch.nn.Sigmoid()\n",
" probs = sigmoid(torch.Tensor(predictions))\n",
" # next, use threshold to turn them into integer predictions\n",
" y_pred = np.zeros(probs.shape)\n",
" y_pred[np.where(probs >= threshold)] = 1\n",
" # finally, compute metrics\n",
" y_true = labels\n",
" f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n",
" roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n",
" accuracy = accuracy_score(y_true, y_pred)\n",
" # return as dictionary\n",
" metrics = {'f1': f1_micro_average,\n",
" 'roc_auc': roc_auc,\n",
" 'accuracy': accuracy}\n",
" return metrics\n",
"\n",
"def compute_metrics(p: EvalPrediction):\n",
" preds = p.predictions[0] if isinstance(p.predictions, \n",
" tuple) else p.predictions\n",
" result = multi_label_metrics(\n",
" predictions=preds, \n",
" labels=p.label_ids)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"train_dataset.set_format(\"torch\")\n",
"test_dataset.set_format(\"torch\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_571129/1751307119.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
" trainer = Trainer(\n"
]
}
],
"source": [
"trainer = Trainer(\n",
" model,\n",
" args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [17940/17940 32:06, Epoch 5/5]\n",
"
\n",
" \n",
" \n",
" \n",
" Epoch | \n",
" Training Loss | \n",
" Validation Loss | \n",
" F1 | \n",
" Roc Auc | \n",
" Accuracy | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 0.024000 | \n",
" 0.022899 | \n",
" 0.652954 | \n",
" 0.770167 | \n",
" 0.410366 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.020400 | \n",
" 0.020730 | \n",
" 0.673765 | \n",
" 0.785226 | \n",
" 0.426829 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.017900 | \n",
" 0.019692 | \n",
" 0.700292 | \n",
" 0.812313 | \n",
" 0.425000 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.016100 | \n",
" 0.019695 | \n",
" 0.701593 | \n",
" 0.812366 | \n",
" 0.433171 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.014800 | \n",
" 0.019767 | \n",
" 0.701193 | \n",
" 0.812710 | \n",
" 0.431707 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=17940, training_loss=0.02238395190159214, metrics={'train_runtime': 1927.2238, 'train_samples_per_second': 74.459, 'train_steps_per_second': 9.309, 'total_flos': 1.906093867776e+16, 'train_loss': 0.02238395190159214, 'epoch': 5.0})"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'eval_loss': 0.019695421680808067,\n",
" 'eval_f1': 0.7015928686248721,\n",
" 'eval_roc_auc': 0.8123655228058703,\n",
" 'eval_accuracy': 0.43317073170731707,\n",
" 'eval_runtime': 34.8656,\n",
" 'eval_samples_per_second': 235.189,\n",
" 'eval_steps_per_second': 29.399,\n",
" 'epoch': 5.0}"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluate(eval_dataset=val_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'eval_loss': 0.019682902842760086,\n",
" 'eval_f1': 0.6966158423205653,\n",
" 'eval_roc_auc': 0.8081637343174538,\n",
" 'eval_accuracy': 0.4370731707317073,\n",
" 'eval_runtime': 16.5771,\n",
" 'eval_samples_per_second': 247.329,\n",
" 'eval_steps_per_second': 30.946,\n",
" 'epoch': 5.0}"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluate(eval_dataset=test_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n",
"\n",
"Будем делать следующее:\n",
"- наша модель выдает логиты тегов\n",
"- посчитаем с их помощью вероятность каждого тега, считая сумму вероятностей равной 1\n",
"- посчитаем вероятность категории как сумму вероятностей тегов\n",
"- выведем требуемые топ-95% тематик"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"train_distilbert-base-cased/checkpoint-17940\", \n",
" problem_type=\"multi_label_classification\", \n",
" num_labels=len(tag_to_index),\n",
" id2label=index_to_tag,\n",
" label2id=tag_to_index\n",
").to(torch.device('cuda'))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SequenceClassifierOutput(loss=None, logits=tensor([[-1.3623, -5.3834, -3.3988, -3.4555, -3.7096, -4.5285, -5.1323, -2.3077,\n",
" -3.6645, -4.6847, -4.2481, -5.0417, -3.5121, -2.7808, -5.9767, -4.8864,\n",
" -5.6730, -4.6838, -3.8588, -5.2819, -3.9295, -2.7704, 0.4331, -4.5505,\n",
" -5.2648, -4.9248, -4.2074, -3.4895, -3.2717, -5.2713, -5.7536, -7.2749,\n",
" -4.8728, -5.2606, -4.5935, -4.7103, -5.4628, -5.4589, -5.3678, -3.5648,\n",
" -5.1455, -8.8455, -9.1583, -6.4358, -4.7737, -4.7821, -8.9264, -5.8790,\n",
" -4.7536, -5.4549, -5.3879, -6.1918, -4.1667, -7.1828, -7.3235, -5.4470,\n",
" -4.6688, -4.7201, -6.2949, -7.5401, -6.6242, -6.1022, -5.5325, -3.1546,\n",
" -9.4200, -5.2060, -5.3880, -6.8743, -3.3176, -7.2654, -7.4301, -3.0929,\n",
" -3.2351, -9.0408, -5.4315, -6.3230, -9.5853, -5.7075, -3.6443, -5.5524,\n",
" -6.0723, -6.0414, -7.3201, -3.9738, -5.5964, -4.0455, -5.2017, -5.8061,\n",
" -7.8401, -7.5268, -7.4576, -4.4483, -6.4790, -5.9085, -6.8822, -5.4498,\n",
" -6.7494, -6.1449, -5.9297, -6.4985, -5.0379, -4.9914, -5.5201, -7.9075,\n",
" -8.7653, -6.6116, -6.6643, -9.3863, -4.9038, -7.6509, -9.0117, -9.1193,\n",
" -5.3166, -5.4046, -8.3876, -4.9028, -3.5257, -8.9734, -6.1487, -8.1408,\n",
" -5.3014, -6.5494, -6.8383, -4.8011, -5.2831, -8.7708, -7.5039, -5.3957,\n",
" -7.3326, -3.6551, -4.9892, -5.9366, -5.2093, -5.2362, -5.0462, -6.5469,\n",
" -4.9182, -4.4108, -7.1632, -5.9481, -5.3291, -6.4517, -5.6950, -8.7276,\n",
" -5.7762, -8.9848, -7.3795, -5.4210, -5.6845, -2.9447, -3.6166, -3.6258,\n",
" -1.4417, -5.6568, -3.5869]], device='cuda:0',\n",
" grad_fn=), hidden_states=None, attentions=None)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad\n",
"def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
" text = f'{title} $ {summary}'\n",
" tags_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
" sigmoid = torch.nn.Sigmoid()\n",
" tags_probs = sigmoid(tags_logits.squeeze().cpu()).numpy()\n",
" tags_probs /= tags_probs.sum()\n",
" category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n",
" for index in range(len(index_to_tag)):\n",
" category_probs_dict[tag_to_category[index_to_tag[index]]] += float(tags_probs[index])\n",
" return category_probs_dict"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n",
" current_p = 0\n",
" probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n",
" current_index = 0\n",
" answer = []\n",
" while current_p <= target_probability:\n",
" current_p += probs_list[current_index][0]\n",
" if not print_probabilities:\n",
" answer.append(probs_list[current_index][1])\n",
" else:\n",
" answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n",
" current_index += 1\n",
" if current_index >= len(probs_list):\n",
" break\n",
" return answer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь нужно как-то сохранить модель, чтобы потом можно было её использовать в huggingface space"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"model.safetensors: 100%|██████████| 264M/264M [00:31<00:00, 8.47MB/s] \n"
]
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/bumchik2/train_distilbert-base-cased-tags-classification-simple/commit/98a87d7c96e0647dd557a9d47be03ddd30e0c964', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='98a87d7c96e0647dd557a9d47be03ddd30e0c964', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train_distilbert-base-cased-tags-classification-simple', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train_distilbert-base-cased-tags-classification-simple'), pr_revision=None, pr_num=None)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.push_to_hub(\"bumchik2/train_distilbert-base-cased-tags-classification-simple\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь я смогу загружать свою модель оттуда"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"bumchik2/train_distilbert-base-cased-tags-classification-simple\", \n",
" problem_type=\"multi_label_classification\", \n",
" num_labels=len(tag_to_index),\n",
" id2label=index_to_tag,\n",
" label2id=tag_to_index\n",
").to(torch.device('cuda'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Tricks",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}