Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import time | |
| from datasets import load_dataset | |
| from openai import OpenAI | |
| import pandas as pd | |
| import streamlit as st | |
| st.set_page_config(layout="wide") | |
| CONGRESS_GOV_TYPE_MAP = { | |
| "hconres": "house-concurrent-resolution", | |
| "hjres": "house-joint-resolution", | |
| "hr": "house-bill", | |
| "hres": "house-resolution", | |
| "s": "senate-bill", | |
| "sconres": "senate-concurrent-resolution", | |
| "sjres": "senate-joint-resolution", | |
| "sres": "senate-resolution", | |
| } | |
| def get_data(): | |
| dsd = load_dataset("hyperdemocracy/us-congress", "unified_v1") | |
| df = pd.concat([ds.to_pandas() for ds in dsd.values()]) | |
| df["text"] = df["textversions"].apply(lambda x: x[0]["text_v1"] if len(x) > 0 else "") | |
| df = df[df["text"].str.len() > 0] | |
| df1 = df[df["legis_id"]=="118-s-3207"] | |
| return pd.concat([df1, df.sample(n=100)]) | |
| def escape_markdown(text): | |
| MD_SPECIAL_CHARS = "\`*_{}[]()#+-.!$" | |
| for char in MD_SPECIAL_CHARS: | |
| text = text.replace(char, "\\"+char) | |
| return text | |
| def get_sponsor_url(bioguide_id): | |
| return f"https://bioguide.congress.gov/search/bio/{bioguide_id}" | |
| def get_congress_gov_url(congress_num, legis_type, legis_num): | |
| lt = CONGRESS_GOV_TYPE_MAP[legis_type] | |
| return f"https://www.congress.gov/bill/{congress_num}th-congress/{lt}/{legis_num}" | |
| def show_bill(bdict): | |
| bill_url = get_congress_gov_url( | |
| bdict["congress_num"], | |
| bdict["legis_type"], | |
| bdict["legis_num"], | |
| ) | |
| sponsor_url = get_sponsor_url( | |
| bdict["billstatus_json"]["sponsors"][0]["bioguide_id"] | |
| ) | |
| st.header("Metadata") | |
| st.write("**Bill ID**: [{}]({})".format(bdict["legis_id"], bill_url)) | |
| st.write("**Sponsor**: [{}]({})".format(bdict["billstatus_json"]["sponsors"][0]["full_name"], sponsor_url)) | |
| st.write("**Title**: {}".format(bdict["billstatus_json"]["title"])) | |
| st.write("**Introduced**: {}".format(bdict["billstatus_json"]["introduced_date"])) | |
| st.write("**Policy Area**: {}".format(bdict["billstatus_json"]["policy_area"])) | |
| st.write("**Subjects**: {}".format(bdict["billstatus_json"]["subjects"])) | |
| st.write("**Character Count**: {}".format(len(bdict["text"]))) | |
| st.write("**Estimated Tokens**: {}".format(len(bdict["text"])/4)) | |
| st.header("Summary") | |
| if len(bdict["billstatus_json"]["summaries"]) > 0: | |
| st.write(bdict["billstatus_json"]["summaries"][0]) | |
| # st.markdown(bdict["billstatus_json"]["summaries"][0]["text"], unsafe_allow_html=True) | |
| else: | |
| st.write("Not Available") | |
| st.header("Text") | |
| st.markdown(escape_markdown(bdict["text"])) | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [] | |
| if "openai_model" not in st.session_state: | |
| st.session_state["openai_model"] = "gpt-3.5-turbo-0125" | |
| if "openai_api_key" not in st.session_state: | |
| st.session_state["openai_api_key"] = "" | |
| df = get_data() | |
| with st.sidebar: | |
| error_header = st.empty() | |
| st.header("Configuration") | |
| st.text_input( | |
| label = "OpenAI API Key:", | |
| help="Required for OpenAI Models", | |
| type="password", | |
| key="openai_api_key", | |
| ) | |
| st.write("You can create an OpenAI API key [here](https://platform.openai.com/account/api-keys)") | |
| MODELS = ["gpt-3.5-turbo-0125", "gpt-4-0125-preview"] | |
| st.selectbox("Model Name", MODELS, key="openai_model") | |
| LEGIS_IDS = df["legis_id"].to_list() | |
| st.selectbox("Legis ID", LEGIS_IDS, key="legis_id") | |
| bdict = df[df["legis_id"] == st.session_state["legis_id"]].iloc[0].to_dict() | |
| if st.button("Clear Messages"): | |
| st.session_state["messages"] = [] | |
| st.header("Debug") | |
| with st.expander("Show Messages"): | |
| st.write(st.session_state["messages"]) | |
| with st.expander("Show Bill Dictionary"): | |
| st.write(bdict) | |
| system_message = { | |
| "role": "system", | |
| "content": "You are a helpful legislative question answering assistant. Use the following legislative text to help answer user questions.\n\n" + bdict["text"], | |
| } | |
| with st.expander("Show Bill Details"): | |
| with st.container(height=600): | |
| show_bill(bdict) | |
| for message in st.session_state["messages"]: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("How can I help you understand this bill?"): | |
| if st.session_state["openai_api_key"] == "": | |
| error_header.warning("Enter API key to chat") | |
| st.stop() | |
| elif not st.session_state["openai_api_key"].startswith("sk-"): | |
| error_header.warning("Enter valid API key to chat") | |
| st.stop() | |
| else: | |
| client = OpenAI(api_key=st.session_state["openai_api_key"]) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| st.session_state["messages"].append({"role": "user", "content": prompt}) | |
| with st.chat_message("assistant"): | |
| stream = client.chat.completions.create( | |
| model=st.session_state["openai_model"], | |
| messages=[system_message] + [ | |
| {"role": msg["role"], "content": msg["content"]} | |
| for msg in st.session_state.messages | |
| ], | |
| temperature=0.0, | |
| stream=True, | |
| ) | |
| response = st.write_stream(stream) | |
| st.session_state["messages"].append({"role": "assistant", "content": response}) |