Yoad
fix fallback owner name
ead6402
import json
import math
import os
import random
import pandas as pd
import streamlit as st
from huggingface_hub import HfApi
st.set_page_config(page_title="Knesset Plenums Dataset Preview", layout="wide")
fallback_dataset_repo_owner = os.environ.get("REPO_OWNER", "ivrit-ai")
dataset_repo_owner = os.environ.get("SPACE_AUTHOR_NAME", fallback_dataset_repo_owner)
dataset_repo_name = os.environ.get("DATASET_REPO_NAME", "knesset-plenums")
repo_id = f"{dataset_repo_owner}/{dataset_repo_name}"
hf_api = HfApi(token=st.secrets["HF_TOKEN"])
manifest_file = hf_api.hf_hub_download(repo_id, "manifest.csv", repo_type="dataset")
manifest_df = pd.read_csv(manifest_file)
# Filter samples with duration less than 7200 seconds (2 hours)
filtered_samples = manifest_df[manifest_df["duration"] < 7200].copy()
# Convert duration from seconds to hours for display
filtered_samples["duration_hours"] = filtered_samples["duration"] / 3600
# Create display options for the dropdown
sample_options = {}
for _, row in filtered_samples.iterrows():
plenum_id = str(row["plenum_id"])
plenum_date = row["plenum_date"]
hours = round(row["duration_hours"], 1)
display_text = f"{plenum_date} - ({hours} hours)"
sample_options[display_text] = plenum_id
# Default to sample_id 81733 if available, otherwise use the first sample
default_sample_id = "81733"
default_option = next(
(k for k, v in sample_options.items() if v == default_sample_id),
next(iter(sample_options.keys())) if sample_options else None,
)
# Create the dropdown for sample selection
selected_option = st.sidebar.selectbox(
"Select a plenum sample:",
options=list(sample_options.keys()),
index=list(sample_options.keys()).index(default_option) if default_option else 0,
)
# Get the selected plenum ID
sample_plenum_id = sample_options[selected_option]
sample_audio_file_repo_path = f"{sample_plenum_id}/audio.m4a"
sample_metadata_file_repo_path = f"{sample_plenum_id}/metadata.json"
sample_aligned_file_repo_path = f"{sample_plenum_id}/transcript.aligned.json"
sample_raw_text_repo_path = f"{sample_plenum_id}/raw.transcript.txt"
# Display the title with the selected Plenum ID
st.title(f"Knesset Plenum ID: {sample_plenum_id}")
st.markdown(
"Please refer to the main dataset card for more details. [ivrit.ai/knesset-plenums](https://huggingface.co/datasets/ivrit-ai/knesset-plenums)"
"\n\nThis preview shows a small subset (the smallest samples) of the dataset."
)
# Cache the sample data loading to only reload when the sample changes
@st.cache_data
def load_sample_data(repo_id, plenum_id):
"""Load sample data files for a given plenum ID"""
audio_path = f"{plenum_id}/audio.m4a"
metadata_path = f"{plenum_id}/metadata.json"
transcript_path = f"{plenum_id}/transcript.aligned.json"
audio_file = hf_api.hf_hub_download(repo_id, audio_path, repo_type="dataset")
metadata_file = hf_api.hf_hub_download(repo_id, metadata_path, repo_type="dataset")
transcript_file = hf_api.hf_hub_download(
repo_id, transcript_path, repo_type="dataset"
)
raw_transcript_text_file = hf_api.hf_hub_download(
repo_id, sample_raw_text_repo_path, repo_type="dataset"
)
return audio_file, metadata_file, transcript_file, raw_transcript_text_file
# Load the sample data for the selected plenum
(
sample_audio_file,
sample_metadata_file,
sample_transcript_aligned_file,
sample_raw_transcript_text_file,
) = load_sample_data(repo_id, sample_plenum_id)
# Parses the metadata file of this sample - to get the list of all segments.
with open(sample_metadata_file, "r") as f:
sample_metadata = json.load(f)
# each segment is a dict with the structure:
# {
# "start": 3527.26,
# "end": 3531.53,
# "probability": 0.9309
# },
segments_quality_scores = sample_metadata["per_segment_quality_scores"]
segments_quality_scores_df = pd.DataFrame(segments_quality_scores)
segments_quality_scores_df["segment_id"] = segments_quality_scores_df.index
with open(sample_transcript_aligned_file, "r") as f:
sample_transcript_aligned = json.load(f)
transcript_segments = sample_transcript_aligned["segments"]
with open(sample_raw_transcript_text_file, "r") as f:
sample_raw_text = f.read()
col_main, col_aux = st.columns([2, 3])
event = col_main.dataframe(
segments_quality_scores_df,
on_select="rerun",
hide_index=True,
selection_mode=["single-row"],
column_config={
"probability": st.column_config.ProgressColumn(
label="Quality Score",
width="medium",
format="percent",
min_value=0,
max_value=1,
)
},
)
# Initialize session state for selection if it doesn't exist
if "default_selection" not in st.session_state:
st.session_state.default_selection = random.randint(
0, min(49, len(segments_quality_scores_df) - 1)
)
# If a selection exists, get the start and end times of the selected segment
if event and event.selection and event.selection["rows"]:
row_idx = event.selection["rows"][0]
else:
# Use the default random selection if no row is selected
row_idx = st.session_state.default_selection
df_row = segments_quality_scores_df.iloc[row_idx]
segment_id = int(df_row["segment_id"])
selected_segment = segments_quality_scores[segment_id]
start_time = selected_segment["start"]
end_time = selected_segment["end"]
with col_main:
st.write(f"Selected segment: {selected_segment}")
start_at = selected_segment["start"]
end_at = selected_segment["end"]
st.audio(
sample_audio_file,
start_time=math.floor(start_at),
end_time=math.ceil(end_at),
autoplay=True,
)
transcript_segment = transcript_segments[segment_id]
st.caption(f'<div dir="rtl">{transcript_segment["text"]}</div>', unsafe_allow_html=True)
st.divider()
st.caption(
f"Note: The audio will start at {math.floor(start_at)} seconds and end at {math.ceil(end_at)} seconds (rounded up/down) since this is the resolution of the player, actual segments are more accurate."
)
with col_aux:
# Create a chart of Quality vs start_time
st.subheader("Segment Quality Over Time")
# Prepare data for the chart
chart_data = segments_quality_scores_df.copy()
chart_data = chart_data.sort_values(by="start")
# Add a scatter plot to highlight the selected segment
import altair as alt
import pandas as pd
# Create a base chart with all points
base_chart = alt.Chart(chart_data).mark_circle(size=20).encode(
x=alt.X('start:Q', title='Start Time (seconds)'),
y=alt.Y('probability:Q', title='Quality Score', scale=alt.Scale(domain=[0, 1])),
tooltip=['start', 'end', 'probability']
)
# Create a highlight for the selected segment
selected_point = pd.DataFrame([{
'start': selected_segment['start'],
'probability': selected_segment['probability']
}])
highlight = alt.Chart(selected_point).mark_circle(size=120, color='red').encode(
x='start:Q',
y='probability:Q'
)
# Combine the charts
combined_chart = base_chart + highlight
# Display the chart
st.altair_chart(combined_chart, use_container_width=True)
with st.expander("Raw Transcript Text", expanded=False):
st.text_area(
"Raw Transcript Text",
value=sample_raw_text,
height=300,
label_visibility="collapsed",
disabled=True,
)
with st.expander("Sample Metadata", expanded=False):
st.json(
sample_metadata
)