{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import pipeline\n", "import json\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from transformers import DistilBertTokenizer\n", "from tqdm import tqdm\n", "import re\n", "from datasets import Dataset\n", "from transformers import AutoModelForSequenceClassification\n", "import torch\n", "import numpy as np\n", "from typing import Dict\n", "from transformers import AutoModel\n", "from typing import List\n", "from transformers import TrainingArguments, Trainer\n", "from collections import defaultdict" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "USED_MODEL = \"distilbert-base-cased\"" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def read_json(json_filename):\n", " with open(json_filename, 'r') as f:\n", " return json.loads(f.read())\n", "\n", "\n", "def save_json(json_object, json_filename, indent=4):\n", " with open(json_filename, 'w') as f:\n", " json.dump(json_object, f, separators=(',', ':'), indent=indent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Данные берем отсюда: https://www.kaggle.com/datasets/neelshah18/arxivdataset**" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "arxiv_data = read_json('arxivData.json')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'author': \"[{'name': 'Ahmed Osman'}, {'name': 'Wojciech Samek'}]\",\n", " 'day': 1,\n", " 'id': '1802.00209v1',\n", " 'link': \"[{'rel': 'alternate', 'href': 'http://arxiv.org/abs/1802.00209v1', 'type': 'text/html'}, {'rel': 'related', 'href': 'http://arxiv.org/pdf/1802.00209v1', 'type': 'application/pdf', 'title': 'pdf'}]\",\n", " 'month': 2,\n", " 'summary': 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.',\n", " 'tag': \"[{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CL', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CV', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.NE', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'stat.ML', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}]\",\n", " 'title': 'Dual Recurrent Attention Units for Visual Question Answering',\n", " 'year': 2018}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arxiv_data[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Хотим по названию статьи + abstract выдавать наиболее вероятную тематику статьи, скажем, физика, биология или computer science** " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "155\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tagtopiccategory
0cs.AIArtificial IntelligenceComputer Science
1cs.ARHardware ArchitectureComputer Science
2cs.CCComputational ComplexityComputer Science
3cs.CEComputational Engineering, Finance, and ScienceComputer Science
4cs.CGComputational GeometryComputer Science
\n", "
" ], "text/plain": [ " tag topic category\n", "0 cs.AI Artificial Intelligence Computer Science\n", "1 cs.AR Hardware Architecture Computer Science\n", "2 cs.CC Computational Complexity Computer Science\n", "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n", "4 cs.CG Computational Geometry Computer Science" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Manually prepared dataframe with arxiv topics\n", "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n", "print(len(arxiv_topics_df))\n", "arxiv_topics_df.head(5)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "tag_to_index = {}\n", "tag_to_category = {}\n", "for i, row in arxiv_topics_df.iterrows():\n", " tag_to_index[row['tag']] = i\n", " tag_to_category[row['tag']] = row['category']\n", "index_to_tag = {value: key for key, value in tag_to_index.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Готовим данные к обучению**" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 41000/41000 [00:01<00:00, 33941.59it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Среднее число категорий в одной статье: 1.3301219512195122\n", "Среднее число тегов в одной статье: 1.8489024390243902\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "def is_valid_tag(tag: str) -> bool:\n", " return tag in tag_to_index\n", "\n", "total_categories_count = 0\n", "total_tags_count = 0\n", "records = []\n", "for arxiv_record in tqdm(arxiv_data):\n", " record = {\n", " 'title': arxiv_record['title'],\n", " 'summary': arxiv_record['summary'],\n", " 'title_and_summary': arxiv_record['title'] + ' $ ' + arxiv_record['summary'],\n", " 'tags': sorted([current_tag['term'] for current_tag in eval(arxiv_record['tag']) if is_valid_tag(current_tag['term'])], key=lambda x: tag_to_index[x])\n", " }\n", " categories = set(tag_to_category[tag] for tag in record['tags'])\n", " total_categories_count += len(categories)\n", " total_tags_count += len(record['tags'])\n", " record['tags_indices'] = [tag_to_index[tag] for tag in record['tags']]\n", " assert len(record['tags']) > 0\n", " records.append(record)\n", "\n", "print(f'Среднее число категорий в одной статье: {total_categories_count / len(arxiv_data)}')\n", "print(f'Среднее число тегов в одной статье: {total_tags_count / len(arxiv_data)}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Как видим, перед нами задача мультибинарной классификации.\n", "\n", "Тегов у одной статьи бывает много, это понятно, но и категорий тоже бывает много. То есть, условно статья может быть посвящена и физике и биологии одновременно.\n", "\n", "Попробуем обучить модель определять теги - так она потенциально может сохранить в себе больше информации, чем если ее обучить определять категории (которых гораздо меньше)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Соединяем title и summary используя символ `$` - он редкий, при этом его знает токенайзер, поэтому не придется с ним дополнительно возиться**" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41000\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
titlesummarytitle_and_summarytagstags_indices
0Dual Recurrent Attention Units for Visual Ques...We propose an architecture for VQA which utili...Dual Recurrent Attention Units for Visual Ques...[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML][0, 5, 7, 28, 152]
1Sequential Short-Text Classification with Recu...Recent approaches based on artificial neural n...Sequential Short-Text Classification with Recu...[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML][0, 5, 22, 28, 152]
2Multiresolution Recurrent Neural Networks: An ...We introduce the multiresolution recurrent neu...Multiresolution Recurrent Neural Networks: An ...[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML][0, 5, 22, 28, 152]
3Learning what to share between loosely related...Multi-task learning is motivated by the observ...Learning what to share between loosely related...[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML][0, 5, 22, 28, 152]
4A Deep Reinforcement Learning ChatbotWe present MILABOT: a deep reinforcement learn...A Deep Reinforcement Learning Chatbot $ We pre...[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML][0, 5, 22, 28, 152]
\n", "
" ], "text/plain": [ " title \\\n", "0 Dual Recurrent Attention Units for Visual Ques... \n", "1 Sequential Short-Text Classification with Recu... \n", "2 Multiresolution Recurrent Neural Networks: An ... \n", "3 Learning what to share between loosely related... \n", "4 A Deep Reinforcement Learning Chatbot \n", "\n", " summary \\\n", "0 We propose an architecture for VQA which utili... \n", "1 Recent approaches based on artificial neural n... \n", "2 We introduce the multiresolution recurrent neu... \n", "3 Multi-task learning is motivated by the observ... \n", "4 We present MILABOT: a deep reinforcement learn... \n", "\n", " title_and_summary \\\n", "0 Dual Recurrent Attention Units for Visual Ques... \n", "1 Sequential Short-Text Classification with Recu... \n", "2 Multiresolution Recurrent Neural Networks: An ... \n", "3 Learning what to share between loosely related... \n", "4 A Deep Reinforcement Learning Chatbot $ We pre... \n", "\n", " tags tags_indices \n", "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] [0, 5, 7, 28, 152] \n", "1 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] \n", "2 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] \n", "3 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] \n", "4 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] " ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_data_df = pd.DataFrame(records)\n", "print(len(full_data_df))\n", "full_data_df.head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Посмотрим на распределение тегов и категорий в данных" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "defaultdict(, {'Statistics': 10618, 'Computer Science': 39251, 'Physics': 1208, 'Mathematics': 2263, 'Quantitative Biology': 896, 'Electrical Engineering and Systems Science': 220, 'Quantitative Finance': 66, 'Economics': 13})\n", "defaultdict(, {'cs.AI': 10481, 'cs.CL': 6417, 'cs.CV': 13902, 'cs.NE': 3819, 'stat.ML': 10326, 'cs.LG': 13735, 'physics.soc-ph': 293, 'stat.AP': 360, 'cs.RO': 973, 'cs.SE': 180, 'cs.MA': 268, 'math.OC': 1020, 'cs.IR': 1443, 'cond-mat.dis-nn': 126, 'stat.ME': 458, 'physics.chem-ph': 16, 'cs.DC': 404, 'stat.CO': 260, 'q-bio.NC': 513, 'cs.GT': 318, 'cs.MM': 345, 'cs.CG': 94, 'cs.CR': 411, 'cs.HC': 434, 'cs.GL': 10, 'eess.AS': 89, 'cs.SD': 389, 'math.DS': 49, 'cs.GR': 225, 'math.NA': 172, 'cs.CY': 376, 'physics.data-an': 187, 'math.ST': 336, 'stat.TH': 336, 'cs.IT': 543, 'math.IT': 543, 'quant-ph': 142, 'astro-ph.GA': 6, 'astro-ph.IM': 76, 'cs.SI': 639, 'cs.DB': 327, 'cs.LO': 643, 'nlin.AO': 119, 'cs.PF': 35, 'cs.ET': 85, 'eess.IV': 85, 'cs.AR': 52, 'cs.SY': 270, 'cs.CC': 196, 'q-bio.BM': 30, 'q-bio.QM': 232, 'cs.NI': 137, 'cs.DS': 570, 'cond-mat.stat-mech': 84, 'cs.NA': 253, 'cs.DM': 101, 'eess.SP': 52, 'cs.MS': 66, 'physics.med-ph': 81, 'physics.optics': 60, 'q-fin.CP': 14, 'cs.FL': 50, 'cs.SC': 24, 'q-fin.EC': 5, 'q-fin.TR': 9, 'cond-mat.mes-hall': 14, 'math.PR': 144, 'q-fin.RM': 3, 'nlin.CD': 29, 'cs.CE': 285, 'math.AT': 13, 'stat.OT': 8, 'physics.ao-ph': 19, 'math.SP': 7, 'cs.PL': 128, 'math.AP': 13, 'math.FA': 43, 'gr-qc': 6, 'physics.geo-ph': 14, 'q-bio.TO': 8, 'physics.comp-ph': 34, 'cs.DL': 139, 'math.CO': 33, 'physics.flu-dyn': 3, 'math.MG': 9, 'astro-ph.EP': 4, 'q-bio.CB': 5, 'hep-th': 6, 'math.RA': 11, 'astro-ph.CO': 10, 'cond-mat.mtrl-sci': 12, 'q-fin.ST': 15, 'q-bio.GN': 50, 'hep-ex': 9, 'nlin.CG': 18, 'nlin.PS': 3, 'math.HO': 8, 'q-fin.GN': 13, 'math.LO': 37, 'math.CT': 26, 'q-bio.PE': 84, 'astro-ph.SR': 9, 'q-fin.PM': 12, 'physics.bio-ph': 34, 'math.AG': 21, 'cs.OH': 11, 'math.DG': 17, 'astro-ph.HE': 4, 'econ.EM': 13, 'math.QA': 2, 'q-bio.SC': 3, 'math.GM': 3, 'q-bio.MN': 26, 'math.GT': 5, 'math.AC': 3, 'math.CA': 6, 'cond-mat.str-el': 5, 'math.GN': 4, 'hep-ph': 6, 'cond-mat.supr-con': 4, 'q-bio.OT': 5, 'nucl-th': 2, 'physics.ins-det': 9, 'hep-lat': 3, 'physics.app-ph': 1, 'math.RT': 3, 'math.MP': 4, 'math-ph': 4, 'physics.class-ph': 2, 'q-fin.PR': 1, 'physics.space-ph': 2, 'physics.gen-ph': 1, 'cond-mat.other': 2, 'math.GR': 4, 'nucl-ex': 3, 'cond-mat.quant-gas': 1, 'math.OA': 2, 'physics.hist-ph': 4, 'math.NT': 1, 'cs.OS': 2, 'cond-mat.soft': 2, 'physics.pop-ph': 1, 'math.CV': 1})\n" ] } ], "source": [ "tag_to_count = defaultdict(int)\n", "category_to_count = defaultdict(int)\n", "for i, row in full_data_df.iterrows():\n", " found_categories = set()\n", " for tag in row['tags']:\n", " tag_to_count[tag] += 1\n", " found_categories.add(tag_to_category[tag])\n", " for category in found_categories:\n", " category_to_count[category] += 1\n", "print(category_to_count)\n", "print(tag_to_count)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Как видим, Computer science встречается очень часто. А, например, экономика - совсем редко**\n", "\n", "**Это по-хорошему нужно учесть, но в рамках данного ноутбука мы это делать не будем**" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "text_data = list(full_data_df['title_and_summary'])\n", "tags_indices = list(full_data_df['tags_indices'])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28700 8200 4100\n" ] } ], "source": [ "X_train_val, X_test, y_train_val, y_test = train_test_split(text_data, tags_indices, test_size=0.1, random_state=42)\n", "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=2/9, random_state=42)\n", "print(len(X_train), len(X_val), len(X_test))\n", "# Train is 70%, val is 20%, test is 10%" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "tokenizer = DistilBertTokenizer.from_pretrained(USED_MODEL)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def tokenize_function(text):\n", " return tokenizer(text, padding=\"max_length\", truncation=True)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dual Recurrent Attention Units for Visual Question Answering $ We propose an architecture for VQA which utilizes recurrent layers to\n", "generate visual and textual attention. The memory characteristic of the\n", "proposed recurrent attention units offers a rich joint embedding of visual and\n", "textual features and enables the model to reason relations between several\n", "parts of the image and question. Our single model outperforms the first place\n", "winner on the VQA 1.0 dataset, performs within margin to the current\n", "state-of-the-art ensemble model. We also experiment with replacing attention\n", "mechanisms in other state-of-the-art models with our implementation and show\n", "increased accuracy. In both cases, our recurrent attention mechanism improves\n", "performance in tasks requiring sequential or relational reasoning on the VQA\n", "dataset.\n" ] }, { "data": { "text/plain": [ "{'input_ids': [101, 27791, 11336, 21754, 1335, 5208, 2116, 21687, 1111, 12071, 22171, 26018, 1158, 109, 1284, 17794, 1126, 4220, 1111, 159, 4880, 1592, 1134, 24242, 1231, 21754, 8798, 1106, 9509, 5173, 1105, 3087, 4746, 2209, 119, 1109, 2962, 7987, 1104, 1103, 3000, 1231, 21754, 2209, 2338, 3272, 170, 3987, 4091, 9712, 4774, 3408, 1104, 5173, 1105, 3087, 4746, 1956, 1105, 13267, 1103, 2235, 1106, 2255, 4125, 1206, 1317, 2192, 1104, 1103, 3077, 1105, 2304, 119, 3458, 1423, 2235, 1149, 3365, 13199, 1116, 1103, 1148, 1282, 2981, 1113, 1103, 159, 4880, 1592, 122, 119, 121, 2233, 9388, 117, 10383, 1439, 7464, 1106, 1103, 1954, 1352, 118, 1104, 118, 1103, 118, 1893, 9525, 2235, 119, 1284, 1145, 7886, 1114, 5861, 2209, 10748, 1107, 1168, 1352, 118, 1104, 118, 1103, 118, 1893, 3584, 1114, 1412, 7249, 1105, 1437, 2569, 10893, 119, 1130, 1241, 2740, 117, 1412, 1231, 21754, 2209, 6978, 4607, 1116, 2099, 1107, 8249, 8753, 14516, 21967, 1137, 6796, 1348, 14417, 1113, 1103, 159, 4880, 1592, 2233, 9388, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(text_data[0])\n", "tokenize_function(text_data[0])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "train_encodings = tokenize_function(X_train)\n", "val_encodings = tokenize_function(X_val)\n", "test_encodings = tokenize_function(X_test)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n", "2\n" ] } ], "source": [ "print(type(train_encodings))\n", "print(dir(train_encodings))\n", "print(len(train_encodings))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def get_labels(y: List[List[int]]):\n", " labels = np.zeros((len(y), len(tag_to_index)))\n", " for i in tqdm(range(len(y))):\n", " labels[i, y[i]] = 1\n", " return labels.tolist()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 28700/28700 [00:00<00:00, 388780.42it/s]\n", "100%|██████████| 8200/8200 [00:00<00:00, 223262.03it/s]\n", "100%|██████████| 4100/4100 [00:00<00:00, 165215.75it/s]\n" ] } ], "source": [ "labels_train = get_labels(y_train)\n", "labels_val = get_labels(y_val)\n", "labels_test = get_labels(y_test)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "train_encodings['labels'] = labels_train\n", "val_encodings['labels'] = labels_val\n", "test_encodings['labels'] = labels_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "train_dataset = Dataset.from_dict(train_encodings)\n", "val_dataset = Dataset.from_dict(val_encodings)\n", "test_dataset = Dataset.from_dict(test_encodings)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\n", " USED_MODEL, \n", " problem_type=\"multi_label_classification\", \n", " num_labels=len(tag_to_index),\n", " id2label=index_to_tag,\n", " label2id=tag_to_index\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "batch_size = 8\n", "metric_name = \"f1\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "args = TrainingArguments(\n", " output_dir=f'train-{USED_MODEL}-baseline',\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " num_train_epochs=5,\n", " weight_decay=0.01,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=metric_name,\n", " push_to_hub=False\n", ")" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n", "from transformers import EvalPrediction\n", "import torch\n", " \n", "# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n", "def multi_label_metrics(predictions, labels, threshold=0.5):\n", " # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)\n", " sigmoid = torch.nn.Sigmoid()\n", " probs = sigmoid(torch.Tensor(predictions))\n", " # next, use threshold to turn them into integer predictions\n", " y_pred = np.zeros(probs.shape)\n", " y_pred[np.where(probs >= threshold)] = 1\n", " # finally, compute metrics\n", " y_true = labels\n", " f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n", " roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n", " accuracy = accuracy_score(y_true, y_pred)\n", " # return as dictionary\n", " metrics = {'f1': f1_micro_average,\n", " 'roc_auc': roc_auc,\n", " 'accuracy': accuracy}\n", " return metrics\n", "\n", "def compute_metrics(p: EvalPrediction):\n", " preds = p.predictions[0] if isinstance(p.predictions, \n", " tuple) else p.predictions\n", " result = multi_label_metrics(\n", " predictions=preds, \n", " labels=p.label_ids)\n", " return result" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "train_dataset.set_format(\"torch\")\n", "test_dataset.set_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_571129/1751307119.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n" ] } ], "source": [ "trainer = Trainer(\n", " model,\n", " args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [17940/17940 32:06, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossF1Roc AucAccuracy
10.0240000.0228990.6529540.7701670.410366
20.0204000.0207300.6737650.7852260.426829
30.0179000.0196920.7002920.8123130.425000
40.0161000.0196950.7015930.8123660.433171
50.0148000.0197670.7011930.8127100.431707

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=17940, training_loss=0.02238395190159214, metrics={'train_runtime': 1927.2238, 'train_samples_per_second': 74.459, 'train_steps_per_second': 9.309, 'total_flos': 1.906093867776e+16, 'train_loss': 0.02238395190159214, 'epoch': 5.0})" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.019695421680808067,\n", " 'eval_f1': 0.7015928686248721,\n", " 'eval_roc_auc': 0.8123655228058703,\n", " 'eval_accuracy': 0.43317073170731707,\n", " 'eval_runtime': 34.8656,\n", " 'eval_samples_per_second': 235.189,\n", " 'eval_steps_per_second': 29.399,\n", " 'epoch': 5.0}" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(eval_dataset=val_dataset)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.019682902842760086,\n", " 'eval_f1': 0.6966158423205653,\n", " 'eval_roc_auc': 0.8081637343174538,\n", " 'eval_accuracy': 0.4370731707317073,\n", " 'eval_runtime': 16.5771,\n", " 'eval_samples_per_second': 247.329,\n", " 'eval_steps_per_second': 30.946,\n", " 'epoch': 5.0}" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(eval_dataset=test_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n", "\n", "Будем делать следующее:\n", "- наша модель выдает логиты тегов\n", "- посчитаем с их помощью вероятность каждого тега, считая сумму вероятностей равной 1\n", "- посчитаем вероятность категории как сумму вероятностей тегов\n", "- выведем требуемые топ-95% тематик" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\n", " \"train_distilbert-base-cased/checkpoint-17940\", \n", " problem_type=\"multi_label_classification\", \n", " num_labels=len(tag_to_index),\n", " id2label=index_to_tag,\n", " label2id=tag_to_index\n", ").to(torch.device('cuda'))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SequenceClassifierOutput(loss=None, logits=tensor([[-1.3623, -5.3834, -3.3988, -3.4555, -3.7096, -4.5285, -5.1323, -2.3077,\n", " -3.6645, -4.6847, -4.2481, -5.0417, -3.5121, -2.7808, -5.9767, -4.8864,\n", " -5.6730, -4.6838, -3.8588, -5.2819, -3.9295, -2.7704, 0.4331, -4.5505,\n", " -5.2648, -4.9248, -4.2074, -3.4895, -3.2717, -5.2713, -5.7536, -7.2749,\n", " -4.8728, -5.2606, -4.5935, -4.7103, -5.4628, -5.4589, -5.3678, -3.5648,\n", " -5.1455, -8.8455, -9.1583, -6.4358, -4.7737, -4.7821, -8.9264, -5.8790,\n", " -4.7536, -5.4549, -5.3879, -6.1918, -4.1667, -7.1828, -7.3235, -5.4470,\n", " -4.6688, -4.7201, -6.2949, -7.5401, -6.6242, -6.1022, -5.5325, -3.1546,\n", " -9.4200, -5.2060, -5.3880, -6.8743, -3.3176, -7.2654, -7.4301, -3.0929,\n", " -3.2351, -9.0408, -5.4315, -6.3230, -9.5853, -5.7075, -3.6443, -5.5524,\n", " -6.0723, -6.0414, -7.3201, -3.9738, -5.5964, -4.0455, -5.2017, -5.8061,\n", " -7.8401, -7.5268, -7.4576, -4.4483, -6.4790, -5.9085, -6.8822, -5.4498,\n", " -6.7494, -6.1449, -5.9297, -6.4985, -5.0379, -4.9914, -5.5201, -7.9075,\n", " -8.7653, -6.6116, -6.6643, -9.3863, -4.9038, -7.6509, -9.0117, -9.1193,\n", " -5.3166, -5.4046, -8.3876, -4.9028, -3.5257, -8.9734, -6.1487, -8.1408,\n", " -5.3014, -6.5494, -6.8383, -4.8011, -5.2831, -8.7708, -7.5039, -5.3957,\n", " -7.3326, -3.6551, -4.9892, -5.9366, -5.2093, -5.2362, -5.0462, -6.5469,\n", " -4.9182, -4.4108, -7.1632, -5.9481, -5.3291, -6.4517, -5.6950, -8.7276,\n", " -5.7762, -8.9848, -7.3795, -5.4210, -5.6845, -2.9447, -3.6166, -3.6258,\n", " -1.4417, -5.6568, -3.5869]], device='cuda:0',\n", " grad_fn=), hidden_states=None, attentions=None)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad\n", "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n", " text = f'{title} $ {summary}'\n", " tags_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n", " sigmoid = torch.nn.Sigmoid()\n", " tags_probs = sigmoid(tags_logits.squeeze().cpu()).numpy()\n", " tags_probs /= tags_probs.sum()\n", " category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n", " for index in range(len(index_to_tag)):\n", " category_probs_dict[tag_to_category[index_to_tag[index]]] += float(tags_probs[index])\n", " return category_probs_dict" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n", " current_p = 0\n", " probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n", " current_index = 0\n", " answer = []\n", " while current_p <= target_probability:\n", " current_p += probs_list[current_index][0]\n", " if not print_probabilities:\n", " answer.append(probs_list[current_index][1])\n", " else:\n", " answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n", " current_index += 1\n", " if current_index >= len(probs_list):\n", " break\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Теперь нужно как-то сохранить модель, чтобы потом можно было её использовать в huggingface space" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "model.safetensors: 100%|██████████| 264M/264M [00:31<00:00, 8.47MB/s] \n" ] }, { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/bumchik2/train_distilbert-base-cased-tags-classification-simple/commit/98a87d7c96e0647dd557a9d47be03ddd30e0c964', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='98a87d7c96e0647dd557a9d47be03ddd30e0c964', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train_distilbert-base-cased-tags-classification-simple', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train_distilbert-base-cased-tags-classification-simple'), pr_revision=None, pr_num=None)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.push_to_hub(\"bumchik2/train_distilbert-base-cased-tags-classification-simple\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Теперь я смогу загружать свою модель оттуда" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\n", " \"bumchik2/train_distilbert-base-cased-tags-classification-simple\", \n", " problem_type=\"multi_label_classification\", \n", " num_labels=len(tag_to_index),\n", " id2label=index_to_tag,\n", " label2id=tag_to_index\n", ").to(torch.device('cuda'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Tricks", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }