File size: 2,440 Bytes
3b6b8a6
c3c8834
 
84db39a
d17306e
01dbd9a
 
 
 
5522dd0
9d5060f
 
01dbd9a
5522dd0
 
65ce341
01dbd9a
 
 
 
 
 
 
62f1d54
15bcfa5
01dbd9a
9d5060f
01dbd9a
 
 
 
 
 
 
9d5060f
 
c3c8834
01dbd9a
8bbd5d4
01dbd9a
 
c3c8834
01dbd9a
 
 
c3c8834
01dbd9a
9d5060f
01dbd9a
 
 
 
9d5060f
01dbd9a
 
9d5060f
01dbd9a
 
 
 
 
 
 
 
 
 
 
 
15bcfa5
01dbd9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import os

# Path to the saved model directory inside the Space
model_dir = "./campaign_bert_model/campaign_bert_model/campaign-bert-model"

# Load tokenizer and model
try:
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    model.eval()
except Exception as e:
    raise RuntimeError(f"❌ Failed to load model or tokenizer: {e}")

# Map the 5 classes to tones and templates
class_map = {
    0: ("Informative", "Here are plan details tailored for your interest."),
    1: ("Excited", "Great news! You’re eligible for our premium plans!"),
    2: ("Neutral", "Explore various insurance options with us."),
    3: ("Persuasive", "Take the first step to secure your future today."),
    4: ("Empathetic", "We understand your needs—here’s how we can help."),
}

def predict(text):
    try:
        if not text.strip():
            return "<h3 style='color:red'>⚠️ Please enter a message.</h3>", ""

        # Tokenize input
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

        # Run model inference
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits

        if logits is None or logits.shape[1] != len(class_map):
            return f"<h3 style='color:red'>❌ Invalid model output shape: {logits.shape}</h3>", ""

        probs = torch.softmax(logits, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred_class].item()

        tone, template = class_map[pred_class]

        return (
            f"<h3 style='color:green'>Tone: {tone}</h3><p>📨 Suggested Campaign Message:<br><b>{template}</b></p>",
            f"<p>Confidence: <b>{confidence:.2%}</b></p>"
        )

    except Exception as e:
        return f"<h3 style='color:red'>Error: {str(e)}</h3>", ""

# Gradio UI
iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="Client Message", placeholder="E.g. I want to know about child education plans"),
    outputs=[
        gr.HTML(label="Prediction"),
        gr.HTML(label="Confidence"),
    ],
    title="📢 Campaign Personalizer",
    description="Predicts message tone and template using a fine-tuned BERT model with 5 classes.",
    allow_flagging="never"
)

iface.launch()