File size: 2,440 Bytes
3b6b8a6 c3c8834 84db39a d17306e 01dbd9a 5522dd0 9d5060f 01dbd9a 5522dd0 65ce341 01dbd9a 62f1d54 15bcfa5 01dbd9a 9d5060f 01dbd9a 9d5060f c3c8834 01dbd9a 8bbd5d4 01dbd9a c3c8834 01dbd9a c3c8834 01dbd9a 9d5060f 01dbd9a 9d5060f 01dbd9a 9d5060f 01dbd9a 15bcfa5 01dbd9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import os
# Path to the saved model directory inside the Space
model_dir = "./campaign_bert_model/campaign_bert_model/campaign-bert-model"
# Load tokenizer and model
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.eval()
except Exception as e:
raise RuntimeError(f"❌ Failed to load model or tokenizer: {e}")
# Map the 5 classes to tones and templates
class_map = {
0: ("Informative", "Here are plan details tailored for your interest."),
1: ("Excited", "Great news! You’re eligible for our premium plans!"),
2: ("Neutral", "Explore various insurance options with us."),
3: ("Persuasive", "Take the first step to secure your future today."),
4: ("Empathetic", "We understand your needs—here’s how we can help."),
}
def predict(text):
try:
if not text.strip():
return "<h3 style='color:red'>⚠️ Please enter a message.</h3>", ""
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Run model inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
if logits is None or logits.shape[1] != len(class_map):
return f"<h3 style='color:red'>❌ Invalid model output shape: {logits.shape}</h3>", ""
probs = torch.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_class].item()
tone, template = class_map[pred_class]
return (
f"<h3 style='color:green'>Tone: {tone}</h3><p>📨 Suggested Campaign Message:<br><b>{template}</b></p>",
f"<p>Confidence: <b>{confidence:.2%}</b></p>"
)
except Exception as e:
return f"<h3 style='color:red'>Error: {str(e)}</h3>", ""
# Gradio UI
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Client Message", placeholder="E.g. I want to know about child education plans"),
outputs=[
gr.HTML(label="Prediction"),
gr.HTML(label="Confidence"),
],
title="📢 Campaign Personalizer",
description="Predicts message tone and template using a fine-tuned BERT model with 5 classes.",
allow_flagging="never"
)
iface.launch()
|