File size: 4,756 Bytes
a6576f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations
from pathlib import Path
import os
import yaml
import pandas as pd
import numpy as np
from huggingface_hub import HfApi
from datetime import datetime, timezone
import re

# Root directory of the project
ROOT = Path(__file__).resolve().parent.parent

# Detect Streamlit runtime
try:
    import streamlit as st
    has_streamlit = True
except ImportError:
    has_streamlit = False

# Load environment variables when running locally
if os.getenv("ENV") == "local" or not has_streamlit:
    from dotenv import load_dotenv
    load_dotenv(ROOT / ".env")

# Read Hugging Face dataset repo ID from config
with open(ROOT / "config.yaml") as f:
    cfg = yaml.safe_load(f)
REPO_ID: str = cfg["repo_id"]

# Initialize Hugging Face API client
api = HfApi()

# URL for the summary CSV in the dataset
CSV_URL = (
    f"https://huggingface.co/datasets/{REPO_ID}/resolve/main/subreddit_daily_summary.csv"
)


def get_secret(key: str, default=None) -> str | None:
    """Fetch a secret from environment variables or Streamlit secrets."""
    val = os.getenv(key)
    if val is None and has_streamlit:
        val = st.secrets.get(key, default)
    return val


import streamlit as st

@st.cache_data(ttl=6000, show_spinner=False)
def load_summary() -> pd.DataFrame:
    """Download and return the subreddit daily summary as a DataFrame using HF Hub API. Cached for 10 minutes."""
    # Use HF Hub API to download the file instead of direct URL
    local_file = api.hf_hub_download(
        repo_id=REPO_ID,
        filename="subreddit_daily_summary.csv",
        repo_type="dataset"
    )
    df = pd.read_csv(local_file, parse_dates=["date"])
    needed = {"date", "subreddit", "mean_sentiment", "community_weighted_sentiment", "count"}
    if not needed.issubset(df.columns):
        missing = needed - set(df.columns)
        raise ValueError(f"Missing columns in summary CSV: {missing}")
    return df


def _sanitize(name: str) -> str:
    """
    Make subreddit safe for filenames (removes slashes, spaces, etc.).
    """
    name = name.strip().lower()
    name = re.sub(r"[^\w\-\.]", "_", name)
    return name


@st.cache_data(show_spinner=False, ttl=60*60)
def load_day(date: str, subreddit: str) -> pd.DataFrame:
    """Lazy-download the parquet shard for one YYYY-MM-DD and return df slice.
    
    Args:
        date: Date string in YYYY-MM-DD format
        subreddit: Subreddit name to filter by
        
    Returns:
        DataFrame containing posts from the specified subreddit on the given day
    """
    # Download the subreddit-specific file using sanitized subreddit
    safe_sub = _sanitize(subreddit)
    fname = f"data_scored_subreddit/{date}__{safe_sub}.parquet"
    local = api.hf_hub_download(REPO_ID, fname, repo_type="dataset")
    df_day = pd.read_parquet(local)
    # File contains only the selected subreddit; reset index
    return df_day.reset_index(drop=True)


def get_last_updated_hf(repo_id: str) -> datetime:
    """
    Retrieve the dataset repo's last modified datetime via HF Hub API.
    Returns a timezone-aware datetime in UTC.
    """
    info = api.repo_info(repo_id=repo_id, repo_type="dataset")
    dt: datetime = info.lastModified  # already a datetime object
    if dt.tzinfo is not None:
        dt = dt.astimezone(timezone.utc)
    return dt


def get_last_updated_hf_caption() -> str:
    """
    Build a markdown-formatted caption string showing the dataset source and last update.
    Uses REPO_ID and the HF Hub API to fetch the timestamp.
    """
    # Generate dataset link and timestamp
    dataset_url = f"https://huggingface.co/datasets/{REPO_ID}"
    last_update_dt = get_last_updated_hf(REPO_ID)
    last_update = last_update_dt.strftime("%Y-%m-%d %H:%M:%S UTC")

    # Return the small-caption HTML/markdown string
    return (
        f"<small>"
        f"Data source: <a href='{dataset_url}' target='_blank'>{REPO_ID}</a> &bull; "
        f"Last updated: {last_update}"
        f"</small>"
    )


def add_rolling(df: pd.DataFrame, window: int = 7) -> pd.DataFrame:
    """Add a rolling mean for community_weighted_sentiment over the specified window."""
    out = df.copy()
    for sub, grp in out.groupby("subreddit"):
        grp_sorted = grp.sort_values("date")
        roll = grp_sorted["community_weighted_sentiment"].rolling(window).mean()
        out.loc[grp_sorted.index, f"roll_{window}"] = roll
    return out


def get_subreddit_colors(subreddits: list[str]) -> dict[str, str]:
    """Provide a consistent color map for each subreddit."""
    palette = [
        "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
        "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
    ]
    return {sub: palette[i % len(palette)] for i, sub in enumerate(sorted(subreddits))}