Spaces:
Running
Running
import re | |
import torch | |
from collections import Counter | |
from datasets import load_dataset | |
from sklearn.preprocessing import LabelEncoder | |
from transformers import AutoTokenizer | |
import random | |
# ====== Dataset Loading ====== | |
def load_emotion_dataset(split="train"): | |
return load_dataset("dair-ai/emotion", split=split) | |
def encode_labels(dataset): | |
le = LabelEncoder() | |
all_labels = [example["label"] for example in dataset] | |
le.fit(all_labels) | |
dataset = dataset.map(lambda x: {"label": le.transform([x["label"]])[0]}) | |
return dataset, le | |
# ====== Tokenizer for RNN/LSTM ====== | |
def simple_tokenizer(text): | |
text = text.lower() | |
text = re.sub(r"[^a-z0-9\s]", "", text) # Remove special characters | |
return text.split() | |
# ====== Vocab Builder for RNN/LSTM ====== | |
def build_vocab(dataset, min_freq=2): | |
counter = Counter() | |
for example in dataset: | |
tokens = simple_tokenizer(example["text"]) | |
counter.update(tokens) | |
vocab = {"<PAD>": 0, "<UNK>": 1} | |
idx = 2 | |
for word, freq in counter.items(): | |
if freq >= min_freq: | |
vocab[word] = idx | |
idx += 1 | |
return vocab | |
# ====== Collate Function for RNN/LSTM ====== | |
def collate_fn_rnn(batch, vocab, max_length=32, partial_prob=0.0): | |
texts = [item["text"] for item in batch] | |
labels = [item["label"] for item in batch] | |
all_input_ids = [] | |
for text in texts: | |
tokens = simple_tokenizer(text) | |
# π₯ Randomly truncate tokens with some probability | |
if random.random() < partial_prob and len(tokens) > 5: | |
# Keep between 30% to 70% of the tokens | |
cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7)) | |
tokens = tokens[:cutoff] | |
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens] | |
if len(ids) < max_length: | |
ids += [vocab["<PAD>"]] * (max_length - len(ids)) | |
else: | |
ids = ids[:max_length] | |
all_input_ids.append(ids) | |
input_ids = torch.tensor(all_input_ids) | |
labels = torch.tensor(labels) | |
return input_ids, labels | |
# ====== Collate Function for Transformer ====== | |
def collate_fn_transformer(batch, tokenizer, max_length=128, partial_prob=0.5): | |
import random | |
texts = [] | |
labels = [] | |
for item in batch: | |
text = item["text"] | |
tokens = text.split() | |
# π₯ Random truncation | |
if random.random() < partial_prob and len(tokens) > 5: | |
cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7)) | |
tokens = tokens[:cutoff] | |
text = " ".join(tokens) | |
texts.append(text) | |
labels.append(item["label"]) | |
encoding = tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt") | |
encoding["labels"] = torch.tensor(labels) | |
return encoding | |