import streamlit as st import os import torch import torch import torch.nn as nn import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go from datetime import datetime from transformers import AutoTokenizer, AutoModel from preprocessing import clean_doc import warnings warnings.filterwarnings('ignore') st.set_page_config( page_title="ViSoBERT Emotion Recognition", page_icon="😊", layout="wide", initial_sidebar_state="expanded" ) # CSS st.markdown(""" """, unsafe_allow_html=True) # ==================== MODEL DEFINITIONS ==================== class ViSoBERTEmotionClassifier(nn.Module): def __init__(self, model_name, num_classes=7, dropout_rate=0.3): super(ViSoBERTEmotionClassifier, self).__init__() # Load ViSoBERT model self.visobert = AutoModel.from_pretrained(model_name) # Classifier layers self.dropout = nn.Dropout(dropout_rate) self.classifier = nn.Sequential( nn.Linear(self.visobert.config.hidden_size, 512), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(512, num_classes) ) def forward(self, input_ids, attention_mask): outputs = self.visobert( input_ids=input_ids, attention_mask=attention_mask ) pooled_output = outputs.pooler_output output = self.dropout(pooled_output) logits = self.classifier(output) return logits # ==================== CONSTANTS ==================== emotion_labels = { 0: "Vui vẻ", 1: "Buồn bã", 2: "Tức giận", 3: "Sợ hãi", 4: "Ngạc nhiên", 5: "Kinh tởm", 6: "Khác" } emotion_colors = { "Vui vẻ": "#FFD700", "Buồn bã": "#4169E1", "Tức giận": "#DC143C", "Sợ hãi": "#800080", "Ngạc nhiên": "#FF8C00", "Kinh tởm": "#228B22", "Khác": "#808080" } emotion_emojis = { "Vui vẻ": "😊", "Buồn bã": "😢", "Tức giận": "😠", "Sợ hãi": "😨", "Ngạc nhiên": "😲", "Kinh tởm": "🤢", "Khác": "😐" } # ==================== CACHING FUNCTIONS ==================== @st.cache_resource def load_model_and_tokenizer(): """Load model và tokenizer với caching""" try: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_name = "uitnlp/visobert" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # Load model model = ViSoBERTEmotionClassifier(model_name, num_classes=7) # Load trained weights model_path = 'best_visobert_emotion_model.pth' checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint) model.to(device) model.eval() return model, tokenizer, device, checkpoint except Exception as e: st.error(f"Lỗi khi load model: {e}") return None, None, None, None # ==================== PREDICTION FUNCTIONS ==================== def predict_emotion(model, tokenizer, device, text, max_length=256): """Dự đoán cảm xúc cho văn bản""" try: # Tiền xử lý văn bản text = clean_doc(text) # Tokenization encoding = tokenizer( text, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt' ) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) with torch.no_grad(): logits = model(input_ids, attention_mask) probabilities = torch.softmax(logits, dim=1) predicted_class = torch.argmax(logits, dim=1).item() confidence = probabilities[0][predicted_class].item() return { 'emotion': emotion_labels[predicted_class], 'confidence': confidence, 'probabilities': {emotion_labels[i]: prob.item() for i, prob in enumerate(probabilities[0])} } except Exception as e: st.error(f"Lỗi khi dự đoán: {e}") return None def create_probability_chart(probabilities): """Tạo biểu đồ xác suất cho các cảm xúc""" emotions = list(probabilities.keys()) probs = list(probabilities.values()) colors = [emotion_colors[emotion] for emotion in emotions] fig = px.bar( x=emotions, y=probs, color=emotions, color_discrete_map=emotion_colors, title="Phân bố xác suất các cảm xúc", labels={'x': 'Cảm xúc', 'y': 'Xác suất'} ) fig.update_layout( showlegend=False, height=400, xaxis_tickangle=-45 ) return fig # ==================== MAIN APP ==================== def main(): # Header st.markdown("""

🤖 ViSoBERT Emotion Recognition

Phân tích cảm xúc văn bản tiếng Việt với mô hình ViSoBERT

