gabrielaltay's picture
update
ac2020e
"""Document formatting utilities for LegisQA"""
from collections import defaultdict
import json
from langchain.schema import Document
import streamlit as st
from legisqa_local.utils.text import get_congress_gov_url, get_sponsor_url, escape_markdown
def group_docs(docs) -> list[tuple[str, list[Document]]]:
"""Group and sort docs by legis_id.
docs are grouped by legis_id
inside a legis_id group, the docs are sorted by start_index
overall the legis_id groups are sorted by number of docs (desc)
Returns:
doc_grps = [
(legis_id, start_index sorted docs), # group with the most docs
(legis_id, start_index sorted docs),
...
(legis_id, start_index sorted docs), # group with the least docs
]
"""
doc_grps = defaultdict(list)
# create legis_id groups
for doc in docs:
doc_grps[doc.metadata["legis_id"]].append(doc)
# sort docs in each group by start index
for legis_id in doc_grps.keys():
doc_grps[legis_id] = sorted(
doc_grps[legis_id],
key=lambda x: x.metadata["start_index"],
)
# sort groups by number of docs
doc_grps = sorted(
tuple(doc_grps.items()),
key=lambda x: (
-len(x[1]), # length of x[1] = number of chunks
x[0], # legis_id for deterministic sort
),
)
return doc_grps
def format_docs(docs: list[Document]) -> str:
"""Format documents as JSON for RAG context"""
doc_grps = group_docs(docs)
out = []
for legis_id, doc_grp in doc_grps:
dd = {
"legis_id": doc_grp[0].metadata["legis_id"],
"title": doc_grp[0].metadata["title"],
"introduced_date": doc_grp[0].metadata["introduced_date"],
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
"snippets": [doc.page_content for doc in doc_grp],
}
out.append(dd)
return json.dumps(out, indent=4)
def render_doc_grp(legis_id: str, doc_grp: list[Document]):
"""Render a group of documents from the same legislation"""
first_doc = doc_grp[0]
congress_gov_url = get_congress_gov_url(
first_doc.metadata["congress_num"],
first_doc.metadata["legis_type"],
first_doc.metadata["legis_num"],
)
congress_gov_link = f"[congress.gov]({congress_gov_url})"
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
len(doc_grp),
first_doc.metadata["legis_id"],
first_doc.metadata["title"],
congress_gov_link,
first_doc.metadata["sponsor_full_name"],
first_doc.metadata["sponsor_bioguide_id"],
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
)
doc_contents = [
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
for doc in doc_grp
]
with st.expander(ref):
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
def render_retrieved_chunks(docs: list[Document], tag: str | None = None):
"""Render all retrieved document chunks"""
with st.container(border=True):
doc_grps = group_docs(docs)
if tag is None:
st.write(
"Retrieved Chunks\n\nleft click to expand, right click to follow links"
)
else:
st.write(
f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
)
for legis_id, doc_grp in doc_grps:
render_doc_grp(legis_id, doc_grp)