import re import torch from collections import Counter from datasets import load_dataset from sklearn.preprocessing import LabelEncoder from transformers import AutoTokenizer import random # ====== Dataset Loading ====== def load_emotion_dataset(split="train"): return load_dataset("dair-ai/emotion", split=split) def encode_labels(dataset): le = LabelEncoder() all_labels = [example["label"] for example in dataset] le.fit(all_labels) dataset = dataset.map(lambda x: {"label": le.transform([x["label"]])[0]}) return dataset, le # ====== Tokenizer for RNN/LSTM ====== def simple_tokenizer(text): text = text.lower() text = re.sub(r"[^a-z0-9\s]", "", text) # Remove special characters return text.split() # ====== Vocab Builder for RNN/LSTM ====== def build_vocab(dataset, min_freq=2): counter = Counter() for example in dataset: tokens = simple_tokenizer(example["text"]) counter.update(tokens) vocab = {"": 0, "": 1} idx = 2 for word, freq in counter.items(): if freq >= min_freq: vocab[word] = idx idx += 1 return vocab # ====== Collate Function for RNN/LSTM ====== def collate_fn_rnn(batch, vocab, max_length=32, partial_prob=0.0): texts = [item["text"] for item in batch] labels = [item["label"] for item in batch] all_input_ids = [] for text in texts: tokens = simple_tokenizer(text) # 🔥 Randomly truncate tokens with some probability if random.random() < partial_prob and len(tokens) > 5: # Keep between 30% to 70% of the tokens cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7)) tokens = tokens[:cutoff] ids = [vocab.get(token, vocab[""]) for token in tokens] if len(ids) < max_length: ids += [vocab[""]] * (max_length - len(ids)) else: ids = ids[:max_length] all_input_ids.append(ids) input_ids = torch.tensor(all_input_ids) labels = torch.tensor(labels) return input_ids, labels # ====== Collate Function for Transformer ====== def collate_fn_transformer(batch, tokenizer, max_length=128, partial_prob=0.5): import random texts = [] labels = [] for item in batch: text = item["text"] tokens = text.split() # 🔥 Random truncation if random.random() < partial_prob and len(tokens) > 5: cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7)) tokens = tokens[:cutoff] text = " ".join(tokens) texts.append(text) labels.append(item["label"]) encoding = tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt") encoding["labels"] = torch.tensor(labels) return encoding