{ "cells": [ { "cell_type": "markdown", "id": "976841dc", "metadata": {}, "source": [ "## Preparación de un dataset\n", "\n", "Descargamos el dataset y lo preparamos para el entrenamiento. En el caso de ejemplo, usaremos toxic-teenage-relationships, que son frases que describen si un comporamiento es tóxico o sano. Tienen una campo de texto y un campo de etiqueta, que vale 1 si es tóxico y 0 si no lo es. Acumula 267 ejemplos de entrenamiento y 66 para testear." ] }, { "cell_type": "code", "execution_count": 1, "id": "caf72aa3", "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "{'label': 1, 'text': 'Me mira mal por mi forma de vestir'}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "data_files = {\"train\": \"train.csv\", \"test\": \"test.csv\"}\n", "dataset = load_dataset(\"toxic-teenage-relationships\", data_files=data_files, sep=\";\")\n", "dataset['train'][102]" ] }, { "cell_type": "markdown", "id": "08aacc14", "metadata": {}, "source": [ "Una vez cargado el dataset, se crea un tokenizador para procesar el texto e incluir una estrategia para el padding y el truncamiento. Par poder procesar el dataset en un solo paso, se utiliza el método dataset.map para preprocesar todo el dataset.\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "4a854ead", "metadata": {}, "outputs": [], "source": [ "\n", "from transformers import AutoTokenizer\n", "#el modelo a utilizar es BETo\n", "tokenizer = AutoTokenizer.from_pretrained(\"dccuchile/bert-base-spanish-wwm-cased\")\n", "\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", "\n", "\n", "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "eb5477cc", "metadata": {}, "outputs": [], "source": [ "train_dataset = tokenized_datasets[\"train\"]\n", "eval_dataset = tokenized_datasets[\"test\"]" ] }, { "cell_type": "markdown", "id": "38a6c521", "metadata": {}, "source": [ "## Fine-tuning usando Trainer\n", "\n", "La clase trainer de Transformers permite entrenar modelos de transformers. La API del Trainer soporta varias opciones de entrenamiento y características como logging, gradient accumulation y mixed preccision" ] }, { "cell_type": "code", "execution_count": 4, "id": "843f218d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['classifier.bias', 'bert.pooler.dense.weight', 'classifier.weight', 'bert.pooler.dense.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoModelForSequenceClassification\n", "\n", "#Hay dos categorías, así que ponemos 2 etiquetas (0 sano 1 tóxico)\n", "model = AutoModelForSequenceClassification.from_pretrained(\"dccuchile/bert-base-spanish-wwm-cased\", num_labels=2)\n" ] }, { "cell_type": "markdown", "id": "27be3c25", "metadata": {}, "source": [ "## Hiperparámetros de entrenamiento\n", "\n", "Ahora se crea una clase TrainingArguments que contiene todos los hiperparámetros que se pueden ajustar. \n", "Empezamos con los hiperparámetros de entrenamiento por defecto, pero tendremos que ajustarlos para encontrar la configuración óptima.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "7f84ef1e", "metadata": {}, "outputs": [], "source": [ "#Para poder evitar el overfitting, voy a añadir la clase earlystopping en el momento que se observe\n", "#que la pérdida se incrementa en dos epoch\n", "from transformers import EarlyStoppingCallback\n", "early_stop=EarlyStoppingCallback(early_stopping_patience=2)" ] }, { "cell_type": "code", "execution_count": 6, "id": "f53c992d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mmartinez/anaconda3/envs/TFM/lib/python3.8/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import TrainingArguments\n", "from transformers import DataCollatorWithPadding\n", "from transformers import AdamW\n", "# para controlar las métricas de evaluación durante el fine-tuning\n", "# vamos a añadir que elija el mejor modelo al final, usamos load_best_model_at_end que cogerá eval_loss para evaluar\n", "# para que se fije en el valor de loss como la mejor métrica, hay que poner greater_is_better a false.\n", "#vamos a poner el número de epoch a 10 y el del batch a 8\n", "\n", "training_args = TrainingArguments(output_dir=\"BETo-t-MMG\",\n", " num_train_epochs=10,\n", " per_device_train_batch_size=8,\n", " per_device_eval_batch_size=8,\n", " load_best_model_at_end=True,\n", " greater_is_better=False,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\")\n", "#optimizador \n", "optimizer=AdamW(model.parameters(),lr=5e-5)\n", "#añado el data Collator, que en este caso va a ser parte del trainer.\n", "#este es el indicado específicamente para tareas de clasificación de texto, agrupa y preprocesa\n", "#para que todos los ejemplos de entrada en lotes tengan la misma longitud además del tokenizdor\n", "#agrupación en lotes y creación de mapas de atención.\n", "#usando la función .map, no es estrictamente necesario pero así se combinan las características\n", "#adicionales del texto antes de pasarle el datacollator.\n", "data_collator = DataCollatorWithPadding(tokenizer)" ] }, { "cell_type": "markdown", "id": "6d604727", "metadata": {}, "source": [ "## Métricas\n", "\n", "El Trainer no evalúa automátiamentee el rendimiento, hay que pasarle una función para calcular y hacer un reporte de las métricas. En Datasets hay una función, accuracy, que se puede cargar con load_metric. \n", "Antes hay que instalar scikit-learn" ] }, { "cell_type": "code", "execution_count": 7, "id": "0ed3ddf4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: scikit-learn in /home/mmartinez/anaconda3/envs/TFM/lib/python3.8/site-packages (1.3.0)\n", "Requirement already satisfied: numpy>=1.17.3 in /home/mmartinez/anaconda3/envs/TFM/lib/python3.8/site-packages (from scikit-learn) (1.24.3)\n", "Requirement already satisfied: scipy>=1.5.0 in /home/mmartinez/anaconda3/envs/TFM/lib/python3.8/site-packages (from scikit-learn) (1.10.1)\n", "Requirement already satisfied: joblib>=1.1.1 in /home/mmartinez/anaconda3/envs/TFM/lib/python3.8/site-packages (from scikit-learn) (1.3.1)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/mmartinez/anaconda3/envs/TFM/lib/python3.8/site-packages (from scikit-learn) (3.2.0)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "pip install scikit-learn" ] }, { "cell_type": "code", "execution_count": 8, "id": "326103f5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_3270586/2607597888.py:4: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", " metric = load_metric(\"accuracy\")\n" ] } ], "source": [ "import numpy as np\n", "from datasets import load_metric\n", "\n", "metric = load_metric(\"accuracy\")" ] }, { "cell_type": "markdown", "id": "087d4b3e", "metadata": {}, "source": [ "Se define la función compute_metrics para calcular el accuracy de las predicciones hechas. Antes de pasar las predicciones a compute, hay que convertir las predicciones a logits (los modelos de Transformers devuelven logits)." ] }, { "cell_type": "code", "execution_count": 9, "id": "d7b8341d", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(eval_pred):\n", " logits, labels = eval_pred\n", " predictions = np.argmax(logits, axis=-1)\n", " return metric.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "markdown", "id": "53db268c", "metadata": {}, "source": [ "## Trainer\n", "\n", "Ahora es el momento de crear el objeto Trainer con el modelo, argumentos de entrenamiento, datasets de entrenamiento y de prueba, y función de evaluación:" ] }, { "cell_type": "code", "execution_count": 12, "id": "d566aded", "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset,\n", " data_collator=data_collator,\n", " optimizers=(optimizer, None),\n", " compute_metrics=compute_metrics,\n", " callbacks=[early_stop],\n", ")" ] }, { "cell_type": "markdown", "id": "a31780ca", "metadata": {}, "source": [ "Y se aplica fine-tunning con train" ] }, { "cell_type": "code", "execution_count": 13, "id": "3e01c5fb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [102/340 01:27 < 03:27, 1.15 it/s, Epoch 3/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
1No log0.4598660.803030
2No log0.6496650.848485
3No log1.0263340.787879

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=102, training_loss=0.33487387264476104, metrics={'train_runtime': 88.4219, 'train_samples_per_second': 30.309, 'train_steps_per_second': 3.845, 'total_flos': 211541288509440.0, 'train_loss': 0.33487387264476104, 'epoch': 3.0})" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "id": "417d3cd2", "metadata": {}, "source": [ "Imprimo el loss y el accuracy" ] }, { "cell_type": "code", "execution_count": 15, "id": "d1144002", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Resultados del conjunto de train\n", "eval_loss. 0.19221480190753937\n", "eval_accuracy. 0.9440298507462687\n", "eval_runtime. 9.8909\n", "eval_samples_per_second. 27.095\n", "eval_steps_per_second. 3.437\n", "epoch. 3.0\n", "Resultados del conjunto de test\n", "eval_loss. 0.4598655700683594\n", "eval_accuracy. 0.803030303030303\n", "eval_runtime. 2.4345\n", "eval_samples_per_second. 27.11\n", "eval_steps_per_second. 3.697\n", "epoch. 3.0\n" ] } ], "source": [ "#creo una función para imprimir los resultados de una formá más visual\n", "def print_results(title, results):\n", " print(title)\n", " for key, value in results.items():\n", " print(f\"{key}. {value}\")\n", " \n", "train_result = trainer.evaluate(train_dataset)\n", "print_results(\"Resultados del conjunto de train\",train_result)\n", "eval_result = trainer.evaluate(eval_dataset)\n", "print_results(\"Resultados del conjunto de test\",eval_result)" ] }, { "cell_type": "markdown", "id": "9e61a040", "metadata": {}, "source": [ "# Guardando el modelo" ] }, { "cell_type": "markdown", "id": "4af06209", "metadata": {}, "source": [ "Para Guardarlo, utilizamos esl método save_model" ] }, { "cell_type": "code", "execution_count": 16, "id": "b93638cb", "metadata": {}, "outputs": [], "source": [ "trainer.save_model()" ] }, { "cell_type": "code", "execution_count": 17, "id": "973c4e03", "metadata": {}, "outputs": [], "source": [ "trainer.create_model_card()" ] }, { "cell_type": "code", "execution_count": null, "id": "9671b67c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }