File size: 3,567 Bytes
ac2020e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Document formatting utilities for LegisQA"""

from collections import defaultdict
import json

from langchain.schema import Document
import streamlit as st

from legisqa_local.utils.text import get_congress_gov_url, get_sponsor_url, escape_markdown


def group_docs(docs) -> list[tuple[str, list[Document]]]:
    """Group and sort docs by legis_id.

    docs are grouped by legis_id
    inside a legis_id group, the docs are sorted by start_index
    overall the legis_id groups are sorted by number of docs (desc)

    Returns:
        doc_grps = [
            (legis_id, start_index sorted docs), # group with the most docs
            (legis_id, start_index sorted docs),
            ...
            (legis_id, start_index sorted docs), # group with the least docs
        ]
    """
    doc_grps = defaultdict(list)

    # create legis_id groups
    for doc in docs:
        doc_grps[doc.metadata["legis_id"]].append(doc)

    # sort docs in each group by start index
    for legis_id in doc_grps.keys():
        doc_grps[legis_id] = sorted(
            doc_grps[legis_id],
            key=lambda x: x.metadata["start_index"],
        )

    # sort groups by number of docs
    doc_grps = sorted(
        tuple(doc_grps.items()),
        key=lambda x: (
            -len(x[1]),  # length of x[1] = number of chunks
            x[0],  # legis_id for deterministic sort
        ),
    )

    return doc_grps


def format_docs(docs: list[Document]) -> str:
    """Format documents as JSON for RAG context"""
    doc_grps = group_docs(docs)
    out = []
    for legis_id, doc_grp in doc_grps:
        dd = {
            "legis_id": doc_grp[0].metadata["legis_id"],
            "title": doc_grp[0].metadata["title"],
            "introduced_date": doc_grp[0].metadata["introduced_date"],
            "sponsor": doc_grp[0].metadata["sponsor_full_name"],
            "snippets": [doc.page_content for doc in doc_grp],
        }
        out.append(dd)
    return json.dumps(out, indent=4)


def render_doc_grp(legis_id: str, doc_grp: list[Document]):
    """Render a group of documents from the same legislation"""
    first_doc = doc_grp[0]

    congress_gov_url = get_congress_gov_url(
        first_doc.metadata["congress_num"],
        first_doc.metadata["legis_type"],
        first_doc.metadata["legis_num"],
    )
    congress_gov_link = f"[congress.gov]({congress_gov_url})"

    ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
        len(doc_grp),
        first_doc.metadata["legis_id"],
        first_doc.metadata["title"],
        congress_gov_link,
        first_doc.metadata["sponsor_full_name"],
        first_doc.metadata["sponsor_bioguide_id"],
        get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
    )
    doc_contents = [
        "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
        for doc in doc_grp
    ]
    with st.expander(ref):
        st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))


def render_retrieved_chunks(docs: list[Document], tag: str | None = None):
    """Render all retrieved document chunks"""
    with st.container(border=True):
        doc_grps = group_docs(docs)
        if tag is None:
            st.write(
                "Retrieved Chunks\n\nleft click to expand, right click to follow links"
            )
        else:
            st.write(
                f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
            )
        for legis_id, doc_grp in doc_grps:
            render_doc_grp(legis_id, doc_grp)