Spaces:
Running
Running
import streamlit as st | |
from utils.retriever import retrieve_paragraphs | |
from utils.generator import build_messages, _call_llm | |
from utils.utils import meetings_list, countries_list, projects_list | |
import ast | |
import time | |
import asyncio | |
import re | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
# st.set_page_config(layout="wide") | |
st.markdown( | |
""" | |
<style> | |
.full-width-banner { | |
width: 100vw; | |
position: relative; | |
left: -50vw; | |
right: -50vw; | |
margin-left: 50%; | |
margin-right: 50%; | |
background-color: #0071BC; /* UN Blue */ | |
padding: 15px 0; | |
text-align: center; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
z-index: 1000; | |
} | |
</style> | |
<div class="full-width-banner"> | |
<h1 style="color:white; margin:0;">Montreal AI Decisions (MVP)</h1> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
""" | |
<style> | |
/* Fix content overflow in expanders and all custom containers */ | |
.streamlit-expanderContent, .source-block { | |
max-width: 700px; | |
word-wrap: break-word; | |
overflow-wrap: break-word; | |
white-space: pre-wrap; | |
font-size: 16px; | |
} | |
/* Force label size on text input and selectboxes */ | |
label[data-testid="stWidgetLabel"] { | |
font-size: 20px !important; | |
font-weight: 600 !important; | |
color: #000000 !important; | |
} | |
/* Optional: Adjust placeholder font size */ | |
input[type="text"]::placeholder { | |
font-size: 18px !important; | |
} | |
/* Optional: Adjust the selected option inside the dropdown */ | |
div[role="combobox"] * { | |
font-size: 18px !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# Add vertical spacing between banner and help text | |
st.markdown("<div style='margin-top: 40px;'></div>", unsafe_allow_html=True) | |
# Help text (static) | |
st.markdown("""<p style='text-align: left; font-weight: 600; margin-bottom: 1rem;'>Welcome to your chatbot research assistant. It helps you find and summarize specific decisions and annexes, and can also answer general questions about the text. For transparency, it provides references with links to the relevant decisions and annexes, so you can easily verify the sources. \n While this chatbot was developed with care, we recommend double-checking the links to gain a deeper understanding of the material. </p>""", unsafe_allow_html=True) | |
# Add vertical spacing between help text and question input | |
st.markdown("<div style='margin-top: 25px;'></div>", unsafe_allow_html=True) | |
########### Function for getting response ####################### | |
def chat_response(query, filter_metadata=None): | |
"""Generate chat response based on method and inputs""" | |
try: | |
retrieved_paragraphs = retrieve_paragraphs(query, filter_metadata=filter_metadata) | |
context_retrieved = ast.literal_eval(retrieved_paragraphs) | |
# Build list of only content, no metadata | |
context_retrieved_formatted = "||".join(doc['answer'] for doc in context_retrieved) | |
context_retrieved_lst = [doc['answer'] for doc in context_retrieved] | |
logging.info("Context Retrieval done") | |
logging.info(f"Content {context_retrieved}") | |
messages = build_messages(query, context_retrieved_lst) | |
answer = asyncio.run(_call_llm(messages)) | |
return answer, context_retrieved | |
except Exception as e: | |
error_message = f"Error processing request: {str(e)}" | |
return error_message | |
############## UI related functions ##################### | |
def reset_page(): | |
""" | |
Reset pagination back to the first page; used as on_change callback. | |
""" | |
st.session_state["page"] = 1 | |
def contruct_metadata_filter(): | |
filter_metadata = {} | |
if st.session_state['meetings_filter'] != 'All': | |
filter_metadata['meeting_id'] = st.session_state['meetings_filter'] | |
## need to change the filter for coutnry and project tolist | |
if st.session_state['country_filter'] != 'All': | |
filter_metadata['Countries'] = st.session_state['country_filter'] | |
if st.session_state['project_filter'] != 'All': | |
filter_metadata['Projects'] = st.session_state['project_filter'] | |
logging.info(f"contructed metadata_filter {filter_metadata}") | |
return filter_metadata | |
def render_sources(chunks, query): | |
# 11.7. Render each result chunk | |
st.subheader("Sources") | |
st.write("======================================") | |
start_idx = 0 | |
for idx, doc in enumerate(chunks, start=start_idx + 1): | |
meta = doc.get('answer_metadata', {}) | |
title = meta.get('Decision Number', 'Unknown Project') | |
agencies = meta.get('Agencies', 'Unknown Agencies') | |
country = meta.get('country', 'Unknown Country') | |
snippet = doc.get('answer', '') | |
preview = snippet.split(maxsplit=90)[:90] | |
remainder = snippet[len(" ".join(preview)):] | |
# Wrap markdown in a div with limited width | |
st.markdown(f""" | |
<div class="source-block"> | |
<h4>{idx}. {title}</h4> | |
<p><strong>Agencies:</strong> {agencies} | <strong>Country:</strong> {country}</p> | |
<p>{" ".join(preview)}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
if remainder: | |
with st.expander("Show more"): | |
st.markdown( | |
f"<div class='source-block'>{remainder}</div>", | |
unsafe_allow_html=True | |
) | |
st.divider() | |
for key in ('meetings_filter', 'country_filter', 'project_filter'): | |
if key not in st.session_state: | |
st.session_state[key] = 'All' | |
if 'page' not in st.session_state: | |
st.session_state['page'] = 1 | |
col_query, col_about = st.columns([8, 2]) | |
# 10.1. Question input | |
with col_query: | |
query = st.text_input( | |
label="Enter your question:", | |
key="query", | |
on_change = reset_page | |
) | |
# 10.2. Filter widgets | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
meetings = sorted(meetings_list) | |
st.selectbox( | |
"Meeting", | |
options=['All'] + meetings, | |
key='meetings_filter', | |
on_change=reset_page | |
) | |
with col2: | |
countries = sorted(countries_list) | |
st.selectbox( | |
"Country", | |
options=['All'] + countries, | |
key='country_filter', | |
on_change=reset_page | |
) | |
with col3: | |
projects = sorted(projects_list) | |
st.selectbox( | |
"Projects", | |
options=['All'] + projects, | |
key='project_filter', | |
on_change=reset_page | |
) | |
# Only run search & display if user has entered something | |
if not query.strip(): | |
st.info("Please enter a question to see results.") | |
st.stop() | |
else: | |
filter_metadata = contruct_metadata_filter() | |
if filter_metadata: | |
logging.info("calling with metadata filter") | |
answer, context_retrieved = chat_response(query, filter_metadata) | |
st.write(answer) | |
render_sources(context_retrieved, query) | |
else: | |
logging.info("calling without metadata filter") | |
answer, context_retrieved = chat_response(query) | |
st.write(answer) | |
render_sources(context_retrieved, query) |