|
import warnings |
|
import streamlit as st |
|
warnings.filterwarnings("ignore", category=UserWarning, module="streamlit") |
|
import pandas as pd |
|
import json |
|
import os |
|
from huggingface_hub import HfApi, login |
|
from streamlit_cookies_manager import EncryptedCookieManager |
|
|
|
st.set_page_config( |
|
page_title="Holistic AI - ML Verticals papers", |
|
page_icon="👋", |
|
layout='wide' |
|
) |
|
|
|
def program(): |
|
st.title("Papers") |
|
|
|
dataset_name = "holistic-ai/mitigation_ml_bias_strategies" |
|
token = os.getenv("HF_TOKEN") |
|
|
|
api = HfApi() |
|
login(token) |
|
|
|
repo_path = api.snapshot_download(repo_id=dataset_name, repo_type="dataset") |
|
dirnames = [dirname for dirname in os.listdir(repo_path) if not dirname.startswith(".")] |
|
dirnames = sorted(dirnames, key=lambda x:len(x)) |
|
|
|
st.sidebar.title("Namespaces") |
|
selected_namespace = st.sidebar.selectbox("Select Namespace", dirnames) |
|
|
|
selected_paper_type = st.sidebar.selectbox("Select Paper Type", ['Metrics',"Mitigators"]) |
|
|
|
if selected_namespace: |
|
|
|
if selected_paper_type=='Metrics': |
|
with open(f'{repo_path}/{selected_namespace}/grouped_metrics.json') as file: |
|
data = json.load(file) |
|
elif selected_paper_type=='Mitigators': |
|
with open(f'{repo_path}/{selected_namespace}/grouped_mitigators.json') as file: |
|
data = json.load(file) |
|
|
|
task_names = list(data.keys()) |
|
|
|
st.sidebar.title("Tasks") |
|
selected_task = st.sidebar.selectbox("Select a Task", task_names) |
|
|
|
if selected_task: |
|
st.header(selected_task) |
|
results = data[selected_task] |
|
rec = {str(r['id']): r for r in results['recommendations']} |
|
|
|
for group in results['groups']: |
|
ids = [i.strip() for i in group['ids'].split(",")] |
|
|
|
selected_rec = [rec[i] for i in ids] |
|
|
|
selected_rec = pd.DataFrame(selected_rec) |
|
selected_rec['date'] = pd.to_datetime(selected_rec.apply(lambda x:x['metadata']['date'], axis=1)) |
|
selected_rec = selected_rec.sort_values(by='date', ascending=False).to_dict('records') |
|
|
|
rec2html = ''.join([f"""<tr><td style="border: 1px solid #ddd; padding: 8px;">{i+1}</td><td style="border: 1px solid #ddd; padding: 8px;"><a href="{rec['metadata']['id']}" target="_blank">{rec['title']}</a></td><td style="border: 1px solid #ddd; padding: 8px;">{rec['metadata']['date']}</td></tr>""" for i, rec in enumerate(selected_rec)]) |
|
title = group['title'].split(':', 1)[1].strip() |
|
st.markdown(f""" |
|
<div style="border: 1px solid #ccc; padding: 10px; margin: 10px 0; border-radius: 5px; width: 100%;"> |
|
<p><b>{title}</b></p> |
|
<p>{group['recommendation']}</p> |
|
<table style="width: 100%; border-collapse: collapse;"> |
|
<thead> |
|
<tr> |
|
<th style="border: 1px solid #ddd; padding: 8px;">Index</th> |
|
<th style="border: 1px solid #ddd; padding: 8px;">Paper</th> |
|
<th style="border: 1px solid #ddd; padding: 8px;">Year</th> |
|
</tr> |
|
</thead> |
|
<tbody> |
|
{rec2html} |
|
</tbody> |
|
</table> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
SECRET_KEY = os.getenv('SECRET_KEY') |
|
|
|
cookies = EncryptedCookieManager( |
|
prefix="login", |
|
password=os.getenv('COOKIES_PASSWORD') |
|
) |
|
|
|
if not cookies.ready(): |
|
st.stop() |
|
|
|
def main(): |
|
|
|
st.title("Holistic AI - ML Papers") |
|
|
|
if not cookies.get("authenticated"): |
|
|
|
user_key = st.text_input("Password:", type="password") |
|
|
|
if st.button("Login"): |
|
|
|
if user_key == SECRET_KEY: |
|
cookies.__setitem__("authenticated", "True") |
|
st.experimental_rerun() |
|
else: |
|
st.error("Access not granted. Incorrect Password.") |
|
else: |
|
program() |
|
if __name__ == "__main__": |
|
main() |