RS-AAAI / backend /server.py
peihsin0715
Fix data saving
4f1edc3
from flask import Flask, request, jsonify, send_file, send_from_directory
from flask_cors import CORS
import pandas as pd
import torch
import os
from datetime import datetime
from tqdm import tqdm
import logging
from functools import lru_cache
from typing import Optional, List, Dict, Any
from utils.utils import _ensure_plot_saved
os.environ["MPLBACKEND"] = "Agg"
os.environ["QT_QPA_PLATFORM"] = "offscreen"
logging.basicConfig(level=logging.INFO)
from utils.sampling import rank_sample
try:
from transformers import TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
print("βœ“ transformers training components imported")
except Exception as e:
print(f"βœ— transformers training import failed: {e}")
def finetune(*args, **kwargs):
print("Warning: Transformers training components not available, skipping fine-tuning")
return None
# πŸ€— datasets
try:
from datasets import (
load_dataset,
load_dataset_builder,
get_dataset_config_names,
get_dataset_split_names,
Features,
)
print("βœ“ datasets imported")
except Exception as e:
print(f"βœ— datasets import failed: {e}")
raise
from utils.utils import (
generate_topk_samples,
evaluate_generated_outputs,
load_model_and_tokenizer,
generate_counterfactual_augmentations,
)
print("βœ“ utils imported")
app = Flask(__name__)
CORS(app)
_MODELS = {}
_CURRENT_DATASET = None
_GENERATION_RESULTS = None
@app.route('/tmp/<path:filename>')
def serve_data(filename):
import os
from flask import Response
print(f"[Static] Requested file: {filename}")
data_dir = os.path.abspath('/tmp')
file_path = os.path.join(data_dir, filename)
print(f"[Static] Full path: {file_path}")
print(f"[Static] File exists: {os.path.exists(file_path)}")
if not os.path.exists(file_path):
return "File not found", 404
try:
with open(file_path, 'rb') as f:
file_data = f.read()
if filename.endswith('.png'):
mimetype = 'image/png'
elif filename.endswith('.jpg') or filename.endswith('.jpeg'):
mimetype = 'image/jpeg'
elif filename.endswith('.csv'):
mimetype = 'text/csv'
else:
mimetype = 'application/octet-stream'
print(f"[Static] Serving {len(file_data)} bytes as {mimetype}")
return Response(file_data, mimetype=mimetype)
except Exception as e:
print(f"[Static] Error reading file: {e}")
return f"Error reading file: {str(e)}", 500
@app.route('/debug/files', methods=['GET'])
def debug_files():
try:
data_dir = os.path.abspath('/tmp/data')
if not os.path.exists(data_dir):
return jsonify({"error": "Data directory not found", "path": data_dir})
files = []
for f in os.listdir(data_dir):
file_path = os.path.join(data_dir, f)
files.append({
"name": f,
"path": file_path,
"exists": os.path.exists(file_path),
"size": os.path.getsize(file_path) if os.path.exists(file_path) else 0
})
return jsonify({
"data_directory": data_dir,
"files": files
})
except Exception as e:
return jsonify({"error": str(e)})
def get_model(model_name: str):
if model_name in _MODELS:
print(f"Using cached model: {model_name}")
return _MODELS[model_name]
print(f"Loading new model: {model_name}")
tokenizer, model, device = load_model_and_tokenizer(model_name)
_MODELS[model_name] = (tokenizer, model, device)
return tokenizer, model, device
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"loaded_models": list(_MODELS.keys()),
"dataset_loaded": _CURRENT_DATASET is not None,
"generation_results_available": _GENERATION_RESULTS is not None
})
def _flatten_features(feats, prefix: str = "") -> List[str]:
cols: List[str] = []
try:
items = feats.items() if isinstance(feats, (Features, dict)) else feats.items()
except Exception:
try:
return list(feats.keys())
except Exception:
return cols
for name, sub in items:
full = f"{prefix}.{name}" if prefix else name
try:
if isinstance(sub, (Features, dict)):
cols += _flatten_features(sub, prefix=full)
else:
cols.append(full)
except Exception:
cols.append(full)
return cols
@lru_cache(maxsize=256)
def _get_dataset_fields_cached(dataset_id: str, config: Optional[str], split: str) -> List[str]:
try:
builder = load_dataset_builder(dataset_id, name=config)
feats = builder.info.features
fields = _flatten_features(feats)
return sorted(set(fields))
except Exception as e_builder:
try:
ds = load_dataset(dataset_id, name=config, split=split, streaming=True)
first = next(iter(ds.take(1)), None)
if first is None:
return []
fields = list(first.keys())
return sorted(set(fields))
except Exception as e_stream:
raise RuntimeError(f"builder_error={e_builder}; streaming_error={e_stream}")
@app.route('/dataset/fields', methods=['GET'])
def dataset_fields():
dataset_id = request.args.get('id')
cfg = request.args.get('config')
split = request.args.get('split', 'train')
if not dataset_id:
return jsonify({"error": "Missing required query param 'id'"}), 400
try:
fields = _get_dataset_fields_cached(dataset_id, cfg, split)
return jsonify({
"fields": fields,
"datasetId": dataset_id,
"config": cfg,
"split": split,
"source": "huggingface-builder" if fields else "unknown"
})
except Exception as e:
return jsonify({
"error": "Failed to fetch dataset fields",
"datasetId": dataset_id,
"config": cfg,
"split": split,
"detail": str(e)
}), 400
@app.route('/dataset/meta', methods=['GET'])
def dataset_meta():
dataset_id = request.args.get('id')
if not dataset_id:
return jsonify({"error": "Missing required query param 'id'"}), 400
try:
configs = get_dataset_config_names(dataset_id)
except Exception as e:
configs = []
logging.warning(f"get_dataset_config_names failed for {dataset_id}: {e}")
splits: List[str] = []
try:
if configs:
try:
b0 = load_dataset_builder(dataset_id, name=configs[0])
splits = sorted(list(b0.info.splits) or [])
except Exception:
splits = get_dataset_split_names(dataset_id, configs[0])
else:
try:
b = load_dataset_builder(dataset_id)
splits = sorted(list(b.info.splits) or [])
except Exception:
splits = get_dataset_split_names(dataset_id)
except Exception as e:
logging.warning(f"get splits failed for {dataset_id}: {e}")
splits = []
return jsonify({
"datasetId": dataset_id,
"configs": configs,
"splits": splits
})
@app.route('/dataset/field-stats', methods=['GET'])
def dataset_field_stats():
dataset_id = request.args.get('id')
cfg = request.args.get('config')
split = request.args.get('split', 'train')
field = request.args.get('field')
subfield = request.args.get('subfield')
if not dataset_id or not field:
return jsonify({"error": "Missing required query params 'id' or 'field'"}), 400
try:
ds = load_dataset(dataset_id, name=cfg, split=split, streaming=True)
max_rows = 50000
counter: Dict[str, Any] = {}
print(f"[field-stats] Computing stats for '{field}'" + (f" β†’ '{subfield}'" if subfield else ""))
for i, row in enumerate(ds):
if i >= max_rows:
break
main_val = row.get(field)
if main_val is None:
continue
if subfield:
sub_val = row.get(subfield)
if sub_val is None:
continue
counter.setdefault(main_val, {})
counter[main_val][sub_val] = counter[main_val].get(sub_val, 0) + 1
else:
counter[main_val] = counter.get(main_val, 0) + 1
return jsonify({
"field": field,
"subfield": subfield,
"datasetId": dataset_id,
"config": cfg,
"split": split,
"counts": counter
})
except Exception as e:
return jsonify({
"error": f"Failed to compute field stats: {str(e)}",
"datasetId": dataset_id,
"config": cfg,
"split": split,
"field": field,
"subfield": subfield
}), 500
def _parse_selected_groups_from_config(config: dict) -> List[str]:
raw = config.get('selectedCfFields', []) or []
out: List[str] = []
for s in raw:
s = (s or "").strip()
if not s:
continue
if "/" in s:
out.append(s.split("/")[-1])
else:
out.append(s)
seen = set()
uniq = []
for x in out:
if x not in seen:
uniq.append(x)
seen.add(x)
return uniq
def stratified_sample_by_category(df: pd.DataFrame, category_col: str, groups: List[str], total_n: Optional[int]) -> pd.DataFrame:
if total_n is None or total_n <= 0:
return df
groups_present = [g for g in groups if g in df[category_col].unique()]
if not groups_present:
return df.sample(n=min(total_n, len(df)), random_state=42)
base_each = max(1, total_n // max(1, len(groups_present)))
remainder = max(0, total_n - base_each * len(groups_present))
parts = []
for g in groups_present:
gdf = df[df[category_col] == g]
need = min(base_each, len(gdf))
if need > 0:
parts.append(gdf.sample(n=need, random_state=42))
i = 0
while remainder > 0 and len(df) > 0:
g = groups_present[i % len(groups_present)]
gdf = df[df[category_col] == g]
if len(gdf) > 0:
parts.append(gdf.sample(n=1, replace=(len(gdf) < 1), random_state=42 + remainder))
remainder -= 1
i += 1
out = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=df.columns)
if len(out) < total_n and len(df) > len(out):
rest = min(total_n - len(out), len(df) - len(out))
pool = df.drop(out.index, errors="ignore")
if len(pool) > 0 and rest > 0:
out = pd.concat([out, pool.sample(n=min(rest, len(pool)), random_state=777)], ignore_index=True)
return out
def _pairwise_max_abs_diff(means: Dict[str, float]) -> float:
from itertools import combinations
keys = list(means.keys())
if len(keys) < 2:
return 0.0
diffs = [abs(means[a] - means[b]) for a, b in combinations(keys, 2)]
return float(max(diffs)) if diffs else 0.0
def _mean_by_cat(df: pd.DataFrame, cats: List[str], score_col: str = "sentiment_score") -> Dict[str, float]:
out: Dict[str, float] = {}
for c in cats:
sub = df[df["category"] == c]
if len(sub) > 0:
out[c] = float(sub[score_col].mean())
return out
@app.route('/pipeline', methods=['POST'])
def run_pipeline():
"""Run the complete pipeline with frontend JobConfig format"""
data = request.get_json() or {}
config = data.get('config', data) or {}
print("[DEBUG] Received config:", config)
dataset_id = config.get('dataset') or "AmazonScience/bold"
model_name = config.get('languageModel', 'openai-community/gpt2')
top_k = int(config.get('k', 5))
dataset_limit_raw = config.get('datasetLimit')
dataset_limit = int(dataset_limit_raw) if dataset_limit_raw is not None else None
num_cf_per_row = int(config.get('numCounterfactuals') or 3)
tau = float(config.get('tau', 0.1))
iterations = int(config.get('iterations', 1000))
metric_target = config.get('metrictarget')
try:
results = {}
global _CURRENT_DATASET, _GENERATION_RESULTS
print("Pipeline Step 1: Loading data...")
ds = load_dataset(dataset_id, split="train")
df_full = pd.DataFrame(ds)[["domain", "name", "category", "prompts", "wikipedia"]].copy()
selected_groups = _parse_selected_groups_from_config(config)
present_all = sorted(df_full["category"].dropna().unique().tolist())
if selected_groups:
selected_groups = [g for g in selected_groups if g in present_all]
if len(selected_groups) < 2:
print(f"[Filter] Requested groups not enough in dataset (have {selected_groups}); fallback to ALL categories")
selected_groups = []
else:
print("[Filter] No groups requested from frontend; will use categories present after generation.")
df_pool = df_full[df_full["category"].isin(selected_groups)].copy() if selected_groups else df_full.copy()
df = stratified_sample_by_category(
df=df_pool,
category_col="category",
groups=selected_groups if selected_groups else sorted(df_pool["category"].unique().tolist()),
total_n=dataset_limit
)
print(f"[Pool] pool_size={len(df_pool)}, sampled={len(df)}")
print(f"[Pool] categories in pool: {sorted(df_pool['category'].unique().tolist())}")
print(f"[Pool] categories in sample: {sorted(df['category'].unique().tolist())}")
_CURRENT_DATASET = df
results['data_loaded'] = len(df)
print(f"Dataset loaded: {len(df)} rows")
print("Pipeline Step 2: Loading model...")
tokenizer, model, device = get_model(model_name)
results['model_loaded'] = model_name
print(f"Pipeline Step 3: Generating samples for {len(df)} entries...")
generation_results = generate_topk_samples(model, _CURRENT_DATASET, tokenizer, device, top_k=top_k)
task = config.get('classificationTask', 'sentiment')
tox_choice = config.get('toxicityModelChoice', 'detoxify')
evaluated_results = evaluate_generated_outputs(
generation_results, device,
task=task,
toxicity_model_choice=tox_choice
)
_GENERATION_RESULTS = evaluated_results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("/tmp", exist_ok=True)
output_file = f"/tmp/pipeline_generation_{timestamp}.csv"
evaluated_results.to_csv(output_file, index=False)
results['generation_file'] = output_file
results['generation_samples'] = len(evaluated_results)
print("Pipeline Step 3.5: Counterfactual augmentation...")
augmented_results = generate_counterfactual_augmentations(
evaluated_results,
text_col="generated",
name_col="name",
category_col="category",
num_cf_per_row=num_cf_per_row
)
augmented_file = f"/tmp/pipeline_generation_cf_augmented_{timestamp}.csv"
augmented_results.to_csv(augmented_file, index=False)
results['counterfactual_file'] = augmented_file
results['counterfactual_added'] = len(augmented_results) - len(evaluated_results)
results['counterfactual_total'] = len(augmented_results)
present_after_gen = sorted(evaluated_results["category"].dropna().unique().tolist())
if not selected_groups:
selected_groups_used = present_after_gen
else:
selected_groups_used = [g for g in selected_groups if g in present_after_gen]
if len(selected_groups_used) < 2:
print(f"[Sampling] After generation only {selected_groups_used} present; expanding to all present categories")
selected_groups_used = present_after_gen
print(f"[Sampling] Using groups: {selected_groups_used}")
print("Debug: Checking data before sampling...")
print(f"Total evaluated results: {len(evaluated_results)}")
print(f"Categories in data: {present_after_gen}")
print(f"Names in data: {evaluated_results['name'].unique()}")
for cat in selected_groups_used:
cat_count = int((evaluated_results["category"] == cat).sum())
print(f"Category '{cat}': {cat_count} samples")
print(f"Pipeline Step 4: Rank sampling on original evaluated results...(iterations={iterations}, temp={tau})")
try:
best_sent_subset = rank_sample(evaluated_results, num_samples=iterations, temp=tau, target_value=metric_target)
except (ValueError, IndexError) as e:
print(f"Sampling failed: {e}")
mid_point = len(evaluated_results) // 2
best_sent_subset = evaluated_results.iloc[:mid_point].copy()
sent_file = f"/tmp/pipeline_sent_subset_{timestamp}.csv"
best_sent_subset.to_csv(sent_file, index=False)
print(f"Pipeline Step 5: Rank sampling on CF-augmented results...(iterations={iterations}, temp={tau})")
try:
cf_best_sent_subset = rank_sample(augmented_results, num_samples=iterations, temp=tau, target_value=metric_target)
except (ValueError, IndexError) as e:
print(f"CF Sampling failed: {e}")
mid_point = len(augmented_results) // 2
cf_best_sent_subset = augmented_results.iloc[:mid_point].copy()
cf_sent_file = f"/tmp/pipeline_cf_sent_subset_{timestamp}.csv"
cf_best_sent_subset.to_csv(cf_sent_file, index=False)
orig_means = _mean_by_cat(best_sent_subset, selected_groups_used)
final_mean_diff = _pairwise_max_abs_diff(orig_means)
cf_means = _mean_by_cat(cf_best_sent_subset, selected_groups_used)
cf_final_mean_diff = _pairwise_max_abs_diff(cf_means)
print("Pipeline Step 6: Plotting distributions...")
def _safe(s: str) -> str:
import re
return re.sub(r"[^A-Za-z0-9_.-]+", "_", s)
orig_sent_title = _safe(f"{timestamp}_original_distribution")
cf_sent_title = _safe(f"{timestamp}_cf_distribution")
score_col = None
for c in [
"sentiment_score", "regard_score", "toxicity_score",
"stereotype_gender_score", "stereotype_religion_score",
"stereotype_profession_score", "stereotype_race_score",
"personality_score",
]:
if c in best_sent_subset.columns:
score_col = c
break
if score_col is None:
raise KeyError(f"No score column found. Available: {list(best_sent_subset.columns)}")
orig_path = _ensure_plot_saved(
best_sent_subset, score_col, orig_sent_title,
group_col="category", target=metric_target
)
cf_path = _ensure_plot_saved(
cf_best_sent_subset, score_col, cf_sent_title,
group_col="category", target=metric_target
)
print("[Plot check exists]", orig_path, os.path.exists(orig_path))
print("[Plot check exists]", cf_path, os.path.exists(cf_path))
results['plots'] = {
'original_sentiment': f"/tmp/{orig_sent_title}.png",
'counterfactual_sentiment': f"/tmp/{cf_sent_title}.png",
}
print("[Plot urls]", results['plots'])
if config.get("enableFineTuning"):
print("Pipeline Step 7: Fine-tuning enabled, starting training...")
ft_cfg = config.get("finetuneParams", {}) or {}
epochs = int(ft_cfg.get("epochs", 3))
batch_size = int(ft_cfg.get("batchSize", 8))
lr = float(ft_cfg.get("learningRate", 5e-5))
input_csv = augmented_file
ft_output_dir = f"/tmp/ft_{timestamp}"
os.makedirs(ft_output_dir, exist_ok=True)
try:
from utils.finetune import finetune_gpt2_from_csv
finetune_gpt2_from_csv(
csv_path=input_csv,
output_dir=ft_output_dir,
epochs=epochs,
batch_size=batch_size,
lr=lr
)
print(f"[Fine-tune] Saved fine-tuned model to {ft_output_dir}")
results["finetuned_model_dir"] = ft_output_dir
zip_base = f"/tmp/ft_{timestamp}"
import shutil
zip_path = shutil.make_archive(zip_base, 'zip', ft_output_dir)
results["finetuned_model_zip"] = f"/tmp/{os.path.basename(zip_path)}"
except Exception as fe:
print(f"[Fine-tune] Failed: {fe}")
results["finetuned_model_error"] = str(fe)
results.update({
'sampling_method': 'rank_sentiment_only',
'used_groups': selected_groups_used,
'sentiment_subset_file': sent_file,
'cf_sentiment_subset_file': cf_sent_file,
'sentiment_subset_size': len(best_sent_subset),
'cf_sentiment_subset_size': len(cf_best_sent_subset),
'config_used': config,
'metrics': {
'finalMeanDiff': final_mean_diff,
'cfFinalMeanDiff': cf_final_mean_diff,
'reductionPct': (0.0 if final_mean_diff == 0 else max(0.0, (final_mean_diff - cf_final_mean_diff) / abs(final_mean_diff) * 100.0)),
'stableCoverage': 100.0
}
})
return jsonify({
"status": "success",
"message": "Complete pipeline executed successfully (with counterfactual augmentation)",
"results": results,
"timestamp": timestamp
})
except Exception as e:
print(f"Error in pipeline: {str(e)}")
return jsonify({
"status": "error",
"message": f"Pipeline failed: {str(e)}"
}), 500
if __name__ == '__main__':
os.makedirs("/tmp", exist_ok=True)
print("Starting minimal Flask server...")
print("Available endpoints:")
print(" GET /health - Health check")
print(" GET /dataset/fields?id=<hf_id>[&config=...][&split=...] - List dataset fields")
print(" GET /dataset/field-stats?id=...&field=... - Get value distribution of a field")
print(" GET /dataset/meta?id=<hf_id> - List configs/splits")
print(" POST /pipeline - Run complete pipeline")
app.run(host='0.0.0.0', port=5001, debug=True, threaded=True)