Spaces:
Running
Running
File size: 4,756 Bytes
a6576f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from __future__ import annotations
from pathlib import Path
import os
import yaml
import pandas as pd
import numpy as np
from huggingface_hub import HfApi
from datetime import datetime, timezone
import re
# Root directory of the project
ROOT = Path(__file__).resolve().parent.parent
# Detect Streamlit runtime
try:
import streamlit as st
has_streamlit = True
except ImportError:
has_streamlit = False
# Load environment variables when running locally
if os.getenv("ENV") == "local" or not has_streamlit:
from dotenv import load_dotenv
load_dotenv(ROOT / ".env")
# Read Hugging Face dataset repo ID from config
with open(ROOT / "config.yaml") as f:
cfg = yaml.safe_load(f)
REPO_ID: str = cfg["repo_id"]
# Initialize Hugging Face API client
api = HfApi()
# URL for the summary CSV in the dataset
CSV_URL = (
f"https://huggingface.co/datasets/{REPO_ID}/resolve/main/subreddit_daily_summary.csv"
)
def get_secret(key: str, default=None) -> str | None:
"""Fetch a secret from environment variables or Streamlit secrets."""
val = os.getenv(key)
if val is None and has_streamlit:
val = st.secrets.get(key, default)
return val
import streamlit as st
@st.cache_data(ttl=6000, show_spinner=False)
def load_summary() -> pd.DataFrame:
"""Download and return the subreddit daily summary as a DataFrame using HF Hub API. Cached for 10 minutes."""
# Use HF Hub API to download the file instead of direct URL
local_file = api.hf_hub_download(
repo_id=REPO_ID,
filename="subreddit_daily_summary.csv",
repo_type="dataset"
)
df = pd.read_csv(local_file, parse_dates=["date"])
needed = {"date", "subreddit", "mean_sentiment", "community_weighted_sentiment", "count"}
if not needed.issubset(df.columns):
missing = needed - set(df.columns)
raise ValueError(f"Missing columns in summary CSV: {missing}")
return df
def _sanitize(name: str) -> str:
"""
Make subreddit safe for filenames (removes slashes, spaces, etc.).
"""
name = name.strip().lower()
name = re.sub(r"[^\w\-\.]", "_", name)
return name
@st.cache_data(show_spinner=False, ttl=60*60)
def load_day(date: str, subreddit: str) -> pd.DataFrame:
"""Lazy-download the parquet shard for one YYYY-MM-DD and return df slice.
Args:
date: Date string in YYYY-MM-DD format
subreddit: Subreddit name to filter by
Returns:
DataFrame containing posts from the specified subreddit on the given day
"""
# Download the subreddit-specific file using sanitized subreddit
safe_sub = _sanitize(subreddit)
fname = f"data_scored_subreddit/{date}__{safe_sub}.parquet"
local = api.hf_hub_download(REPO_ID, fname, repo_type="dataset")
df_day = pd.read_parquet(local)
# File contains only the selected subreddit; reset index
return df_day.reset_index(drop=True)
def get_last_updated_hf(repo_id: str) -> datetime:
"""
Retrieve the dataset repo's last modified datetime via HF Hub API.
Returns a timezone-aware datetime in UTC.
"""
info = api.repo_info(repo_id=repo_id, repo_type="dataset")
dt: datetime = info.lastModified # already a datetime object
if dt.tzinfo is not None:
dt = dt.astimezone(timezone.utc)
return dt
def get_last_updated_hf_caption() -> str:
"""
Build a markdown-formatted caption string showing the dataset source and last update.
Uses REPO_ID and the HF Hub API to fetch the timestamp.
"""
# Generate dataset link and timestamp
dataset_url = f"https://huggingface.co/datasets/{REPO_ID}"
last_update_dt = get_last_updated_hf(REPO_ID)
last_update = last_update_dt.strftime("%Y-%m-%d %H:%M:%S UTC")
# Return the small-caption HTML/markdown string
return (
f"<small>"
f"Data source: <a href='{dataset_url}' target='_blank'>{REPO_ID}</a> • "
f"Last updated: {last_update}"
f"</small>"
)
def add_rolling(df: pd.DataFrame, window: int = 7) -> pd.DataFrame:
"""Add a rolling mean for community_weighted_sentiment over the specified window."""
out = df.copy()
for sub, grp in out.groupby("subreddit"):
grp_sorted = grp.sort_values("date")
roll = grp_sorted["community_weighted_sentiment"].rolling(window).mean()
out.loc[grp_sorted.index, f"roll_{window}"] = roll
return out
def get_subreddit_colors(subreddits: list[str]) -> dict[str, str]:
"""Provide a consistent color map for each subreddit."""
palette = [
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
"#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
]
return {sub: palette[i % len(palette)] for i, sub in enumerate(sorted(subreddits))}
|