|
import json |
|
import math |
|
import os |
|
import random |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from huggingface_hub import HfApi |
|
|
|
st.set_page_config(page_title="Knesset Plenums Dataset Preview", layout="wide") |
|
|
|
fallback_dataset_repo_owner = os.environ.get("REPO_OWNER", "ivrit-ai") |
|
dataset_repo_owner = os.environ.get("SPACE_AUTHOR_NAME", fallback_dataset_repo_owner) |
|
dataset_repo_name = os.environ.get("DATASET_REPO_NAME", "knesset-plenums") |
|
repo_id = f"{dataset_repo_owner}/{dataset_repo_name}" |
|
|
|
hf_api = HfApi(token=st.secrets["HF_TOKEN"]) |
|
|
|
manifest_file = hf_api.hf_hub_download(repo_id, "manifest.csv", repo_type="dataset") |
|
|
|
manifest_df = pd.read_csv(manifest_file) |
|
|
|
|
|
filtered_samples = manifest_df[manifest_df["duration"] < 7200].copy() |
|
|
|
|
|
filtered_samples["duration_hours"] = filtered_samples["duration"] / 3600 |
|
|
|
|
|
sample_options = {} |
|
for _, row in filtered_samples.iterrows(): |
|
plenum_id = str(row["plenum_id"]) |
|
plenum_date = row["plenum_date"] |
|
hours = round(row["duration_hours"], 1) |
|
display_text = f"{plenum_date} - ({hours} hours)" |
|
sample_options[display_text] = plenum_id |
|
|
|
|
|
default_sample_id = "81733" |
|
default_option = next( |
|
(k for k, v in sample_options.items() if v == default_sample_id), |
|
next(iter(sample_options.keys())) if sample_options else None, |
|
) |
|
|
|
|
|
selected_option = st.sidebar.selectbox( |
|
"Select a plenum sample:", |
|
options=list(sample_options.keys()), |
|
index=list(sample_options.keys()).index(default_option) if default_option else 0, |
|
) |
|
|
|
|
|
sample_plenum_id = sample_options[selected_option] |
|
sample_audio_file_repo_path = f"{sample_plenum_id}/audio.m4a" |
|
sample_metadata_file_repo_path = f"{sample_plenum_id}/metadata.json" |
|
sample_aligned_file_repo_path = f"{sample_plenum_id}/transcript.aligned.json" |
|
sample_raw_text_repo_path = f"{sample_plenum_id}/raw.transcript.txt" |
|
|
|
|
|
|
|
st.title(f"Knesset Plenum ID: {sample_plenum_id}") |
|
st.markdown( |
|
"Please refer to the main dataset card for more details. [ivrit.ai/knesset-plenums](https://huggingface.co/datasets/ivrit-ai/knesset-plenums)" |
|
"\n\nThis preview shows a small subset (the smallest samples) of the dataset." |
|
) |
|
|
|
|
|
|
|
@st.cache_data |
|
def load_sample_data(repo_id, plenum_id): |
|
"""Load sample data files for a given plenum ID""" |
|
audio_path = f"{plenum_id}/audio.m4a" |
|
metadata_path = f"{plenum_id}/metadata.json" |
|
transcript_path = f"{plenum_id}/transcript.aligned.json" |
|
|
|
audio_file = hf_api.hf_hub_download(repo_id, audio_path, repo_type="dataset") |
|
metadata_file = hf_api.hf_hub_download(repo_id, metadata_path, repo_type="dataset") |
|
transcript_file = hf_api.hf_hub_download( |
|
repo_id, transcript_path, repo_type="dataset" |
|
) |
|
raw_transcript_text_file = hf_api.hf_hub_download( |
|
repo_id, sample_raw_text_repo_path, repo_type="dataset" |
|
) |
|
|
|
return audio_file, metadata_file, transcript_file, raw_transcript_text_file |
|
|
|
|
|
|
|
( |
|
sample_audio_file, |
|
sample_metadata_file, |
|
sample_transcript_aligned_file, |
|
sample_raw_transcript_text_file, |
|
) = load_sample_data(repo_id, sample_plenum_id) |
|
|
|
|
|
with open(sample_metadata_file, "r") as f: |
|
sample_metadata = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
segments_quality_scores = sample_metadata["per_segment_quality_scores"] |
|
segments_quality_scores_df = pd.DataFrame(segments_quality_scores) |
|
segments_quality_scores_df["segment_id"] = segments_quality_scores_df.index |
|
|
|
with open(sample_transcript_aligned_file, "r") as f: |
|
sample_transcript_aligned = json.load(f) |
|
transcript_segments = sample_transcript_aligned["segments"] |
|
|
|
with open(sample_raw_transcript_text_file, "r") as f: |
|
sample_raw_text = f.read() |
|
|
|
col_main, col_aux = st.columns([2, 3]) |
|
|
|
event = col_main.dataframe( |
|
segments_quality_scores_df, |
|
on_select="rerun", |
|
hide_index=True, |
|
selection_mode=["single-row"], |
|
column_config={ |
|
"probability": st.column_config.ProgressColumn( |
|
label="Quality Score", |
|
width="medium", |
|
format="percent", |
|
min_value=0, |
|
max_value=1, |
|
) |
|
}, |
|
) |
|
|
|
|
|
|
|
if "default_selection" not in st.session_state: |
|
st.session_state.default_selection = random.randint( |
|
0, min(49, len(segments_quality_scores_df) - 1) |
|
) |
|
|
|
|
|
if event and event.selection and event.selection["rows"]: |
|
row_idx = event.selection["rows"][0] |
|
else: |
|
|
|
row_idx = st.session_state.default_selection |
|
|
|
df_row = segments_quality_scores_df.iloc[row_idx] |
|
segment_id = int(df_row["segment_id"]) |
|
selected_segment = segments_quality_scores[segment_id] |
|
start_time = selected_segment["start"] |
|
end_time = selected_segment["end"] |
|
|
|
with col_main: |
|
st.write(f"Selected segment: {selected_segment}") |
|
start_at = selected_segment["start"] |
|
end_at = selected_segment["end"] |
|
|
|
st.audio( |
|
sample_audio_file, |
|
start_time=math.floor(start_at), |
|
end_time=math.ceil(end_at), |
|
autoplay=True, |
|
) |
|
transcript_segment = transcript_segments[segment_id] |
|
st.caption(f'<div dir="rtl">{transcript_segment["text"]}</div>', unsafe_allow_html=True) |
|
st.divider() |
|
st.caption( |
|
f"Note: The audio will start at {math.floor(start_at)} seconds and end at {math.ceil(end_at)} seconds (rounded up/down) since this is the resolution of the player, actual segments are more accurate." |
|
) |
|
|
|
|
|
with col_aux: |
|
|
|
st.subheader("Segment Quality Over Time") |
|
|
|
|
|
chart_data = segments_quality_scores_df.copy() |
|
chart_data = chart_data.sort_values(by="start") |
|
|
|
|
|
import altair as alt |
|
import pandas as pd |
|
|
|
|
|
base_chart = alt.Chart(chart_data).mark_circle(size=20).encode( |
|
x=alt.X('start:Q', title='Start Time (seconds)'), |
|
y=alt.Y('probability:Q', title='Quality Score', scale=alt.Scale(domain=[0, 1])), |
|
tooltip=['start', 'end', 'probability'] |
|
) |
|
|
|
|
|
selected_point = pd.DataFrame([{ |
|
'start': selected_segment['start'], |
|
'probability': selected_segment['probability'] |
|
}]) |
|
|
|
highlight = alt.Chart(selected_point).mark_circle(size=120, color='red').encode( |
|
x='start:Q', |
|
y='probability:Q' |
|
) |
|
|
|
|
|
combined_chart = base_chart + highlight |
|
|
|
|
|
st.altair_chart(combined_chart, use_container_width=True) |
|
|
|
with st.expander("Raw Transcript Text", expanded=False): |
|
st.text_area( |
|
"Raw Transcript Text", |
|
value=sample_raw_text, |
|
height=300, |
|
label_visibility="collapsed", |
|
disabled=True, |
|
) |
|
|
|
with st.expander("Sample Metadata", expanded=False): |
|
st.json( |
|
sample_metadata |
|
) |
|
|