Spaces:
Sleeping
Sleeping
Commit
·
ac2020e
1
Parent(s):
e812ccd
update
Browse files- .gitignore +1 -0
- src/legisqa_local/components/__init__.py +1 -0
- src/legisqa_local/components/display.py +63 -0
- src/legisqa_local/components/forms.py +92 -0
- src/legisqa_local/components/sidebar.py +59 -0
- src/legisqa_local/config/__init__.py +1 -0
- src/legisqa_local/config/models.py +41 -0
- src/legisqa_local/config/settings.py +39 -0
- src/legisqa_local/core/__init__.py +1 -0
- src/legisqa_local/core/embeddings.py +14 -0
- src/legisqa_local/core/llm.py +51 -0
- src/legisqa_local/core/rag.py +56 -0
- src/legisqa_local/core/vectorstore.py +31 -0
- src/legisqa_local/tabs/__init__.py +1 -0
- src/legisqa_local/tabs/base.py +16 -0
- src/legisqa_local/tabs/guide_tab.py +52 -0
- src/legisqa_local/tabs/rag_sbs_tab.py +81 -0
- src/legisqa_local/tabs/rag_tab.py +57 -0
- src/legisqa_local/utils/__init__.py +1 -0
- src/legisqa_local/utils/formatting.py +109 -0
- src/legisqa_local/utils/text.py +55 -0
- src/legisqa_local/utils/usage.py +47 -0
.gitignore
CHANGED
|
@@ -8,5 +8,6 @@ __pycache__/
|
|
| 8 |
.venv/
|
| 9 |
*.log
|
| 10 |
.python-version
|
|
|
|
| 11 |
chromadb
|
| 12 |
chromadb/
|
|
|
|
| 8 |
.venv/
|
| 9 |
*.log
|
| 10 |
.python-version
|
| 11 |
+
src/legisqa_local.egg-info/
|
| 12 |
chromadb
|
| 13 |
chromadb/
|
src/legisqa_local/components/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Reusable UI components for the Streamlit interface"""
|
src/legisqa_local/components/display.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Display components for LegisQA"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from legisqa_local.utils.text import escape_markdown, replace_legis_ids_with_urls
|
| 5 |
+
from legisqa_local.utils.usage import display_api_usage
|
| 6 |
+
from legisqa_local.utils.formatting import render_retrieved_chunks
|
| 7 |
+
from legisqa_local.config.models import PROVIDER_MODELS
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def render_example_queries():
|
| 11 |
+
"""Render example queries in an expander"""
|
| 12 |
+
with st.expander("Example Queries"):
|
| 13 |
+
st.write(
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
What are the themes around artificial intelligence?
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
Write a well cited 3 paragraph essay on food insecurity.
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
Create a table summarizing major climate change ideas with columns legis_id, title, idea.
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
Write an action plan to keep social security solvent.
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
Suggest reforms that would benefit the Medicaid program.
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def render_response(
|
| 41 |
+
response: dict,
|
| 42 |
+
model_info: dict,
|
| 43 |
+
provider: str,
|
| 44 |
+
should_escape_markdown: bool,
|
| 45 |
+
should_add_legis_urls: bool,
|
| 46 |
+
tag: str | None = None,
|
| 47 |
+
):
|
| 48 |
+
"""Render a RAG response with usage information and retrieved chunks"""
|
| 49 |
+
response_text = response["aimessage"].content
|
| 50 |
+
if should_escape_markdown:
|
| 51 |
+
response_text = escape_markdown(response_text)
|
| 52 |
+
if should_add_legis_urls:
|
| 53 |
+
response_text = replace_legis_ids_with_urls(response_text)
|
| 54 |
+
|
| 55 |
+
with st.container(border=True):
|
| 56 |
+
if tag is None:
|
| 57 |
+
st.write("Response")
|
| 58 |
+
else:
|
| 59 |
+
st.write(f"Response ({tag})")
|
| 60 |
+
st.info(response_text)
|
| 61 |
+
|
| 62 |
+
display_api_usage(response["aimessage"], model_info, provider, tag=tag)
|
| 63 |
+
render_retrieved_chunks(response["docs"], tag=tag)
|
src/legisqa_local/components/forms.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Form components for configuration in LegisQA"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from legisqa_local.config.models import PROVIDER_MODELS, CONGRESS_NUMBERS, SPONSOR_PARTIES
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_generative_config(key_prefix: str) -> dict:
|
| 8 |
+
"""Render generative model configuration form"""
|
| 9 |
+
output = {}
|
| 10 |
+
|
| 11 |
+
key = "provider"
|
| 12 |
+
output[key] = st.selectbox(
|
| 13 |
+
label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
key = "model_name"
|
| 17 |
+
output[key] = st.selectbox(
|
| 18 |
+
label=key,
|
| 19 |
+
options=PROVIDER_MODELS[output["provider"]],
|
| 20 |
+
key=f"{key_prefix}|{key}",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
key = "temperature"
|
| 24 |
+
output[key] = st.slider(
|
| 25 |
+
key,
|
| 26 |
+
min_value=0.0,
|
| 27 |
+
max_value=2.0,
|
| 28 |
+
value=0.0,
|
| 29 |
+
key=f"{key_prefix}|{key}",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
key = "max_output_tokens"
|
| 33 |
+
output[key] = st.slider(
|
| 34 |
+
key,
|
| 35 |
+
min_value=8192,
|
| 36 |
+
max_value=16_384,
|
| 37 |
+
key=f"{key_prefix}|{key}",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
key = "should_escape_markdown"
|
| 41 |
+
output[key] = st.checkbox(
|
| 42 |
+
key,
|
| 43 |
+
value=False,
|
| 44 |
+
key=f"{key_prefix}|{key}",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
key = "should_add_legis_urls"
|
| 48 |
+
output[key] = st.checkbox(
|
| 49 |
+
key,
|
| 50 |
+
value=True,
|
| 51 |
+
key=f"{key_prefix}|{key}",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_retrieval_config(key_prefix: str) -> dict:
|
| 58 |
+
"""Render retrieval configuration form"""
|
| 59 |
+
output = {}
|
| 60 |
+
|
| 61 |
+
key = "n_ret_docs"
|
| 62 |
+
output[key] = st.slider(
|
| 63 |
+
"Number of chunks to retrieve",
|
| 64 |
+
min_value=1,
|
| 65 |
+
max_value=32,
|
| 66 |
+
value=8,
|
| 67 |
+
key=f"{key_prefix}|{key}",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
key = "filter_legis_id"
|
| 71 |
+
output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}")
|
| 72 |
+
|
| 73 |
+
key = "filter_bioguide_id"
|
| 74 |
+
output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}")
|
| 75 |
+
|
| 76 |
+
key = "filter_congress_nums"
|
| 77 |
+
output[key] = st.multiselect(
|
| 78 |
+
"Congress Numbers",
|
| 79 |
+
CONGRESS_NUMBERS,
|
| 80 |
+
default=CONGRESS_NUMBERS[-2:],
|
| 81 |
+
key=f"{key_prefix}|{key}",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
key = "filter_sponsor_parties"
|
| 85 |
+
output[key] = st.multiselect(
|
| 86 |
+
"Sponsor Party",
|
| 87 |
+
SPONSOR_PARTIES,
|
| 88 |
+
default=SPONSOR_PARTIES,
|
| 89 |
+
key=f"{key_prefix}|{key}",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return output
|
src/legisqa_local/components/sidebar.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sidebar components for LegisQA"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import os
|
| 5 |
+
from legisqa_local.config.settings import get_chroma_config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def render_chromadb_status():
|
| 9 |
+
"""Render ChromaDB status in sidebar"""
|
| 10 |
+
st.subheader("🗄️ Vector Database")
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
config = get_chroma_config()
|
| 14 |
+
chromadb_path = config["persist_directory"]
|
| 15 |
+
|
| 16 |
+
if os.path.exists(chromadb_path):
|
| 17 |
+
st.success("✅ ChromaDB Ready")
|
| 18 |
+
st.caption("📊 Using pre-existing database")
|
| 19 |
+
st.caption(f"📁 Collection: {config['collection_name']}")
|
| 20 |
+
st.caption(f"📁 Path: .../{os.path.basename(os.path.dirname(chromadb_path))}")
|
| 21 |
+
else:
|
| 22 |
+
st.error("❌ ChromaDB Not Found")
|
| 23 |
+
st.caption(f"Expected path: {chromadb_path}")
|
| 24 |
+
st.caption("Please check the database path")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
st.error("❌ ChromaDB Configuration Error")
|
| 27 |
+
st.caption(f"Error: {str(e)[:50]}...")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def render_outreach_links():
|
| 31 |
+
"""Render links to external resources"""
|
| 32 |
+
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
|
| 33 |
+
nomic_map_name = "us-congressional-legislation-s1024o256nomic-1"
|
| 34 |
+
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
|
| 35 |
+
hf_url = "https://huggingface.co/hyperdemocracy"
|
| 36 |
+
chroma_url = "https://www.trychroma.com/"
|
| 37 |
+
together_url = "https://www.together.ai/"
|
| 38 |
+
google_gemini_url = "https://ai.google.dev/gemini-api"
|
| 39 |
+
anthropic_url = "https://www.anthropic.com/api"
|
| 40 |
+
openai_url = "https://platform.openai.com/docs/overview"
|
| 41 |
+
langchain_url = "https://www.langchain.com/"
|
| 42 |
+
|
| 43 |
+
st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})")
|
| 44 |
+
st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
|
| 45 |
+
st.subheader(f":card_file_box: Vector DB [chromadb]({chroma_url})")
|
| 46 |
+
st.subheader(f":pancakes: Inference [together.ai]({together_url})")
|
| 47 |
+
st.subheader(f":eyeglasses: Inference [google-gemini]({google_gemini_url})")
|
| 48 |
+
st.subheader(f":hut: Inference [anthropic]({anthropic_url})")
|
| 49 |
+
st.subheader(f":sparkles: Inference [openai]({openai_url})")
|
| 50 |
+
st.subheader(f":parrot: Orchestration [langchain]({langchain_url})")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def render_sidebar():
|
| 54 |
+
"""Render the complete sidebar"""
|
| 55 |
+
with st.container(border=True):
|
| 56 |
+
render_chromadb_status()
|
| 57 |
+
|
| 58 |
+
with st.container(border=True):
|
| 59 |
+
render_outreach_links()
|
src/legisqa_local/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Configuration module for LegisQA"""
|
src/legisqa_local/config/models.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model configurations for different LLM providers"""
|
| 2 |
+
|
| 3 |
+
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118, 119]
|
| 4 |
+
SPONSOR_PARTIES = ["D", "R", "L", "I"]
|
| 5 |
+
|
| 6 |
+
OPENAI_CHAT_MODELS = {
|
| 7 |
+
"gpt-5-nano": {"cost": {"pmi": 0.05, "pmo": 0.40}},
|
| 8 |
+
"gpt-5-mini": {"cost": {"pmi": 0.25, "pmo": 2.00}},
|
| 9 |
+
"gpt-5": {"cost": {"pmi": 1.25, "pmo": 10.0}},
|
| 10 |
+
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
|
| 11 |
+
"gpt-4o": {"cost": {"pmi": 2.50, "pmo": 10.0}},
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
ANTHROPIC_CHAT_MODELS = {
|
| 15 |
+
"claude-3-5-haiku-20241022": {"cost": {"pmi": 0.80, "pmo": 4.00}},
|
| 16 |
+
"claude-sonnet-4-20250514": {"cost": {"pmi": 3.0, "pmo": 15.0}},
|
| 17 |
+
"claude-opus-4-1-20250805": {"cost": {"pmi": 15.0, "pmo": 75.0}},
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
TOGETHER_CHAT_MODELS = {
|
| 21 |
+
"openai/gpt-oss-20b": {"cost": {"pmi": 0.05, "pmo": 0.20}},
|
| 22 |
+
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free": {"cost": {"pmi": 0.00, "pmo": 0.00}},
|
| 23 |
+
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}},
|
| 24 |
+
"meta-llama/Llama-3.3-70B-Instruct-Turbo": {"cost": {"pmi": 0.88, "pmo": 0.88}},
|
| 25 |
+
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {"cost": {"pmi": 3.50, "pmo": 3.50}},
|
| 26 |
+
"Qwen/Qwen3-235B-A22B-Thinking-2507": {"cost": {"pmi": 0.65, "pmo": 3.00}},
|
| 27 |
+
"moonshotai/Kimi-K2-Instruct": {"cost": {"pmi": 1.00, "pmo": 3.00}},
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
GOOGLE_CHAT_MODELS = {
|
| 31 |
+
"gemini-2.5-flash-lite": {"cost": {"pmi": 0.10, "pmo": 0.40}},
|
| 32 |
+
"gemini-2.5-flash": {"cost": {"pmi": 0.30, "pmo": 2.50}},
|
| 33 |
+
"gemini-2.5-pro": {"cost": {"pmi": 1.25, "pmo": 10.0}},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
PROVIDER_MODELS = {
|
| 37 |
+
"OpenAI": OPENAI_CHAT_MODELS,
|
| 38 |
+
"Anthropic": ANTHROPIC_CHAT_MODELS,
|
| 39 |
+
"Together": TOGETHER_CHAT_MODELS,
|
| 40 |
+
"Google": GOOGLE_CHAT_MODELS,
|
| 41 |
+
}
|
src/legisqa_local/config/settings.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application settings and configuration"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
# Streamlit configuration
|
| 7 |
+
STREAMLIT_CONFIG = {
|
| 8 |
+
"layout": "wide",
|
| 9 |
+
"page_title": "LegisQA"
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
def get_secret(key: str, default=None):
|
| 13 |
+
"""Get secret from Streamlit secrets or environment variables"""
|
| 14 |
+
try:
|
| 15 |
+
# Try Streamlit secrets first (for local development)
|
| 16 |
+
return st.secrets[key]
|
| 17 |
+
except (KeyError, FileNotFoundError):
|
| 18 |
+
# Fall back to environment variables (for Docker/HF Spaces)
|
| 19 |
+
return os.getenv(key, default)
|
| 20 |
+
|
| 21 |
+
# Environment variables setup
|
| 22 |
+
def setup_environment():
|
| 23 |
+
"""Setup environment variables for the application"""
|
| 24 |
+
os.environ["LANGCHAIN_API_KEY"] = get_secret("langchain_api_key", "")
|
| 25 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
| 26 |
+
os.environ["LANGCHAIN_PROJECT"] = get_secret("langchain_project", "legisqa-local")
|
| 27 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 28 |
+
|
| 29 |
+
# ChromaDB configuration
|
| 30 |
+
def get_chroma_config():
|
| 31 |
+
"""Get ChromaDB configuration from environment variables"""
|
| 32 |
+
return {
|
| 33 |
+
"persist_directory": os.getenv("CHROMA_PERSIST_DIRECTORY", "./chromadb"),
|
| 34 |
+
"collection_name": os.getenv("CHROMA_COLLECTION_NAME", "usc")
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Embedding model configuration
|
| 38 |
+
EMBEDDING_MODEL = "sentence-transformers/static-retrieval-mrl-en-v1"
|
| 39 |
+
EMBEDDING_DEVICE = "cpu"
|
src/legisqa_local/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core business logic for LegisQA"""
|
src/legisqa_local/core/embeddings.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding functionality for LegisQA"""
|
| 2 |
+
|
| 3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
+
from legisqa_local.config.settings import EMBEDDING_MODEL, EMBEDDING_DEVICE
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_embeddings():
|
| 8 |
+
"""Load and return the embedding function"""
|
| 9 |
+
model_kwargs = {"device": EMBEDDING_DEVICE}
|
| 10 |
+
emb_fn = HuggingFaceEmbeddings(
|
| 11 |
+
model_name=EMBEDDING_MODEL,
|
| 12 |
+
model_kwargs=model_kwargs,
|
| 13 |
+
)
|
| 14 |
+
return emb_fn
|
src/legisqa_local/core/llm.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM provider implementations for LegisQA"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
from langchain_anthropic import ChatAnthropic
|
| 6 |
+
from langchain_together import ChatTogether
|
| 7 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 8 |
+
from legisqa_local.config.settings import get_secret
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_llm(gen_config: dict):
|
| 12 |
+
"""Get LLM instance based on configuration"""
|
| 13 |
+
|
| 14 |
+
match gen_config["provider"]:
|
| 15 |
+
|
| 16 |
+
case "OpenAI":
|
| 17 |
+
llm = ChatOpenAI(
|
| 18 |
+
model=gen_config["model_name"],
|
| 19 |
+
temperature=gen_config["temperature"],
|
| 20 |
+
api_key=get_secret("openai_api_key"),
|
| 21 |
+
max_tokens=gen_config["max_output_tokens"],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
case "Anthropic":
|
| 25 |
+
llm = ChatAnthropic(
|
| 26 |
+
model_name=gen_config["model_name"],
|
| 27 |
+
temperature=gen_config["temperature"],
|
| 28 |
+
api_key=get_secret("anthropic_api_key"),
|
| 29 |
+
max_tokens_to_sample=gen_config["max_output_tokens"],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
case "Together":
|
| 33 |
+
llm = ChatTogether(
|
| 34 |
+
model=gen_config["model_name"],
|
| 35 |
+
temperature=gen_config["temperature"],
|
| 36 |
+
max_tokens=gen_config["max_output_tokens"],
|
| 37 |
+
api_key=get_secret("together_api_key"),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
case "Google":
|
| 41 |
+
llm = ChatGoogleGenerativeAI(
|
| 42 |
+
model=gen_config["model_name"],
|
| 43 |
+
temperature=gen_config["temperature"],
|
| 44 |
+
api_key=get_secret("google_api_key"),
|
| 45 |
+
max_output_tokens=gen_config["max_output_tokens"],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
case _:
|
| 49 |
+
raise ValueError(f"Unknown provider: {gen_config['provider']}")
|
| 50 |
+
|
| 51 |
+
return llm
|
src/legisqa_local/core/rag.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG (Retrieval-Augmented Generation) chain implementation"""
|
| 2 |
+
|
| 3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 5 |
+
|
| 6 |
+
from legisqa_local.core.llm import get_llm
|
| 7 |
+
from legisqa_local.core.vectorstore import load_vectorstore, get_vectorstore_filter
|
| 8 |
+
from legisqa_local.utils.formatting import format_docs
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_rag_chain(llm, retriever):
|
| 12 |
+
"""Create a RAG chain with the given LLM and retriever"""
|
| 13 |
+
|
| 14 |
+
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user.
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
Congressional Legislation Excerpts:
|
| 19 |
+
|
| 20 |
+
{context}
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
Query: {query}"""
|
| 25 |
+
|
| 26 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 27 |
+
("human", QUERY_RAG_TEMPLATE),
|
| 28 |
+
])
|
| 29 |
+
|
| 30 |
+
rag_chain = (
|
| 31 |
+
RunnableParallel({
|
| 32 |
+
"docs": retriever,
|
| 33 |
+
"query": RunnablePassthrough(),
|
| 34 |
+
})
|
| 35 |
+
.assign(context=lambda x: format_docs(x["docs"]))
|
| 36 |
+
.assign(aimessage=prompt | llm)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return rag_chain
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def process_query(gen_config: dict, ret_config: dict, query: str):
|
| 43 |
+
"""Process a query using RAG"""
|
| 44 |
+
vectorstore = load_vectorstore()
|
| 45 |
+
llm = get_llm(gen_config)
|
| 46 |
+
vs_filter = get_vectorstore_filter(ret_config)
|
| 47 |
+
|
| 48 |
+
# ChromaDB uses 'filter' parameter in search_kwargs
|
| 49 |
+
search_kwargs = {"k": ret_config["n_ret_docs"]}
|
| 50 |
+
if vs_filter:
|
| 51 |
+
search_kwargs["filter"] = vs_filter
|
| 52 |
+
|
| 53 |
+
retriever = vectorstore.as_retriever(search_kwargs=search_kwargs)
|
| 54 |
+
rag_chain = create_rag_chain(llm, retriever)
|
| 55 |
+
response = rag_chain.invoke(query)
|
| 56 |
+
return response
|
src/legisqa_local/core/vectorstore.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector store operations for LegisQA"""
|
| 2 |
+
|
| 3 |
+
from langchain_chroma import Chroma
|
| 4 |
+
from legisqa_local.core.embeddings import load_embeddings
|
| 5 |
+
from legisqa_local.config.settings import get_chroma_config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_vectorstore():
|
| 9 |
+
"""Load and return the ChromaDB vectorstore"""
|
| 10 |
+
config = get_chroma_config()
|
| 11 |
+
emb_fn = load_embeddings()
|
| 12 |
+
|
| 13 |
+
vectorstore = Chroma(
|
| 14 |
+
persist_directory=config["persist_directory"],
|
| 15 |
+
collection_name=config["collection_name"],
|
| 16 |
+
embedding_function=emb_fn,
|
| 17 |
+
)
|
| 18 |
+
return vectorstore
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_vectorstore_filter(ret_config: dict) -> dict:
|
| 22 |
+
"""Get filter dict for ChromaDB queries"""
|
| 23 |
+
where_clause = {}
|
| 24 |
+
|
| 25 |
+
if ret_config["filter_legis_id"] != "":
|
| 26 |
+
where_clause["legis_id"] = ret_config["filter_legis_id"]
|
| 27 |
+
|
| 28 |
+
if ret_config["filter_congress_nums"]:
|
| 29 |
+
where_clause["congress_num"] = {"$in": ret_config["filter_congress_nums"]}
|
| 30 |
+
|
| 31 |
+
return where_clause
|
src/legisqa_local/tabs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Tab implementations for the Streamlit interface"""
|
src/legisqa_local/tabs/base.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base tab interface for LegisQA"""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseTab(ABC):
|
| 7 |
+
"""Base class for tab implementations"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, name: str, key_prefix: str):
|
| 10 |
+
self.name = name
|
| 11 |
+
self.key_prefix = key_prefix
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def render(self):
|
| 15 |
+
"""Render the tab content"""
|
| 16 |
+
pass
|
src/legisqa_local/tabs/guide_tab.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Guide tab implementation"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from legisqa_local.tabs.base import BaseTab
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GuideTab(BaseTab):
|
| 8 |
+
"""Guide and documentation tab"""
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super().__init__("Guide", "guide")
|
| 12 |
+
|
| 13 |
+
def render(self):
|
| 14 |
+
"""Render the guide tab"""
|
| 15 |
+
st.write(
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# LegisQA Guide
|
| 19 |
+
|
| 20 |
+
Welcome to LegisQA! This tool allows you to query congressional legislation using natural language.
|
| 21 |
+
|
| 22 |
+
## How to Use
|
| 23 |
+
|
| 24 |
+
1. **Choose a Tab**: Select between single RAG queries or side-by-side comparisons
|
| 25 |
+
2. **Enter Your Query**: Ask questions about congressional legislation in natural language
|
| 26 |
+
3. **Configure Settings**: Adjust model and retrieval parameters as needed
|
| 27 |
+
4. **Submit**: Click submit and wait for the AI to generate a response
|
| 28 |
+
|
| 29 |
+
## Example Queries
|
| 30 |
+
|
| 31 |
+
- "What are the main themes around artificial intelligence legislation?"
|
| 32 |
+
- "Write a summary of recent climate change bills"
|
| 33 |
+
- "Create a table of healthcare reform proposals"
|
| 34 |
+
- "What bills address social security reform?"
|
| 35 |
+
|
| 36 |
+
## Features
|
| 37 |
+
|
| 38 |
+
- **Multiple LLM Providers**: OpenAI, Anthropic, Together.ai, Google
|
| 39 |
+
- **Flexible Retrieval**: Filter by congress number, bill ID, sponsor party
|
| 40 |
+
- **Citation Support**: Responses include links to original legislation
|
| 41 |
+
- **Cost Tracking**: Monitor API usage and costs
|
| 42 |
+
- **Side-by-Side**: Compare responses from different models
|
| 43 |
+
|
| 44 |
+
## Tips
|
| 45 |
+
|
| 46 |
+
- Be specific in your queries for better results
|
| 47 |
+
- Use the retrieval filters to narrow down the search space
|
| 48 |
+
- Try different models to compare response quality
|
| 49 |
+
- Check the retrieved chunks to understand the source material
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
)
|
src/legisqa_local/tabs/rag_sbs_tab.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Side-by-side RAG tab implementation"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from legisqa_local.tabs.base import BaseTab
|
| 5 |
+
from legisqa_local.components.forms import get_generative_config, get_retrieval_config
|
| 6 |
+
from legisqa_local.components.display import render_response
|
| 7 |
+
from legisqa_local.core.rag import process_query
|
| 8 |
+
from legisqa_local.config.models import PROVIDER_MODELS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RAGSideBySideTab(BaseTab):
|
| 12 |
+
"""Side-by-side RAG comparison tab"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__("RAG (side-by-side)", "query_rag_sbs")
|
| 16 |
+
|
| 17 |
+
def render(self):
|
| 18 |
+
"""Render the side-by-side RAG tab"""
|
| 19 |
+
SS = st.session_state
|
| 20 |
+
|
| 21 |
+
with st.form(f"{self.key_prefix}|query_form"):
|
| 22 |
+
query = st.text_area(
|
| 23 |
+
"Enter a query that can be answered with congressional legislation:"
|
| 24 |
+
)
|
| 25 |
+
cols = st.columns(2)
|
| 26 |
+
with cols[0]:
|
| 27 |
+
query_submitted = st.form_submit_button("Submit")
|
| 28 |
+
with cols[1]:
|
| 29 |
+
status_placeholder = st.empty()
|
| 30 |
+
|
| 31 |
+
grp1a, grp2a = st.columns(2)
|
| 32 |
+
|
| 33 |
+
gen_configs = {}
|
| 34 |
+
ret_configs = {}
|
| 35 |
+
with grp1a:
|
| 36 |
+
st.header("Group 1")
|
| 37 |
+
key_prefix = f"{self.key_prefix}|grp1"
|
| 38 |
+
with st.expander("Generative Config"):
|
| 39 |
+
gen_configs["grp1"] = get_generative_config(key_prefix)
|
| 40 |
+
with st.expander("Retrieval Config"):
|
| 41 |
+
ret_configs["grp1"] = get_retrieval_config(key_prefix)
|
| 42 |
+
|
| 43 |
+
with grp2a:
|
| 44 |
+
st.header("Group 2")
|
| 45 |
+
key_prefix = f"{self.key_prefix}|grp2"
|
| 46 |
+
with st.expander("Generative Config"):
|
| 47 |
+
gen_configs["grp2"] = get_generative_config(key_prefix)
|
| 48 |
+
with st.expander("Retrieval Config"):
|
| 49 |
+
ret_configs["grp2"] = get_retrieval_config(key_prefix)
|
| 50 |
+
|
| 51 |
+
grp1b, grp2b = st.columns(2)
|
| 52 |
+
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
|
| 53 |
+
grp_names = {"grp1": "Group 1", "grp2": "Group 2"}
|
| 54 |
+
|
| 55 |
+
for post_key_prefix in ["grp1", "grp2"]:
|
| 56 |
+
with sbs_cols[post_key_prefix]:
|
| 57 |
+
key_prefix = f"{self.key_prefix}|{post_key_prefix}"
|
| 58 |
+
rkey = f"{key_prefix}|response"
|
| 59 |
+
if query_submitted:
|
| 60 |
+
with status_placeholder:
|
| 61 |
+
with st.spinner(
|
| 62 |
+
"generating response for {}".format(grp_names[post_key_prefix])
|
| 63 |
+
):
|
| 64 |
+
SS[rkey] = process_query(
|
| 65 |
+
gen_configs[post_key_prefix],
|
| 66 |
+
ret_configs[post_key_prefix],
|
| 67 |
+
query,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if response := SS.get(rkey):
|
| 71 |
+
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
|
| 72 |
+
gen_configs[post_key_prefix]["model_name"]
|
| 73 |
+
]
|
| 74 |
+
render_response(
|
| 75 |
+
response,
|
| 76 |
+
model_info,
|
| 77 |
+
gen_configs[post_key_prefix]["provider"],
|
| 78 |
+
gen_configs[post_key_prefix]["should_escape_markdown"],
|
| 79 |
+
gen_configs[post_key_prefix]["should_add_legis_urls"],
|
| 80 |
+
tag=grp_names[post_key_prefix],
|
| 81 |
+
)
|
src/legisqa_local/tabs/rag_tab.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Single RAG tab implementation"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from legisqa_local.tabs.base import BaseTab
|
| 5 |
+
from legisqa_local.components.forms import get_generative_config, get_retrieval_config
|
| 6 |
+
from legisqa_local.components.display import render_example_queries, render_response
|
| 7 |
+
from legisqa_local.core.rag import process_query
|
| 8 |
+
from legisqa_local.config.models import PROVIDER_MODELS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RAGTab(BaseTab):
|
| 12 |
+
"""Single RAG query tab"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__("RAG", "query_rag")
|
| 16 |
+
|
| 17 |
+
def render(self):
|
| 18 |
+
"""Render the RAG tab"""
|
| 19 |
+
SS = st.session_state
|
| 20 |
+
render_example_queries()
|
| 21 |
+
|
| 22 |
+
with st.form(f"{self.key_prefix}|query_form"):
|
| 23 |
+
query = st.text_area(
|
| 24 |
+
"Enter a query that can be answered with congressional legislation:"
|
| 25 |
+
)
|
| 26 |
+
cols = st.columns(2)
|
| 27 |
+
with cols[0]:
|
| 28 |
+
query_submitted = st.form_submit_button("Submit")
|
| 29 |
+
with cols[1]:
|
| 30 |
+
status_placeholder = st.empty()
|
| 31 |
+
|
| 32 |
+
col1, col2 = st.columns(2)
|
| 33 |
+
with col1:
|
| 34 |
+
with st.expander("Generative Config"):
|
| 35 |
+
gen_config = get_generative_config(self.key_prefix)
|
| 36 |
+
with col2:
|
| 37 |
+
with st.expander("Retrieval Config"):
|
| 38 |
+
ret_config = get_retrieval_config(self.key_prefix)
|
| 39 |
+
|
| 40 |
+
rkey = f"{self.key_prefix}|response"
|
| 41 |
+
if query_submitted:
|
| 42 |
+
with status_placeholder:
|
| 43 |
+
with st.spinner("generating response"):
|
| 44 |
+
SS[rkey] = process_query(gen_config, ret_config, query)
|
| 45 |
+
|
| 46 |
+
if response := SS.get(rkey):
|
| 47 |
+
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
|
| 48 |
+
render_response(
|
| 49 |
+
response,
|
| 50 |
+
model_info,
|
| 51 |
+
gen_config["provider"],
|
| 52 |
+
gen_config["should_escape_markdown"],
|
| 53 |
+
gen_config["should_add_legis_urls"],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
with st.expander("Debug"):
|
| 57 |
+
st.write(response)
|
src/legisqa_local/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions and helpers"""
|
src/legisqa_local/utils/formatting.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document formatting utilities for LegisQA"""
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from langchain.schema import Document
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
from legisqa_local.utils.text import get_congress_gov_url, get_sponsor_url, escape_markdown
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
| 13 |
+
"""Group and sort docs by legis_id.
|
| 14 |
+
|
| 15 |
+
docs are grouped by legis_id
|
| 16 |
+
inside a legis_id group, the docs are sorted by start_index
|
| 17 |
+
overall the legis_id groups are sorted by number of docs (desc)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
doc_grps = [
|
| 21 |
+
(legis_id, start_index sorted docs), # group with the most docs
|
| 22 |
+
(legis_id, start_index sorted docs),
|
| 23 |
+
...
|
| 24 |
+
(legis_id, start_index sorted docs), # group with the least docs
|
| 25 |
+
]
|
| 26 |
+
"""
|
| 27 |
+
doc_grps = defaultdict(list)
|
| 28 |
+
|
| 29 |
+
# create legis_id groups
|
| 30 |
+
for doc in docs:
|
| 31 |
+
doc_grps[doc.metadata["legis_id"]].append(doc)
|
| 32 |
+
|
| 33 |
+
# sort docs in each group by start index
|
| 34 |
+
for legis_id in doc_grps.keys():
|
| 35 |
+
doc_grps[legis_id] = sorted(
|
| 36 |
+
doc_grps[legis_id],
|
| 37 |
+
key=lambda x: x.metadata["start_index"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# sort groups by number of docs
|
| 41 |
+
doc_grps = sorted(
|
| 42 |
+
tuple(doc_grps.items()),
|
| 43 |
+
key=lambda x: (
|
| 44 |
+
-len(x[1]), # length of x[1] = number of chunks
|
| 45 |
+
x[0], # legis_id for deterministic sort
|
| 46 |
+
),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
return doc_grps
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def format_docs(docs: list[Document]) -> str:
|
| 53 |
+
"""Format documents as JSON for RAG context"""
|
| 54 |
+
doc_grps = group_docs(docs)
|
| 55 |
+
out = []
|
| 56 |
+
for legis_id, doc_grp in doc_grps:
|
| 57 |
+
dd = {
|
| 58 |
+
"legis_id": doc_grp[0].metadata["legis_id"],
|
| 59 |
+
"title": doc_grp[0].metadata["title"],
|
| 60 |
+
"introduced_date": doc_grp[0].metadata["introduced_date"],
|
| 61 |
+
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
|
| 62 |
+
"snippets": [doc.page_content for doc in doc_grp],
|
| 63 |
+
}
|
| 64 |
+
out.append(dd)
|
| 65 |
+
return json.dumps(out, indent=4)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def render_doc_grp(legis_id: str, doc_grp: list[Document]):
|
| 69 |
+
"""Render a group of documents from the same legislation"""
|
| 70 |
+
first_doc = doc_grp[0]
|
| 71 |
+
|
| 72 |
+
congress_gov_url = get_congress_gov_url(
|
| 73 |
+
first_doc.metadata["congress_num"],
|
| 74 |
+
first_doc.metadata["legis_type"],
|
| 75 |
+
first_doc.metadata["legis_num"],
|
| 76 |
+
)
|
| 77 |
+
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
| 78 |
+
|
| 79 |
+
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
| 80 |
+
len(doc_grp),
|
| 81 |
+
first_doc.metadata["legis_id"],
|
| 82 |
+
first_doc.metadata["title"],
|
| 83 |
+
congress_gov_link,
|
| 84 |
+
first_doc.metadata["sponsor_full_name"],
|
| 85 |
+
first_doc.metadata["sponsor_bioguide_id"],
|
| 86 |
+
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
|
| 87 |
+
)
|
| 88 |
+
doc_contents = [
|
| 89 |
+
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
|
| 90 |
+
for doc in doc_grp
|
| 91 |
+
]
|
| 92 |
+
with st.expander(ref):
|
| 93 |
+
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def render_retrieved_chunks(docs: list[Document], tag: str | None = None):
|
| 97 |
+
"""Render all retrieved document chunks"""
|
| 98 |
+
with st.container(border=True):
|
| 99 |
+
doc_grps = group_docs(docs)
|
| 100 |
+
if tag is None:
|
| 101 |
+
st.write(
|
| 102 |
+
"Retrieved Chunks\n\nleft click to expand, right click to follow links"
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
st.write(
|
| 106 |
+
f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
|
| 107 |
+
)
|
| 108 |
+
for legis_id, doc_grp in doc_grps:
|
| 109 |
+
render_doc_grp(legis_id, doc_grp)
|
src/legisqa_local/utils/text.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text processing utilities for LegisQA"""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
CONGRESS_GOV_TYPE_MAP = {
|
| 7 |
+
"hconres": "house-concurrent-resolution",
|
| 8 |
+
"hjres": "house-joint-resolution",
|
| 9 |
+
"hr": "house-bill",
|
| 10 |
+
"hres": "house-resolution",
|
| 11 |
+
"s": "senate-bill",
|
| 12 |
+
"sconres": "senate-concurrent-resolution",
|
| 13 |
+
"sjres": "senate-joint-resolution",
|
| 14 |
+
"sres": "senate-resolution",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def escape_markdown(text: str) -> str:
|
| 19 |
+
"""Escape markdown special characters in text"""
|
| 20 |
+
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
| 21 |
+
for char in MD_SPECIAL_CHARS:
|
| 22 |
+
text = text.replace(char, "\\" + char)
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_sponsor_url(bioguide_id: str) -> str:
|
| 27 |
+
"""Generate URL for a sponsor's bioguide page"""
|
| 28 |
+
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
|
| 32 |
+
"""Generate Congress.gov URL for a piece of legislation"""
|
| 33 |
+
lt = CONGRESS_GOV_TYPE_MAP[legis_type]
|
| 34 |
+
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def legis_id_to_link(legis_id: str) -> str:
|
| 38 |
+
"""Convert a legislation ID to a Congress.gov URL"""
|
| 39 |
+
congress_num, legis_type, legis_num = legis_id.split("-")
|
| 40 |
+
return get_congress_gov_url(congress_num, legis_type, legis_num)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def legis_id_match_to_link(matchobj):
|
| 44 |
+
"""Convert a regex match object to a markdown link"""
|
| 45 |
+
mstring = matchobj.string[matchobj.start() : matchobj.end()]
|
| 46 |
+
url = legis_id_to_link(mstring)
|
| 47 |
+
link = f"[{mstring}]({url})"
|
| 48 |
+
return link
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def replace_legis_ids_with_urls(text: str) -> str:
|
| 52 |
+
"""Replace legislation IDs in text with markdown links"""
|
| 53 |
+
pattern = "11[345678]-[a-z]+-\\d{1,5}"
|
| 54 |
+
rtext = re.sub(pattern, legis_id_match_to_link, text)
|
| 55 |
+
return rtext
|
src/legisqa_local/utils/usage.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Usage tracking utilities for LegisQA"""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from langchain_core.messages import AIMessage
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_token_usage_for_provider(aimessage: AIMessage, model_info: dict, provider: str):
|
| 8 |
+
"""Get token usage information for any provider"""
|
| 9 |
+
input_tokens = aimessage.usage_metadata["input_tokens"]
|
| 10 |
+
output_tokens = aimessage.usage_metadata["output_tokens"]
|
| 11 |
+
cost = (
|
| 12 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
| 13 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
| 14 |
+
)
|
| 15 |
+
return {
|
| 16 |
+
"input_tokens": input_tokens,
|
| 17 |
+
"output_tokens": output_tokens,
|
| 18 |
+
"cost": cost,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str):
|
| 23 |
+
"""Get token usage based on provider"""
|
| 24 |
+
# All providers use the same calculation now
|
| 25 |
+
return get_token_usage_for_provider(aimessage, model_info, provider)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def display_api_usage(
|
| 29 |
+
aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None
|
| 30 |
+
):
|
| 31 |
+
"""Display API usage information in Streamlit"""
|
| 32 |
+
with st.container(border=True):
|
| 33 |
+
if tag is None:
|
| 34 |
+
st.write("API Usage")
|
| 35 |
+
else:
|
| 36 |
+
st.write(f"API Usage ({tag})")
|
| 37 |
+
token_usage = get_token_usage(aimessage, model_info, provider)
|
| 38 |
+
col1, col2, col3 = st.columns(3)
|
| 39 |
+
with col1:
|
| 40 |
+
st.metric("Input Tokens", token_usage["input_tokens"])
|
| 41 |
+
with col2:
|
| 42 |
+
st.metric("Output Tokens", token_usage["output_tokens"])
|
| 43 |
+
with col3:
|
| 44 |
+
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
| 45 |
+
with st.expander("AIMessage Metadata"):
|
| 46 |
+
dd = {key: val for key, val in aimessage.dict().items() if key != "content"}
|
| 47 |
+
st.write(dd)
|