SentiNet / utility.py
Hunter-Pax's picture
Upload 18 files
e7a44ba verified
import re
import torch
from collections import Counter
from datasets import load_dataset
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer
import random
# ====== Dataset Loading ======
def load_emotion_dataset(split="train"):
return load_dataset("dair-ai/emotion", split=split)
def encode_labels(dataset):
le = LabelEncoder()
all_labels = [example["label"] for example in dataset]
le.fit(all_labels)
dataset = dataset.map(lambda x: {"label": le.transform([x["label"]])[0]})
return dataset, le
# ====== Tokenizer for RNN/LSTM ======
def simple_tokenizer(text):
text = text.lower()
text = re.sub(r"[^a-z0-9\s]", "", text) # Remove special characters
return text.split()
# ====== Vocab Builder for RNN/LSTM ======
def build_vocab(dataset, min_freq=2):
counter = Counter()
for example in dataset:
tokens = simple_tokenizer(example["text"])
counter.update(tokens)
vocab = {"<PAD>": 0, "<UNK>": 1}
idx = 2
for word, freq in counter.items():
if freq >= min_freq:
vocab[word] = idx
idx += 1
return vocab
# ====== Collate Function for RNN/LSTM ======
def collate_fn_rnn(batch, vocab, max_length=32, partial_prob=0.0):
texts = [item["text"] for item in batch]
labels = [item["label"] for item in batch]
all_input_ids = []
for text in texts:
tokens = simple_tokenizer(text)
# πŸ”₯ Randomly truncate tokens with some probability
if random.random() < partial_prob and len(tokens) > 5:
# Keep between 30% to 70% of the tokens
cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
tokens = tokens[:cutoff]
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
if len(ids) < max_length:
ids += [vocab["<PAD>"]] * (max_length - len(ids))
else:
ids = ids[:max_length]
all_input_ids.append(ids)
input_ids = torch.tensor(all_input_ids)
labels = torch.tensor(labels)
return input_ids, labels
# ====== Collate Function for Transformer ======
def collate_fn_transformer(batch, tokenizer, max_length=128, partial_prob=0.5):
import random
texts = []
labels = []
for item in batch:
text = item["text"]
tokens = text.split()
# πŸ”₯ Random truncation
if random.random() < partial_prob and len(tokens) > 5:
cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
tokens = tokens[:cutoff]
text = " ".join(tokens)
texts.append(text)
labels.append(item["label"])
encoding = tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
encoding["labels"] = torch.tensor(labels)
return encoding