import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer import pickle from models.rnn import RNNClassifier from models.lstm import LSTMClassifier from models.transformer import TransformerClassifier from utility import simple_tokenizer # ========================= # Load models and vocab # ========================= device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = "prajjwal1/bert-tiny" def load_vocab(): with open("pretrained_models/vocab.pkl", "rb") as f: return pickle.load(f) def load_models(vocab_size, output_dim=6, padding_idx=0): rnn_model = RNNClassifier(vocab_size, 128, 128, output_dim, padding_idx) rnn_model.load_state_dict(torch.load("pretrained_models/best_rnn.pt", map_location=device)) rnn_model = rnn_model.to(device) rnn_model.eval() lstm_model = LSTMClassifier(vocab_size, 128, 128, output_dim, padding_idx) lstm_model.load_state_dict(torch.load("pretrained_models/best_lstm.pt", map_location=device)) lstm_model = lstm_model.to(device) lstm_model.eval() transformer_model = TransformerClassifier(model_name, output_dim) transformer_model.load_state_dict(torch.load("pretrained_models/best_transformer.pt", map_location=device)) transformer_model = transformer_model.to(device) transformer_model.eval() return rnn_model, lstm_model, transformer_model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vocab = load_vocab() tokenizer = AutoTokenizer.from_pretrained(model_name) rnn_model, lstm_model, transformer_model = load_models(len(vocab)) emotions = ["anger", "fear", "joy", "love", "sadness", "surprise"] def predict(model, text, model_type, vocab, tokenizer=None, max_length=32): if model_type in ["rnn", "lstm"]: # Match collate_fn_rnn but with no random truncation tokens = simple_tokenizer(text) ids = [vocab.get(token, vocab[""]) for token in tokens] if len(ids) < max_length: ids += [vocab[""]] * (max_length - len(ids)) else: ids = ids[:max_length] input_ids = torch.tensor([ids], dtype=torch.long).to(device) outputs = model(input_ids) else: # Match collate_fn_transformer but with no partial_prob encoding = tokenizer( text, padding="max_length", truncation=True, max_length=128, return_tensors="pt" ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) probs = F.softmax(outputs, dim=-1) return probs.squeeze().detach().cpu().numpy() # ========================= # Gradio App # ========================= def emotion_typeahead(text): if len(text.strip()) <= 2: return {}, {}, {} rnn_probs = predict(rnn_model, text.strip(), "rnn", vocab) lstm_probs = predict(lstm_model, text.strip(), "lstm", vocab) transformer_probs = predict(transformer_model, text.strip(), "transformer", vocab, tokenizer) rnn_dict = {emo: float(prob) for emo, prob in zip(emotions, rnn_probs)} lstm_dict = {emo: float(prob) for emo, prob in zip(emotions, lstm_probs)} transformer_dict = {emo: float(prob) for emo, prob in zip(emotions, transformer_probs)} return rnn_dict, lstm_dict, transformer_dict with gr.Blocks() as demo: gr.Markdown("## 🎯 Emotion Typeahead Predictor (RNN, LSTM, Transformer)") text_input = gr.Textbox(label="Type your sentence here...") with gr.Row(): rnn_output = gr.Label(label="🧠 RNN Prediction") lstm_output = gr.Label(label="🧠 LSTM Prediction") transformer_output = gr.Label(label="🧠 Transformer Prediction") text_input.change(emotion_typeahead, inputs=text_input, outputs=[rnn_output, lstm_output, transformer_output]) demo.launch()