Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer | |
import pickle | |
from models.rnn import RNNClassifier | |
from models.lstm import LSTMClassifier | |
from models.transformer import TransformerClassifier | |
from utility import simple_tokenizer | |
# ========================= | |
# Load models and vocab | |
# ========================= | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_name = "prajjwal1/bert-tiny" | |
def load_vocab(): | |
with open("pretrained_models/vocab.pkl", "rb") as f: | |
return pickle.load(f) | |
def load_models(vocab_size, output_dim=6, padding_idx=0): | |
rnn_model = RNNClassifier(vocab_size, 128, 128, output_dim, padding_idx) | |
rnn_model.load_state_dict(torch.load("pretrained_models/best_rnn.pt", map_location=device)) | |
rnn_model = rnn_model.to(device) | |
rnn_model.eval() | |
lstm_model = LSTMClassifier(vocab_size, 128, 128, output_dim, padding_idx) | |
lstm_model.load_state_dict(torch.load("pretrained_models/best_lstm.pt", map_location=device)) | |
lstm_model = lstm_model.to(device) | |
lstm_model.eval() | |
transformer_model = TransformerClassifier(model_name, output_dim) | |
transformer_model.load_state_dict(torch.load("pretrained_models/best_transformer.pt", map_location=device)) | |
transformer_model = transformer_model.to(device) | |
transformer_model.eval() | |
return rnn_model, lstm_model, transformer_model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vocab = load_vocab() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
rnn_model, lstm_model, transformer_model = load_models(len(vocab)) | |
emotions = ["anger", "fear", "joy", "love", "sadness", "surprise"] | |
def predict(model, text, model_type, vocab, tokenizer=None, max_length=32): | |
if model_type in ["rnn", "lstm"]: | |
# Match collate_fn_rnn but with no random truncation | |
tokens = simple_tokenizer(text) | |
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens] | |
if len(ids) < max_length: | |
ids += [vocab["<PAD>"]] * (max_length - len(ids)) | |
else: | |
ids = ids[:max_length] | |
input_ids = torch.tensor([ids], dtype=torch.long).to(device) | |
outputs = model(input_ids) | |
else: | |
# Match collate_fn_transformer but with no partial_prob | |
encoding = tokenizer( | |
text, | |
padding="max_length", | |
truncation=True, | |
max_length=128, | |
return_tensors="pt" | |
) | |
input_ids = encoding["input_ids"].to(device) | |
attention_mask = encoding["attention_mask"].to(device) | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
probs = F.softmax(outputs, dim=-1) | |
return probs.squeeze().detach().cpu().numpy() | |
# ========================= | |
# Gradio App | |
# ========================= | |
def emotion_typeahead(text): | |
if len(text.strip()) <= 2: | |
return {}, {}, {} | |
rnn_probs = predict(rnn_model, text.strip(), "rnn", vocab) | |
lstm_probs = predict(lstm_model, text.strip(), "lstm", vocab) | |
transformer_probs = predict(transformer_model, text.strip(), "transformer", vocab, tokenizer) | |
rnn_dict = {emo: float(prob) for emo, prob in zip(emotions, rnn_probs)} | |
lstm_dict = {emo: float(prob) for emo, prob in zip(emotions, lstm_probs)} | |
transformer_dict = {emo: float(prob) for emo, prob in zip(emotions, transformer_probs)} | |
return rnn_dict, lstm_dict, transformer_dict | |
with gr.Blocks() as demo: | |
gr.Markdown("## π― Emotion Typeahead Predictor (RNN, LSTM, Transformer)") | |
text_input = gr.Textbox(label="Type your sentence here...") | |
with gr.Row(): | |
rnn_output = gr.Label(label="π§ RNN Prediction") | |
lstm_output = gr.Label(label="π§ LSTM Prediction") | |
transformer_output = gr.Label(label="π§ Transformer Prediction") | |
text_input.change(emotion_typeahead, inputs=text_input, outputs=[rnn_output, lstm_output, transformer_output]) | |
demo.launch() | |