""", unsafe_allow_html=True) # Load model with st.spinner("Đang tải model..."): model, tokenizer, device, checkpoint = load_model_and_tokenizer() if model is None: st.error("Không thể tải model. Vui lòng kiểm tra đường dẫn file model.") return # Sidebar - Model Info with st.sidebar: st.markdown("### 📊 Thông tin Model") if checkpoint: st.markdown(f""" """, unsafe_allow_html=True) st.markdown("### 🎯 Các loại cảm xúc") for emotion, emoji in emotion_emojis.items(): st.markdown(f"{emoji} **{emotion}**") # Main content col1, col2 = st.columns([2, 1]) with col1: st.markdown("### 📝 Nhập văn bản để phân tích") # Text input methods input_method = st.radio( "Chọn cách nhập:", ["Nhập trực tiếp", "Upload file (CSV/Excel)"] ) text_input = "" batch_analysis = False df_to_analyze = None selected_column = None if input_method == "Nhập trực tiếp": text_input = st.text_area( "Văn bản:", height=150, placeholder="Ví dụ: Hôm nay tôi rất vui vì được gặp bạn bè...", help="Nhập văn bản tiếng Việt để phân tích cảm xúc" ) else: st.markdown("#### 📂 Upload file dữ liệu") uploaded_file = st.file_uploader( "Chọn file CSV hoặc Excel", type=["csv", "xlsx", "xls"], help="File phải chứa cột 'text' hoặc tương tự với nội dung văn bản" ) if uploaded_file is not None: try: # Đọc file if uploaded_file.name.endswith(".csv"): df = pd.read_csv(uploaded_file) else: df = pd.read_excel(uploaded_file) st.success(f"✅ Đã tải thành công file với {len(df)} dòng dữ liệu") # Hiển thị preview with st.expander("👀 Xem trước dữ liệu", expanded=False): st.dataframe(df.head(10)) # Chọn cột chứa text text_columns = [col for col in df.columns if df[col].dtype == 'object'] if text_columns: selected_column = st.selectbox( "Chọn cột chứa văn bản:", text_columns, help="Chọn cột chứa nội dung văn bản cần phân tích" ) # Chọn chế độ phân tích analysis_mode = st.radio( "Chọn chế độ phân tích:", ["Phân tích từng câu", "Phân tích toàn bộ file"], help="Chọn phân tích từng câu hoặc phân tích tất cả dữ liệu trong file" ) if selected_column in df.columns: # Loại bỏ các dòng trống valid_texts = df[selected_column].dropna() if len(valid_texts) > 0: if analysis_mode == "Phân tích từng câu": # Chế độ phân tích đơn lẻ selected_index = st.selectbox( "Chọn câu để phân tích:", range(len(valid_texts)), format_func=lambda x: f"Dòng {x+1}: {str(valid_texts.iloc[x])[:100]}{'...' if len(str(valid_texts.iloc[x])) > 100 else ''}" ) selected_text = str(valid_texts.iloc[selected_index]) text_input = st.text_area( "Văn bản được chọn (có thể chỉnh sửa):", value=selected_text, height=100 ) else: # Chế độ phân tích batch batch_analysis = True df_to_analyze = df[df[selected_column].notna()].copy() st.info(f"🔄 Sẽ phân tích {len(df_to_analyze)} văn bản trong file") # Hiển thị sample with st.expander("📝 Xem mẫu dữ liệu sẽ phân tích"): sample_df = df_to_analyze[[selected_column]].head(5) st.dataframe(sample_df) else: st.warning("⚠️ Không tìm thấy dữ liệu văn bản hợp lệ trong cột đã chọn.") else: st.error("❌ File không chứa cột văn bản. Vui lòng kiểm tra lại định dạng file.") except Exception as e: st.error(f"❌ Lỗi khi đọc file: {str(e)}") st.info("💡 Đảm bảo file có định dạng đúng và chứa dữ liệu văn bản.") else: st.info("📤 Vui lòng chọn file để tải lên") # Predict button st.markdown("---") predict_button = st.button( "🎯 Phân tích cảm xúc", type="primary", use_container_width=True, disabled=(not text_input.strip() and not batch_analysis) ) # Results section if predict_button and (text_input.strip() or batch_analysis): if batch_analysis and df_to_analyze is not None: # Batch analysis st.markdown("### 🔄 Phân tích toàn bộ file") # Progress bar progress_bar = st.progress(0) status_text = st.empty() results = [] total_texts = len(df_to_analyze) for idx, row in df_to_analyze.iterrows(): text = str(row[selected_column]) if text.strip(): status_text.text(f'Đang phân tích văn bản {len(results)+1}/{total_texts}...') result = predict_emotion(model, tokenizer, device, text) if result: results.append({ 'index': idx, 'text': text, 'emotion': result['emotion'], 'confidence': result['confidence'], 'probabilities': result['probabilities'] }) progress_bar.progress((len(results)) / total_texts) status_text.text('✅ Hoàn thành phân tích!') progress_bar.progress(1.0) if results: # Create results dataframe results_df = pd.DataFrame([ { 'STT': i+1, 'Văn bản': r['text'][:100] + '...' if len(r['text']) > 100 else r['text'], 'Cảm xúc': emotion_emojis[r['emotion']] + ' ' + r['emotion'], 'Độ tin cậy': f"{r['confidence']:.2%}" } for i, r in enumerate(results) ]) # Display results table st.markdown("#### 📊 Kết quả phân tích") st.dataframe(results_df, use_container_width=True, hide_index=True) # Statistics emotion_counts = {} for r in results: emotion = r['emotion'] emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1 # Emotion distribution chart st.markdown("#### 📈 Phân bố cảm xúc") col_chart1, col_chart2 = st.columns(2) with col_chart1: # Pie chart fig_pie = px.pie( values=list(emotion_counts.values()), names=[emotion_emojis[e] + ' ' + e for e in emotion_counts.keys()], title="Tỷ lệ các cảm xúc", color=names, color_discrete_map={emotion_emojis[e] + ' ' + e: emotion_colors[e] for e in emotion_counts.keys()} ) st.plotly_chart(fig_pie, use_container_width=True) with col_chart2: # Bar chart emotions_for_bar = [emotion_emojis[e] + ' ' + e for e in emotion_counts.keys()] fig_bar = px.bar( x=emotions_for_bar, y=list(emotion_counts.values()), title="Số lượng theo cảm xúc", color=emotions_for_bar, color_discrete_map={emotion_emojis[e] + ' ' + e: emotion_colors[e] for e in emotion_counts.keys()} ) fig_bar.update_layout(showlegend=False, xaxis_tickangle=-45) st.plotly_chart(fig_bar, use_container_width=True) # Summary statistics st.markdown("#### 📝 Thống kê tổng quan") col_stat1, col_stat2, col_stat3, col_stat4 = st.columns(4) with col_stat1: st.metric("Tổng số văn bản", len(results)) with col_stat2: avg_confidence = np.mean([r['confidence'] for r in results]) st.metric("Độ tin cậy TB", f"{avg_confidence:.2%}") with col_stat3: most_common = max(emotion_counts.items(), key=lambda x: x[1]) st.metric("Cảm xúc phổ biến nhất", f"{emotion_emojis[most_common[0]]} {most_common[0]}") with col_stat4: high_confidence = len([r for r in results if r['confidence'] > 0.7]) st.metric("Dự đoán tin cậy cao", f"{high_confidence}/{len(results)}") # Download results st.markdown("#### 💾 Tải kết quả") # Prepare detailed results for download detailed_results = [] for r in results: row = { 'text': r['text'], 'predicted_emotion': r['emotion'], 'confidence': r['confidence'] } # Add probability for each emotion for emotion, prob in r['probabilities'].items(): row[f'prob_{emotion}'] = prob detailed_results.append(row) detailed_df = pd.DataFrame(detailed_results) # Convert to CSV csv = detailed_df.to_csv(index=False) st.download_button( label="📥 Tải kết quả (CSV)", data=csv, file_name=f"emotion_analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv" ) # Store batch results for sidebar st.session_state.batch_results = results elif text_input.strip(): # Single text analysis with st.spinner("🔄 Đang phân tích cảm xúc..."): result = predict_emotion(model, tokenizer, device, text_input) if result: # Main prediction result emotion = result['emotion'] confidence = result['confidence'] emoji = emotion_emojis[emotion] st.markdown(f"""

