|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from typing import List, Optional |
|
|
|
def rank_sample( |
|
df: pd.DataFrame, |
|
name_col: str = "name", |
|
category_col: str = "category", |
|
sentiment_col: str = "sentiment_score", |
|
groups: Optional[List[str]] = None, |
|
num_samples: int = 1000, |
|
temp: float = 1.0, |
|
target_value: float = 0.5, |
|
) -> pd.DataFrame: |
|
|
|
df = df.copy() |
|
|
|
for col in [name_col, category_col, sentiment_col]: |
|
if col not in df.columns: |
|
raise ValueError(f"Column '{col}' not found in DataFrame") |
|
|
|
df = df.dropna(subset=[name_col, category_col, sentiment_col]) |
|
|
|
if groups: |
|
available_groups = df[category_col].unique() |
|
valid_groups = [g for g in groups if g in available_groups] |
|
if len(valid_groups) < 2: |
|
print(f"Warning: Only {len(valid_groups)} groups available from {groups}") |
|
groups = None |
|
else: |
|
groups = valid_groups |
|
df = df[df[category_col].isin(groups)].copy() |
|
|
|
final_groups = df[category_col].unique() |
|
if len(final_groups) < 2: |
|
print(f"Error: Only {len(final_groups)} groups in data, need at least 2") |
|
return df.groupby(name_col).first().reset_index() |
|
|
|
print(f"Sampling with groups: {sorted(final_groups)}") |
|
print(f"Target value for deviation calculation: {target_value}") |
|
|
|
df["sentiment_deviation"] = (df[sentiment_col] - target_value).abs() |
|
df["sentiment_rank"] = df.groupby(name_col)["sentiment_deviation"].rank(method="first", ascending=True) |
|
|
|
def softmax_weights(ranks: np.ndarray, temp: float) -> np.ndarray: |
|
t = float(temp) if temp and temp > 1e-8 else 1e-8 |
|
x = -ranks / t |
|
x = x - np.max(x) |
|
exps = np.exp(x) |
|
s = exps.sum() |
|
return exps / s if np.isfinite(s) and s > 0 else np.ones_like(exps) / len(exps) |
|
|
|
def objective_max_pairwise_diff(frame: pd.DataFrame) -> float: |
|
g = frame.groupby(category_col)[sentiment_col].mean().dropna() |
|
if len(g) < 2: |
|
return np.inf |
|
vals = g.values |
|
diffs = np.abs(vals[:, None] - vals[None, :]) |
|
return float(np.max(diffs)) |
|
|
|
best_subset = None |
|
best_obj = np.inf |
|
valid_samples = 0 |
|
|
|
unique_names = df[name_col].nunique() |
|
print(f"Total unique names: {unique_names}") |
|
|
|
for i in tqdm(range(num_samples), desc="Sampling"): |
|
try: |
|
sampled_rows = [] |
|
|
|
for name, group in df.groupby(name_col): |
|
if len(group) == 0: |
|
continue |
|
|
|
ranks = group["sentiment_rank"].to_numpy(dtype=float) |
|
if len(ranks) == 0: |
|
continue |
|
|
|
w = softmax_weights(ranks, temp=temp) |
|
idx = np.random.choice(group.index, p=w) |
|
sampled_rows.append(df.loc[idx]) |
|
|
|
if len(sampled_rows) == 0: |
|
continue |
|
|
|
subset = pd.DataFrame(sampled_rows) |
|
|
|
subset_groups = subset[category_col].unique() |
|
if len(subset_groups) < 2: |
|
continue |
|
|
|
obj = objective_max_pairwise_diff(subset) |
|
|
|
if np.isfinite(obj): |
|
valid_samples += 1 |
|
if obj < best_obj: |
|
best_obj = obj |
|
best_subset = subset.copy() |
|
|
|
if valid_samples % 100 == 0 or valid_samples <= 10: |
|
group_means = subset.groupby(category_col)[sentiment_col].mean() |
|
print(f"Sample {valid_samples}: obj={obj:.4f}, groups={dict(group_means)}") |
|
|
|
except Exception as e: |
|
print(f"Error in sample {i}: {e}") |
|
continue |
|
|
|
print(f"Valid samples: {valid_samples}/{num_samples}") |
|
print(f"Best objective: {best_obj:.4f}") |
|
|
|
if best_subset is None or len(best_subset) == 0: |
|
print("Warning: No valid samples found, returning fallback subset") |
|
best_subset = df.groupby(name_col).first().reset_index() |
|
|
|
final_group_counts = best_subset[category_col].value_counts() |
|
print(f"Final subset group distribution: {dict(final_group_counts)}") |
|
|
|
return best_subset |