Spaces:
Sleeping
Sleeping
""" | |
Chat components for rendering messages and handling user interactions. | |
""" | |
import streamlit as st | |
import pandas as pd | |
import base64 | |
from typing import List, Dict, Any | |
from config import OUTLINE_INDIGO_USER, DARK_MODE_SLATE_AI, RETRY_BUTTON_TEXT, DOWNLOAD_BUTTON_TEXT | |
class ChatRenderer: | |
"""Handles rendering of chat messages and interactions.""" | |
def __init__(self): | |
self.user_avatar = OUTLINE_INDIGO_USER | |
self.ai_avatar = DARK_MODE_SLATE_AI | |
def render_chat_history(self, messages: List[Dict[str, Any]]) -> None: | |
"""Render the complete chat history.""" | |
st.markdown("<div class='chat-container'>", unsafe_allow_html=True) | |
for i, msg in enumerate(messages): | |
if msg["role"] == "user": | |
self._render_user_message(msg, i, messages) | |
elif msg["role"] == "assistant": | |
# Skip error messages that are already shown after user messages | |
if not self._is_error_shown_after_user(i, messages): | |
self._render_assistant_message(msg) | |
st.markdown("</div>", unsafe_allow_html=True) | |
def _render_user_message(self, msg: Dict[str, Any], index: int, messages: List[Dict[str, Any]]) -> None: | |
"""Render a user message with avatar.""" | |
st.markdown( | |
f'''<div style="display: flex; align-items: flex-start; justify-content: flex-end; margin-bottom: 0.5em;"> | |
<div style="margin-right: 0.5em;"> | |
<img src="{self.user_avatar}" alt="User" style="width: 2.3rem; height: 2.3rem; border-radius: 50%; border: 2px solid #e3f2fd; background: #fff; object-fit: cover;" /> | |
</div> | |
<div class="user-bubble">{msg["content"]}</div> | |
</div>''', | |
unsafe_allow_html=True | |
) | |
# Check if next message is an error and render retry option | |
self._render_error_retry_if_needed(index, messages) | |
def _render_error_retry_if_needed(self, index: int, messages: List[Dict[str, Any]]) -> None: | |
"""Render error message and retry button if the next message is an error.""" | |
if ( | |
index + 1 < len(messages) | |
and messages[index + 1]["role"] == "assistant" | |
and messages[index + 1].get("is_error") | |
): | |
error_msg = messages[index + 1]["content"] | |
cols = st.columns([0.3, 0.7]) | |
with cols[1]: | |
col1, col2 = st.columns([0.8, 0.2]) | |
with col1: | |
st.markdown( | |
f'<div style="background: #ffebee; color: #b71c1c; font-weight: bold; border-radius: 14px; padding: 8px 14px; max-width: 100%; word-wrap: break-word; box-shadow: 0 2px 8px rgba(0,0,0,0.15); text-align: right;">{error_msg}</div>', | |
unsafe_allow_html=True | |
) | |
with col2: | |
if st.button(RETRY_BUTTON_TEXT, key=f"retry_{index}"): | |
self._handle_retry(index) | |
def _handle_retry(self, index: int) -> None: | |
"""Handle retry button click.""" | |
messages = st.session_state["messages"] | |
st.session_state["messages"] = messages[:index + 1] | |
st.session_state["messages"].append({ | |
"role": "assistant", | |
"content": "π€ Thinking...", | |
"is_placeholder": True | |
}) | |
st.rerun() | |
def _is_error_shown_after_user(self, index: int, messages: List[Dict[str, Any]]) -> bool: | |
"""Check if this error message is already shown after the previous user message.""" | |
if not messages[index].get("is_error"): | |
return False | |
# Check if this is an error that follows a user message | |
if index > 0 and messages[index - 1]["role"] == "user": | |
return True | |
return False | |
def _render_assistant_message(self, msg: Dict[str, Any]) -> None: | |
"""Render an assistant message with avatar and optional data/charts.""" | |
bubble_class = "error-bubble" if msg.get("is_error") else "ai-bubble" | |
# Render message bubble | |
if msg.get("is_placeholder"): | |
self._render_thinking_message(bubble_class) | |
else: | |
self._render_regular_message(msg["content"], bubble_class) | |
# Render data table if present | |
if msg.get("data"): | |
self._render_data_table(msg["data"], msg) | |
# Render chart if present | |
if msg.get("chart"): | |
self._render_chart(msg["chart"]) | |
def _render_thinking_message(self, bubble_class: str) -> None: | |
"""Render a thinking/loading message with spinner.""" | |
st.markdown( | |
f'''<div style="display: flex; align-items: flex-start; margin-bottom: 0.5em;"> | |
<div style="margin-right: 0.5em;"> | |
<img src="{self.ai_avatar}" alt="AI" style="width: 2.3rem; height: 2.3rem; border-radius: 50%; border: 2px solid #b2dfdb; background: #fff; object-fit: cover;" /> | |
</div> | |
<div class="{bubble_class}"><span class="spinner"></span>Thinking...</div> | |
</div>''', | |
unsafe_allow_html=True | |
) | |
def _render_regular_message(self, content: str, bubble_class: str) -> None: | |
"""Render a regular assistant message.""" | |
st.markdown( | |
f'''<div style="display: flex; align-items: flex-start; margin-bottom: 0.5em;"> | |
<div style="margin-right: 0.5em;"> | |
<img src="{self.ai_avatar}" alt="AI" style="width: 2.3rem; height: 2.3rem; border-radius: 50%; border: 2px solid #b2dfdb; background: #fff; object-fit: cover;" /> | |
</div> | |
<div class="{bubble_class}">{content}</div> | |
</div>''', | |
unsafe_allow_html=True | |
) | |
def _render_data_table(self, data: List[Dict], msg: Dict[str, Any]) -> None: | |
"""Render data table with download option.""" | |
df = pd.DataFrame(data) | |
st.dataframe(df, use_container_width=True) | |
csv = df.to_csv(index=False).encode("utf-8") | |
st.download_button( | |
DOWNLOAD_BUTTON_TEXT, | |
csv, | |
"results.csv", | |
"text/csv", | |
key=f"download_csv_{id(msg)}" | |
) | |
def _render_chart(self, chart_data: str) -> None: | |
"""Render chart from base64 data.""" | |
img_data = base64.b64decode(chart_data) | |
st.image(img_data, use_column_width=True) | |