hblim's picture
Clean codebase for HF Space (drop Prometheus binary data)
a6576f0
#!/usr/bin/env python
"""
Summarise scored shards into one daily_summary.csv
CLI examples
------------
# Summarize data for a specific date
python -m reddit_analysis.summarizer.summarize --date 2025-04-20
"""
from __future__ import annotations
import argparse
from datetime import date
from pathlib import Path
from typing import Optional, List, Dict, Any, Set, Tuple
import pandas as pd
from huggingface_hub import hf_hub_download, HfApi
from reddit_analysis.config_utils import setup_config
from reddit_analysis.summarizer.aggregator import summary_from_df
# --------------------------------------------------------------------------- #
# Utilities #
# --------------------------------------------------------------------------- #
class FileManager:
"""Wrapper class for simple local file I/O that can be mocked for testing."""
def __init__(self, base_dir: Path):
self.base_dir = base_dir
self.base_dir.mkdir(parents=True, exist_ok=True)
# ---------- CSV helpers ------------------------------------------------- #
def read_csv(self, path: Path) -> pd.DataFrame:
if not path.exists() or path.stat().st_size == 0:
return pd.DataFrame(
columns=["date", "subreddit",
"mean_sentiment", "community_weighted_sentiment", "count"]
)
return pd.read_csv(path)
def write_csv(self, df: pd.DataFrame, path: Path) -> Path:
df.to_csv(path, index=False)
return path
# ---------- Parquet helper --------------------------------------------- #
@staticmethod
def read_parquet(path: Path) -> pd.DataFrame:
return pd.read_parquet(path)
class HuggingFaceManager:
"""Thin wrapper around Hugging Face Hub file ops (mock‑friendly)."""
def __init__(self, token: str, repo_id: str, repo_type: str = "dataset"):
self.token = token
self.repo_id = repo_id
self.repo_type = repo_type
self.api = HfApi(token=token)
def download_file(self, path_in_repo: str) -> Path:
return Path(
hf_hub_download(
repo_id=self.repo_id,
repo_type=self.repo_type,
filename=path_in_repo,
token=self.token
)
)
def upload_file(self, local_path: str, path_in_repo: str):
self.api.upload_file(
path_or_fileobj=local_path,
path_in_repo=path_in_repo,
repo_id=self.repo_id,
repo_type=self.repo_type,
token=self.token
)
def list_files(self, prefix: str) -> List[str]:
"""List files in the HF repo filtered by prefix."""
files = self.api.list_repo_files(
repo_id=self.repo_id,
repo_type=self.repo_type
)
return [f for f in files if f.startswith(prefix)]
# --------------------------------------------------------------------------- #
# Core manager #
# --------------------------------------------------------------------------- #
class SummaryManager:
def __init__(
self,
cfg: Dict[str, Any],
file_manager: Optional[FileManager] = None,
hf_manager: Optional[HuggingFaceManager] = None
):
self.config = cfg["config"]
self.secrets = cfg["secrets"]
self.paths = cfg["paths"]
# I/O helpers
self.file_manager = file_manager or FileManager(self.paths["root"])
self.hf_manager = hf_manager or HuggingFaceManager(
token=self.secrets["HF_TOKEN"],
repo_id=self.config["repo_id"],
repo_type=self.config.get("repo_type", "dataset"),
)
# Cache path for the combined summary file on disk
self.local_summary_path: Path = self.paths["summary_file"]
# --------------------------------------------------------------------- #
# Remote summary helpers #
# --------------------------------------------------------------------- #
def _load_remote_summary(self) -> pd.DataFrame:
"""
Ensure `daily_summary.csv` is present locally by downloading the
latest version from HF Hub (if it exists) and return it as a DataFrame.
"""
remote_name = self.paths["summary_file"].name
try:
cached_path = self.hf_manager.download_file(remote_name)
except Exception:
# first run – file doesn't exist yet on the Hub
return pd.DataFrame(
columns=["date", "subreddit",
"mean_sentiment", "community_weighted_sentiment", "count"]
)
return pd.read_csv(cached_path)
def _save_and_push_summary(self, df: pd.DataFrame):
"""Persist the updated summary both locally and back to HF Hub."""
self.file_manager.write_csv(df, self.local_summary_path)
self.hf_manager.upload_file(str(self.local_summary_path),
self.local_summary_path.name)
# --------------------------------------------------------------------- #
# Public helpers #
# --------------------------------------------------------------------- #
def get_processed_combinations(self) -> Set[Tuple[date, str]]:
"""
Return a set of (date, subreddit) pairs that are *already* present
in the remote summary so we can de‑duplicate.
"""
df_summary = self._load_remote_summary()
if df_summary.empty:
return set()
df_summary["date"] = pd.to_datetime(df_summary["date"]).dt.date
return {
(row["date"], row["subreddit"])
for _, row in df_summary.iterrows()
}
# --------------------------------------------------------------------- #
# Main workflow #
# --------------------------------------------------------------------- #
def process_date(self, date_str: str, overwrite: bool = False) -> None:
"""Download scored data for `date_str`, aggregate, and append/upload."""
# ---------- Pull scored shards for the given date ------------------ #
prefix = f"{self.paths['hf_scored_dir']}/{date_str}__"
# List all remote shards
try:
all_files = self.hf_manager.list_files(self.paths['hf_scored_dir'])
except Exception as err:
print(f"Error: could not list scored shards in {self.paths['hf_scored_dir']}: {err}")
return
# Filter to shards matching this date
try:
shards = [fn for fn in all_files if fn.startswith(prefix) and fn.endswith('.parquet')]
except TypeError:
# fall back in case list_files returned a non-iterable (e.g., a mock)
shards = [all_files]
if not shards:
print(f"No scored shards found for {date_str} under {self.paths['hf_scored_dir']}")
return
# Download and concatenate all shards
dfs: List[pd.DataFrame] = []
for shard in shards:
try:
local_path = self.hf_manager.download_file(shard)
except Exception as err:
print(f"Error: could not download scored shard {shard}: {err}")
return
dfs.append(self.file_manager.read_parquet(local_path))
df_day = pd.concat(dfs, ignore_index=True)
# sanity‑check
required_cols = {"retrieved_at", "subreddit", "sentiment", "score"}
if not required_cols.issubset(df_day.columns):
raise ValueError(f"{shards[0]} missing columns {required_cols}")
# ---------- Aggregate ------------------------------------------------ #
df_summary_day = summary_from_df(df_day)
# ---------- De‑duplication / overwrite ------------------------------ #
existing_pairs = self.get_processed_combinations()
if not overwrite:
df_summary_day = df_summary_day[
~df_summary_day.apply(
lambda r: (r["date"], r["subreddit"]) in existing_pairs,
axis=1,
)
]
if df_summary_day.empty:
print("Nothing new to summarise for this date.")
return
# ---------- Combine with historical summary ------------------------- #
df_summary = self._load_remote_summary()
if overwrite:
df_summary = df_summary[df_summary["date"] != date_str]
# Remove weighted_sentiment column if it exists
if "weighted_sentiment" in df_summary.columns:
df_summary = df_summary.drop(columns=["weighted_sentiment"])
df_out = (
pd.concat([df_summary, df_summary_day], ignore_index=True)
if not df_summary.empty
else df_summary_day
)
df_out["date"] = pd.to_datetime(df_out["date"]).dt.date
df_out.sort_values(["date", "subreddit"], inplace=True)
# Ensure the weighted_sentiment column is dropped from final output
if "weighted_sentiment" in df_out.columns:
df_out = df_out.drop(columns=["weighted_sentiment"])
# Round floating point columns to 4 decimal places
if "mean_sentiment" in df_out.columns:
df_out["mean_sentiment"] = df_out["mean_sentiment"].round(4)
if "community_weighted_sentiment" in df_out.columns:
df_out["community_weighted_sentiment"] = df_out["community_weighted_sentiment"].round(4)
# ---------- Save & upload ------------------------------------------- #
self._save_and_push_summary(df_out)
print(f"Updated {self.local_summary_path.name}{len(df_out)} rows")
# --------------------------------------------------------------------------- #
# CLI entry‑point #
# --------------------------------------------------------------------------- #
def main(date_str: str, overwrite: bool = False) -> None:
if not date_str:
raise ValueError("--date is required (YYYY-MM-DD)")
# Confirm valid date
try:
date.fromisoformat(date_str)
except ValueError:
raise ValueError(f"Invalid date: {date_str} (expected YYYY‑MM‑DD)")
cfg = setup_config()
SummaryManager(cfg).process_date(date_str, overwrite)
if __name__ == "__main__":
from reddit_analysis.common_metrics import run_with_metrics
parser = argparse.ArgumentParser(
description="Summarize scored Reddit data for a specific date."
)
parser.add_argument("--date", required=True,
help="YYYY-MM-DD date to process")
parser.add_argument("--overwrite", action="store_true",
help="Replace any existing rows for this date")
args = parser.parse_args()
run_with_metrics("summarize", main, args.date, args.overwrite)