Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Set page configuration for the app | |
st.set_page_config( | |
page_title="Review Assistant", | |
page_icon="📝", | |
layout="centered" | |
) | |
# Custom CSS styling for the page | |
st.markdown(""" | |
<style> | |
.main-header { | |
font-size: 2.2rem; | |
color: #1E3A8A; | |
text-align: center; | |
margin-bottom: 0; | |
padding-bottom: 0; | |
} | |
.sub-header { | |
font-size: 1rem; | |
color: #6B7280; | |
text-align: center; | |
margin-top: 0.3rem; | |
margin-bottom: 2rem; | |
} | |
.result-card { | |
padding: 1.2rem; | |
border-radius: 8px; | |
margin-bottom: 1rem; | |
} | |
.topic-card { | |
background-color: #ECFDF5; | |
border-left: 4px solid #10B981; | |
} | |
.sentiment-card { | |
background-color: #EFF6FF; | |
border-left: 4px solid #3B82F6; | |
} | |
.reply-card { | |
background-color: #F9FAFB; | |
border-left: 4px solid #6B7280; | |
padding: 1.5rem; | |
} | |
.result-label { | |
font-weight: bold; | |
margin-bottom: 0.5rem; | |
} | |
/* Changed button color from blue (#2563EB) to green (#10B981) */ | |
.stButton>button { | |
background-color: #10B981; | |
color: white; | |
border: none; | |
padding: 0.5rem 2rem; | |
border-radius: 6px; | |
font-weight: 500; | |
} | |
/* Changed hover color to darker green (#059669) */ | |
.stButton>button:hover { | |
background-color: #059669; | |
} | |
.footer { | |
text-align: center; | |
color: #9CA3AF; | |
font-size: 0.8rem; | |
margin-top: 3rem; | |
} | |
.stSpinner { | |
text-align: center; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Page headers | |
st.markdown("<h1 class='main-header'>Smart Review Analysis Assistant</h1>", unsafe_allow_html=True) | |
st.markdown("<p class='sub-header'>Topic Recognition, Sentiment Analysis, and Auto Reply in One Click</p>", unsafe_allow_html=True) | |
# ------- Load ML Pipelines ------- | |
def load_pipelines(): | |
""" | |
Load all required machine learning models for review analysis: | |
- Topic classifier using zero-shot classification | |
- Sentiment analysis model | |
- Customer reply generation model | |
Returns models and topic labels for use in the application | |
""" | |
# Define topic categories for classification | |
topic_labels = [ | |
"billing", "account access", "customer service", "loans", | |
"fraud", "technical issue", "credit card", "mobile app", | |
"branch service", "transaction delay", "account closure", "information error" | |
] | |
# Ensure compatibility with CPU-only environments | |
dtype = torch.float32 # Use float32 for better CPU compatibility | |
# Load topic classification model | |
topic_classifier = pipeline( | |
"zero-shot-classification", | |
model="MoritzLaurer/deberta-v3-base-zeroshot-v1", | |
) | |
# Load sentiment analysis model | |
sentiment_classifier = pipeline( | |
"sentiment-analysis", | |
model="cardiffnlp/twitter-roberta-base-sentiment-latest", | |
) | |
# Load reply generation model | |
model_name = "Leo66277/finetuned-tinyllama-customer-replies" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
def generate_reply(text): | |
"""Generate a polite customer service reply based on the given review text""" | |
prompt_text = f"Please write a short, polite English customer service reply to the following customer comment:\n{text}" | |
inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512) | |
# Generate response with beam search for better quality | |
with torch.no_grad(): | |
gen_ids = model.generate( | |
inputs.input_ids, | |
max_length=inputs.input_ids.shape[1] + 120, | |
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id, | |
num_beams=3, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
# Clean up the generated text | |
reply = tokenizer.decode(gen_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() | |
reply = reply.strip('"').replace('\n', ' ').replace(' ', ' ') | |
return reply | |
return topic_classifier, sentiment_classifier, generate_reply, topic_labels | |
# ------- Page Layout -------- | |
st.markdown("### Enter a review for instant analysis") | |
# Updated example review as requested | |
example_review = "BOA states on their website that holds are 2-7 days. I made a deposit, and the receipt states funds would be available in 2 days. Now, 13 days later, I am still waiting on my funds, and BOA can't give me a straight answer." | |
# Text input area for user reviews | |
user_review = st.text_area( | |
"Please enter or paste a review below:", | |
value=example_review, | |
height=120 | |
) | |
# Button to trigger analysis | |
if st.button("Analyze"): | |
if not user_review.strip(): | |
# Validate user input | |
st.warning("Please enter a valid review!") | |
else: | |
# Show user-friendly loading message | |
with st.spinner("Analyzing your review..."): | |
# Load models on first use to avoid loading message on startup | |
if "topic_pipe" not in st.session_state: | |
st.session_state.topic_pipe, st.session_state.sentiment_pipe, st.session_state.reply_generator, st.session_state.topic_labels = load_pipelines() | |
# Run topic classification | |
topic_result = st.session_state.topic_pipe(user_review, st.session_state.topic_labels, multi_label=False) | |
topic = topic_result['labels'][0] | |
# Removed confidence percentage as requested | |
# Run sentiment analysis | |
sentiment_result = st.session_state.sentiment_pipe(user_review) | |
sentiment = sentiment_result[0]['label'] | |
# Removed confidence percentage as requested | |
# Generate auto reply | |
reply_text = st.session_state.reply_generator(user_review) | |
# Display results in a visually appealing format | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown(f"<div class='result-card topic-card'><p class='result-label'>Topic:</p>{topic}</div>", unsafe_allow_html=True) | |
with col2: | |
st.markdown(f"<div class='result-card sentiment-card'><p class='result-label'>Sentiment:</p>{sentiment}</div>", unsafe_allow_html=True) | |
# Display suggested reply | |
st.markdown(f"<div class='result-card reply-card'><p class='result-label'>Auto-reply Suggestion:</p>{reply_text}</div>", unsafe_allow_html=True) | |
# Page footer | |
st.markdown("<div class='footer'>© 2024 Review AI Assistant</div>", unsafe_allow_html=True) |