File size: 5,771 Bytes
cb09873
 
 
 
 
 
 
67d83f0
cb09873
4f607de
 
cb09873
 
 
 
 
 
 
 
 
 
4f607de
 
cb09873
 
 
4f607de
 
cb09873
67d83f0
4f607de
cb09873
67d83f0
 
4f607de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67d83f0
4f607de
67d83f0
cb09873
 
 
 
 
 
67d83f0
cb09873
4f607de
cb09873
67d83f0
cb09873
 
 
 
 
67d83f0
 
cb09873
4f607de
cb09873
 
67d83f0
cb09873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f607de
cb09873
 
4f607de
cb09873
 
 
 
4f607de
cb09873
 
 
 
 
 
 
4f607de
cb09873
67d83f0
cb09873
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, pipeline, DataCollatorWithPadding
from sklearn.metrics import accuracy_score, f1_score
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from umap import UMAP
from sklearn.preprocessing import MinMaxScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



class TransformersSequenceClassifier:
    def __init__(self,
                 model_output_dir,
                 num_labels,
                 tokenizer : AutoTokenizer,
                 id2label,
                 label2id,
                 model_checkpoint="distilbert-base-uncased"
                 ):
        self.model_output_dir = model_output_dir
        self.tokenizer = tokenizer
        self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id).to(device)

    def tokenizer_batch(self, batch):
        return self.tokenizer(batch["inputs"], truncation=True, padding=True, return_tensors="pt") #, max_len=386

    def tokenize_dataset(self, dataset):
        return dataset.map(self.tokenizer_batch, batched=True, remove_columns=('inputs', '__index_level_0__'))

    @staticmethod
    def extract_hidden_states(batch, tokenizer, model):
        # Place model inputs on the GPU
        inputs = {k:v for k,v in batch.items() if k in tokenizer.model_input_names} #.to(device)
        # Extract last hidden states
        with torch.no_grad():
            last_hidden_state = model(**inputs).last_hidden_state
        # Return vector for [CLS] token
        return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}
    
    @staticmethod
    def fit_umap(df_x):
        # Scale features to [0,1] range
        X_scaled = MinMaxScaler().fit_transform(df_x)
        # Initialize and fit UMAP
        mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled)
        return mapper.embedding_
        # Create a DataFrame of 2D embeddings
    
    def train(self, train_dataset, eval_dataset, batch_size, epochs):
        #data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer, padding='longest')
        training_args = TrainingArguments(output_dir=self.model_output_dir,
                                          num_train_epochs=epochs,
                                          learning_rate=2e-5,
                                          per_device_train_batch_size=batch_size,
                                          per_device_eval_batch_size=batch_size,
                                          weight_decay=0.01,
                                          evaluation_strategy="epoch",
                                          save_strategy='epoch',
                                          disable_tqdm=False,
                                          logging_steps=len(train_dataset)//batch_size,
                                          push_to_hub=True,
                                          load_best_model_at_end=True,
                                          log_level="error")
        self.trainer = Trainer(
            model=self.model,
            args=training_args,
            compute_metrics=self._compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer,
            #data_collator=data_collator
        )
        self.trainer.train()
        self.trainer.push_to_hub(commit_message="Training completed!")

    @staticmethod
    def _compute_metrics(pred):
        labels = pred.label_ids
        preds = pred.predictions.argmax(-1)
        f1 = f1_score(labels, preds, average="weighted")
        acc = accuracy_score(labels, preds)
        return {"accuracy": acc, "f1": f1}

    def forward_pass_with_label(self, batch):
        # Place all input tensors on the same device as the model
        inputs = {k:v.to(device) for k,v in batch.items()
                if k in self.tokenizer.model_input_names}

        with torch.no_grad():
            output = self.model(**inputs)
            pred_label = torch.argmax(output.logits, axis=-1)
            loss = F.cross_entropy(output.logits, batch["label"].to(device), 
                                reduction="none")

        # Place outputs on CPU for compatibility with other dataset columns
        return {"loss": loss.cpu().numpy(), 
                "predicted_label": pred_label.cpu().numpy()}

    def compute_loss_per_pred(self, valid_dataset):
        # Compute loss values
        return valid_dataset.map(self.forward_pass_with_label, batched=True, batch_size=16)

    @staticmethod
    def plot_confusion_matrix(y_preds, y_true, label_names):
        cm = confusion_matrix(y_true, y_preds, normalize="true")
        fig, ax = plt.subplots(figsize=(6, 6))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)
        disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
        plt.title("Normalized confusion matrix")
        plt.show()

    def predict_argmax_logit(self, valid_dataset):
        #trainer = Trainer(model=self.model)
        preds_output = self.trainer.predict(valid_dataset)
        print(preds_output.metrics)
        y_preds = np.argmax(preds_output.predictions, axis=1)
        return y_preds
    
    @staticmethod
    def predict_pipeline(model_checkpoint, test_list: List[str]) -> List:
        pipe_classifier = pipeline("text-classification", model=model_checkpoint)
        preds = pipe_classifier(test_list, return_all_scores=True)
        return preds