{emoji} {emotion}

Độ tin cậy: {confidence:.2%}

""", unsafe_allow_html=True) # Detailed results st.markdown("### 📈 Phân tích chi tiết") # Probability chart fig_bar = create_probability_chart(result['probabilities']) st.plotly_chart(fig_bar, use_container_width=True) # Probability table prob_df = pd.DataFrame([ { 'Cảm xúc': emotion_emojis[emo] + " " + emo, 'Xác suất': f"{prob:.4f}", 'Phần trăm': f"{prob:.2%}" } for emo, prob in result['probabilities'].items() ]).sort_values('Phần trăm', ascending=False) st.dataframe( prob_df, use_container_width=True, hide_index=True ) # Store result for sidebar st.session_state.current_result = result st.session_state.current_text = text_input elif predict_button and not text_input.strip() and not batch_analysis: st.warning("⚠️ Vui lòng nhập văn bản hoặc chọn từ file để phân tích!") with col2: st.markdown("### 📊 Thống kê") # Text statistics if text_input.strip(): st.markdown("#### 📝 Thông tin văn bản") text_stats = { "Số từ": len(text_input.split()), "Số ký tự": len(text_input), "Số câu": len([s for s in text_input.split('.') if s.strip()]) } for stat, value in text_stats.items(): st.metric(stat, value) # Batch analysis summary if hasattr(st.session_state, 'batch_results') and st.session_state.batch_results: st.markdown("#### 📊 Tóm tắt phân tích file") batch_results = st.session_state.batch_results # Quick stats total_analyzed = len(batch_results) avg_confidence = np.mean([r['confidence'] for r in batch_results]) st.metric("Đã phân tích", f"{total_analyzed} văn bản") st.metric("Độ tin cậy TB", f"{avg_confidence:.2%}") # Top emotions emotion_counts = {} for r in batch_results: emotion = r['emotion'] emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1 st.markdown("**Top cảm xúc:**") sorted_emotions = sorted(emotion_counts.items(), key=lambda x: x[1], reverse=True) for emotion, count in sorted_emotions[:3]: percentage = (count / total_analyzed) * 100 st.write(f"{emotion_emojis[emotion]} {emotion}: {count} ({percentage:.1f}%)") # History section if 'prediction_history' not in st.session_state: st.session_state.prediction_history = [] # Add to history when prediction is made if (hasattr(st.session_state, 'current_result') and hasattr(st.session_state, 'current_text')): # Check if this prediction is already in history current_time = datetime.now().strftime("%H:%M:%S") current_text_short = (st.session_state.current_text[:50] + "..." if len(st.session_state.current_text) > 50 else st.session_state.current_text) # Add to history if not duplicate if (not st.session_state.prediction_history or st.session_state.prediction_history[-1]['text'] != current_text_short): st.session_state.prediction_history.append({ 'time': current_time, 'text': current_text_short, 'emotion': st.session_state.current_result['emotion'], 'confidence': st.session_state.current_result['confidence'] }) # Keep only last 5 predictions if len(st.session_state.prediction_history) > 5: st.session_state.prediction_history.pop(0) # Display history if st.session_state.prediction_history: st.markdown("#### 📚 Lịch sử phân tích") for i, pred in enumerate(reversed(st.session_state.prediction_history)): emoji = emotion_emojis[pred['emotion']] st.markdown(f"""
{pred['time']}
{emoji} {pred['emotion']} ({pred['confidence']:.2%})
"{pred['text']}"
""", unsafe_allow_html=True) # Clear history button if st.button("🗑️ Xóa lịch sử", use_container_width=True): st.session_state.prediction_history = [] st.rerun() # Footer st.markdown("---") st.markdown("""

🚀 Phát triển bởi Nhóm AI - HVNH

📚 Sử dụng mô hình uitnlp/visobert với Focal Loss

""", unsafe_allow_html=True) if __name__ == "__main__": main()