import json import math import os import random import pandas as pd import streamlit as st from huggingface_hub import HfApi st.set_page_config(page_title="Knesset Plenums Dataset Preview", layout="wide") fallback_dataset_repo_owner = os.environ.get("REPO_OWNER", "ivrit-ai") dataset_repo_owner = os.environ.get("SPACE_AUTHOR_NAME", fallback_dataset_repo_owner) dataset_repo_name = os.environ.get("DATASET_REPO_NAME", "knesset-plenums") repo_id = f"{dataset_repo_owner}/{dataset_repo_name}" hf_api = HfApi(token=st.secrets["HF_TOKEN"]) manifest_file = hf_api.hf_hub_download(repo_id, "manifest.csv", repo_type="dataset") manifest_df = pd.read_csv(manifest_file) # Filter samples with duration less than 7200 seconds (2 hours) filtered_samples = manifest_df[manifest_df["duration"] < 7200].copy() # Convert duration from seconds to hours for display filtered_samples["duration_hours"] = filtered_samples["duration"] / 3600 # Create display options for the dropdown sample_options = {} for _, row in filtered_samples.iterrows(): plenum_id = str(row["plenum_id"]) plenum_date = row["plenum_date"] hours = round(row["duration_hours"], 1) display_text = f"{plenum_date} - ({hours} hours)" sample_options[display_text] = plenum_id # Default to sample_id 81733 if available, otherwise use the first sample default_sample_id = "81733" default_option = next( (k for k, v in sample_options.items() if v == default_sample_id), next(iter(sample_options.keys())) if sample_options else None, ) # Create the dropdown for sample selection selected_option = st.sidebar.selectbox( "Select a plenum sample:", options=list(sample_options.keys()), index=list(sample_options.keys()).index(default_option) if default_option else 0, ) # Get the selected plenum ID sample_plenum_id = sample_options[selected_option] sample_audio_file_repo_path = f"{sample_plenum_id}/audio.m4a" sample_metadata_file_repo_path = f"{sample_plenum_id}/metadata.json" sample_aligned_file_repo_path = f"{sample_plenum_id}/transcript.aligned.json" sample_raw_text_repo_path = f"{sample_plenum_id}/raw.transcript.txt" # Display the title with the selected Plenum ID st.title(f"Knesset Plenum ID: {sample_plenum_id}") st.markdown( "Please refer to the main dataset card for more details. [ivrit.ai/knesset-plenums](https://huggingface.co/datasets/ivrit-ai/knesset-plenums)" "\n\nThis preview shows a small subset (the smallest samples) of the dataset." ) # Cache the sample data loading to only reload when the sample changes @st.cache_data def load_sample_data(repo_id, plenum_id): """Load sample data files for a given plenum ID""" audio_path = f"{plenum_id}/audio.m4a" metadata_path = f"{plenum_id}/metadata.json" transcript_path = f"{plenum_id}/transcript.aligned.json" audio_file = hf_api.hf_hub_download(repo_id, audio_path, repo_type="dataset") metadata_file = hf_api.hf_hub_download(repo_id, metadata_path, repo_type="dataset") transcript_file = hf_api.hf_hub_download( repo_id, transcript_path, repo_type="dataset" ) raw_transcript_text_file = hf_api.hf_hub_download( repo_id, sample_raw_text_repo_path, repo_type="dataset" ) return audio_file, metadata_file, transcript_file, raw_transcript_text_file # Load the sample data for the selected plenum ( sample_audio_file, sample_metadata_file, sample_transcript_aligned_file, sample_raw_transcript_text_file, ) = load_sample_data(repo_id, sample_plenum_id) # Parses the metadata file of this sample - to get the list of all segments. with open(sample_metadata_file, "r") as f: sample_metadata = json.load(f) # each segment is a dict with the structure: # { # "start": 3527.26, # "end": 3531.53, # "probability": 0.9309 # }, segments_quality_scores = sample_metadata["per_segment_quality_scores"] segments_quality_scores_df = pd.DataFrame(segments_quality_scores) segments_quality_scores_df["segment_id"] = segments_quality_scores_df.index with open(sample_transcript_aligned_file, "r") as f: sample_transcript_aligned = json.load(f) transcript_segments = sample_transcript_aligned["segments"] with open(sample_raw_transcript_text_file, "r") as f: sample_raw_text = f.read() col_main, col_aux = st.columns([2, 3]) event = col_main.dataframe( segments_quality_scores_df, on_select="rerun", hide_index=True, selection_mode=["single-row"], column_config={ "probability": st.column_config.ProgressColumn( label="Quality Score", width="medium", format="percent", min_value=0, max_value=1, ) }, ) # Initialize session state for selection if it doesn't exist if "default_selection" not in st.session_state: st.session_state.default_selection = random.randint( 0, min(49, len(segments_quality_scores_df) - 1) ) # If a selection exists, get the start and end times of the selected segment if event and event.selection and event.selection["rows"]: row_idx = event.selection["rows"][0] else: # Use the default random selection if no row is selected row_idx = st.session_state.default_selection df_row = segments_quality_scores_df.iloc[row_idx] segment_id = int(df_row["segment_id"]) selected_segment = segments_quality_scores[segment_id] start_time = selected_segment["start"] end_time = selected_segment["end"] with col_main: st.write(f"Selected segment: {selected_segment}") start_at = selected_segment["start"] end_at = selected_segment["end"] st.audio( sample_audio_file, start_time=math.floor(start_at), end_time=math.ceil(end_at), autoplay=True, ) transcript_segment = transcript_segments[segment_id] st.caption(f'
{transcript_segment["text"]}
', unsafe_allow_html=True) st.divider() st.caption( f"Note: The audio will start at {math.floor(start_at)} seconds and end at {math.ceil(end_at)} seconds (rounded up/down) since this is the resolution of the player, actual segments are more accurate." ) with col_aux: # Create a chart of Quality vs start_time st.subheader("Segment Quality Over Time") # Prepare data for the chart chart_data = segments_quality_scores_df.copy() chart_data = chart_data.sort_values(by="start") # Add a scatter plot to highlight the selected segment import altair as alt import pandas as pd # Create a base chart with all points base_chart = alt.Chart(chart_data).mark_circle(size=20).encode( x=alt.X('start:Q', title='Start Time (seconds)'), y=alt.Y('probability:Q', title='Quality Score', scale=alt.Scale(domain=[0, 1])), tooltip=['start', 'end', 'probability'] ) # Create a highlight for the selected segment selected_point = pd.DataFrame([{ 'start': selected_segment['start'], 'probability': selected_segment['probability'] }]) highlight = alt.Chart(selected_point).mark_circle(size=120, color='red').encode( x='start:Q', y='probability:Q' ) # Combine the charts combined_chart = base_chart + highlight # Display the chart st.altair_chart(combined_chart, use_container_width=True) with st.expander("Raw Transcript Text", expanded=False): st.text_area( "Raw Transcript Text", value=sample_raw_text, height=300, label_visibility="collapsed", disabled=True, ) with st.expander("Sample Metadata", expanded=False): st.json( sample_metadata )