from flask import Flask, request, jsonify, send_file, send_from_directory from flask_cors import CORS import pandas as pd import torch import os from datetime import datetime from tqdm import tqdm import logging from functools import lru_cache from typing import Optional, List, Dict, Any from utils.utils import _ensure_plot_saved os.environ["MPLBACKEND"] = "Agg" os.environ["QT_QPA_PLATFORM"] = "offscreen" logging.basicConfig(level=logging.INFO) from utils.sampling import rank_sample try: from transformers import TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments print("✓ transformers training components imported") except Exception as e: print(f"✗ transformers training import failed: {e}") def finetune(*args, **kwargs): print("Warning: Transformers training components not available, skipping fine-tuning") return None # 🤗 datasets try: from datasets import ( load_dataset, load_dataset_builder, get_dataset_config_names, get_dataset_split_names, Features, ) print("✓ datasets imported") except Exception as e: print(f"✗ datasets import failed: {e}") raise from utils.utils import ( generate_topk_samples, evaluate_generated_outputs, load_model_and_tokenizer, generate_counterfactual_augmentations, ) print("✓ utils imported") app = Flask(__name__) CORS(app) _MODELS = {} _CURRENT_DATASET = None _GENERATION_RESULTS = None @app.route('/tmp/') def serve_data(filename): import os from flask import Response print(f"[Static] Requested file: {filename}") data_dir = os.path.abspath('/tmp') file_path = os.path.join(data_dir, filename) print(f"[Static] Full path: {file_path}") print(f"[Static] File exists: {os.path.exists(file_path)}") if not os.path.exists(file_path): return "File not found", 404 try: with open(file_path, 'rb') as f: file_data = f.read() if filename.endswith('.png'): mimetype = 'image/png' elif filename.endswith('.jpg') or filename.endswith('.jpeg'): mimetype = 'image/jpeg' elif filename.endswith('.csv'): mimetype = 'text/csv' else: mimetype = 'application/octet-stream' print(f"[Static] Serving {len(file_data)} bytes as {mimetype}") return Response(file_data, mimetype=mimetype) except Exception as e: print(f"[Static] Error reading file: {e}") return f"Error reading file: {str(e)}", 500 @app.route('/debug/files', methods=['GET']) def debug_files(): try: data_dir = os.path.abspath('/tmp/data') if not os.path.exists(data_dir): return jsonify({"error": "Data directory not found", "path": data_dir}) files = [] for f in os.listdir(data_dir): file_path = os.path.join(data_dir, f) files.append({ "name": f, "path": file_path, "exists": os.path.exists(file_path), "size": os.path.getsize(file_path) if os.path.exists(file_path) else 0 }) return jsonify({ "data_directory": data_dir, "files": files }) except Exception as e: return jsonify({"error": str(e)}) def get_model(model_name: str): if model_name in _MODELS: print(f"Using cached model: {model_name}") return _MODELS[model_name] print(f"Loading new model: {model_name}") tokenizer, model, device = load_model_and_tokenizer(model_name) _MODELS[model_name] = (tokenizer, model, device) return tokenizer, model, device @app.route('/health', methods=['GET']) def health_check(): return jsonify({ "status": "healthy", "timestamp": datetime.now().isoformat(), "loaded_models": list(_MODELS.keys()), "dataset_loaded": _CURRENT_DATASET is not None, "generation_results_available": _GENERATION_RESULTS is not None }) def _flatten_features(feats, prefix: str = "") -> List[str]: cols: List[str] = [] try: items = feats.items() if isinstance(feats, (Features, dict)) else feats.items() except Exception: try: return list(feats.keys()) except Exception: return cols for name, sub in items: full = f"{prefix}.{name}" if prefix else name try: if isinstance(sub, (Features, dict)): cols += _flatten_features(sub, prefix=full) else: cols.append(full) except Exception: cols.append(full) return cols @lru_cache(maxsize=256) def _get_dataset_fields_cached(dataset_id: str, config: Optional[str], split: str) -> List[str]: try: builder = load_dataset_builder(dataset_id, name=config) feats = builder.info.features fields = _flatten_features(feats) return sorted(set(fields)) except Exception as e_builder: try: ds = load_dataset(dataset_id, name=config, split=split, streaming=True) first = next(iter(ds.take(1)), None) if first is None: return [] fields = list(first.keys()) return sorted(set(fields)) except Exception as e_stream: raise RuntimeError(f"builder_error={e_builder}; streaming_error={e_stream}") @app.route('/dataset/fields', methods=['GET']) def dataset_fields(): dataset_id = request.args.get('id') cfg = request.args.get('config') split = request.args.get('split', 'train') if not dataset_id: return jsonify({"error": "Missing required query param 'id'"}), 400 try: fields = _get_dataset_fields_cached(dataset_id, cfg, split) return jsonify({ "fields": fields, "datasetId": dataset_id, "config": cfg, "split": split, "source": "huggingface-builder" if fields else "unknown" }) except Exception as e: return jsonify({ "error": "Failed to fetch dataset fields", "datasetId": dataset_id, "config": cfg, "split": split, "detail": str(e) }), 400 @app.route('/dataset/meta', methods=['GET']) def dataset_meta(): dataset_id = request.args.get('id') if not dataset_id: return jsonify({"error": "Missing required query param 'id'"}), 400 try: configs = get_dataset_config_names(dataset_id) except Exception as e: configs = [] logging.warning(f"get_dataset_config_names failed for {dataset_id}: {e}") splits: List[str] = [] try: if configs: try: b0 = load_dataset_builder(dataset_id, name=configs[0]) splits = sorted(list(b0.info.splits) or []) except Exception: splits = get_dataset_split_names(dataset_id, configs[0]) else: try: b = load_dataset_builder(dataset_id) splits = sorted(list(b.info.splits) or []) except Exception: splits = get_dataset_split_names(dataset_id) except Exception as e: logging.warning(f"get splits failed for {dataset_id}: {e}") splits = [] return jsonify({ "datasetId": dataset_id, "configs": configs, "splits": splits }) @app.route('/dataset/field-stats', methods=['GET']) def dataset_field_stats(): dataset_id = request.args.get('id') cfg = request.args.get('config') split = request.args.get('split', 'train') field = request.args.get('field') subfield = request.args.get('subfield') if not dataset_id or not field: return jsonify({"error": "Missing required query params 'id' or 'field'"}), 400 try: ds = load_dataset(dataset_id, name=cfg, split=split, streaming=True) max_rows = 50000 counter: Dict[str, Any] = {} print(f"[field-stats] Computing stats for '{field}'" + (f" → '{subfield}'" if subfield else "")) for i, row in enumerate(ds): if i >= max_rows: break main_val = row.get(field) if main_val is None: continue if subfield: sub_val = row.get(subfield) if sub_val is None: continue counter.setdefault(main_val, {}) counter[main_val][sub_val] = counter[main_val].get(sub_val, 0) + 1 else: counter[main_val] = counter.get(main_val, 0) + 1 return jsonify({ "field": field, "subfield": subfield, "datasetId": dataset_id, "config": cfg, "split": split, "counts": counter }) except Exception as e: return jsonify({ "error": f"Failed to compute field stats: {str(e)}", "datasetId": dataset_id, "config": cfg, "split": split, "field": field, "subfield": subfield }), 500 def _parse_selected_groups_from_config(config: dict) -> List[str]: raw = config.get('selectedCfFields', []) or [] out: List[str] = [] for s in raw: s = (s or "").strip() if not s: continue if "/" in s: out.append(s.split("/")[-1]) else: out.append(s) seen = set() uniq = [] for x in out: if x not in seen: uniq.append(x) seen.add(x) return uniq def stratified_sample_by_category(df: pd.DataFrame, category_col: str, groups: List[str], total_n: Optional[int]) -> pd.DataFrame: if total_n is None or total_n <= 0: return df groups_present = [g for g in groups if g in df[category_col].unique()] if not groups_present: return df.sample(n=min(total_n, len(df)), random_state=42) base_each = max(1, total_n // max(1, len(groups_present))) remainder = max(0, total_n - base_each * len(groups_present)) parts = [] for g in groups_present: gdf = df[df[category_col] == g] need = min(base_each, len(gdf)) if need > 0: parts.append(gdf.sample(n=need, random_state=42)) i = 0 while remainder > 0 and len(df) > 0: g = groups_present[i % len(groups_present)] gdf = df[df[category_col] == g] if len(gdf) > 0: parts.append(gdf.sample(n=1, replace=(len(gdf) < 1), random_state=42 + remainder)) remainder -= 1 i += 1 out = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=df.columns) if len(out) < total_n and len(df) > len(out): rest = min(total_n - len(out), len(df) - len(out)) pool = df.drop(out.index, errors="ignore") if len(pool) > 0 and rest > 0: out = pd.concat([out, pool.sample(n=min(rest, len(pool)), random_state=777)], ignore_index=True) return out def _pairwise_max_abs_diff(means: Dict[str, float]) -> float: from itertools import combinations keys = list(means.keys()) if len(keys) < 2: return 0.0 diffs = [abs(means[a] - means[b]) for a, b in combinations(keys, 2)] return float(max(diffs)) if diffs else 0.0 def _mean_by_cat(df: pd.DataFrame, cats: List[str], score_col: str = "sentiment_score") -> Dict[str, float]: out: Dict[str, float] = {} for c in cats: sub = df[df["category"] == c] if len(sub) > 0: out[c] = float(sub[score_col].mean()) return out @app.route('/pipeline', methods=['POST']) def run_pipeline(): """Run the complete pipeline with frontend JobConfig format""" data = request.get_json() or {} config = data.get('config', data) or {} print("[DEBUG] Received config:", config) dataset_id = config.get('dataset') or "AmazonScience/bold" model_name = config.get('languageModel', 'openai-community/gpt2') top_k = int(config.get('k', 5)) dataset_limit_raw = config.get('datasetLimit') dataset_limit = int(dataset_limit_raw) if dataset_limit_raw is not None else None num_cf_per_row = int(config.get('numCounterfactuals') or 3) tau = float(config.get('tau', 0.1)) iterations = int(config.get('iterations', 1000)) metric_target = config.get('metrictarget') try: results = {} global _CURRENT_DATASET, _GENERATION_RESULTS print("Pipeline Step 1: Loading data...") ds = load_dataset(dataset_id, split="train") df_full = pd.DataFrame(ds)[["domain", "name", "category", "prompts", "wikipedia"]].copy() selected_groups = _parse_selected_groups_from_config(config) present_all = sorted(df_full["category"].dropna().unique().tolist()) if selected_groups: selected_groups = [g for g in selected_groups if g in present_all] if len(selected_groups) < 2: print(f"[Filter] Requested groups not enough in dataset (have {selected_groups}); fallback to ALL categories") selected_groups = [] else: print("[Filter] No groups requested from frontend; will use categories present after generation.") df_pool = df_full[df_full["category"].isin(selected_groups)].copy() if selected_groups else df_full.copy() df = stratified_sample_by_category( df=df_pool, category_col="category", groups=selected_groups if selected_groups else sorted(df_pool["category"].unique().tolist()), total_n=dataset_limit ) print(f"[Pool] pool_size={len(df_pool)}, sampled={len(df)}") print(f"[Pool] categories in pool: {sorted(df_pool['category'].unique().tolist())}") print(f"[Pool] categories in sample: {sorted(df['category'].unique().tolist())}") _CURRENT_DATASET = df results['data_loaded'] = len(df) print(f"Dataset loaded: {len(df)} rows") print("Pipeline Step 2: Loading model...") tokenizer, model, device = get_model(model_name) results['model_loaded'] = model_name print(f"Pipeline Step 3: Generating samples for {len(df)} entries...") generation_results = generate_topk_samples(model, _CURRENT_DATASET, tokenizer, device, top_k=top_k) task = config.get('classificationTask', 'sentiment') tox_choice = config.get('toxicityModelChoice', 'detoxify') evaluated_results = evaluate_generated_outputs( generation_results, device, task=task, toxicity_model_choice=tox_choice ) _GENERATION_RESULTS = evaluated_results timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs("/tmp", exist_ok=True) output_file = f"/tmp/pipeline_generation_{timestamp}.csv" evaluated_results.to_csv(output_file, index=False) results['generation_file'] = output_file results['generation_samples'] = len(evaluated_results) print("Pipeline Step 3.5: Counterfactual augmentation...") augmented_results = generate_counterfactual_augmentations( evaluated_results, text_col="generated", name_col="name", category_col="category", num_cf_per_row=num_cf_per_row ) augmented_file = f"/tmp/pipeline_generation_cf_augmented_{timestamp}.csv" augmented_results.to_csv(augmented_file, index=False) results['counterfactual_file'] = augmented_file results['counterfactual_added'] = len(augmented_results) - len(evaluated_results) results['counterfactual_total'] = len(augmented_results) present_after_gen = sorted(evaluated_results["category"].dropna().unique().tolist()) if not selected_groups: selected_groups_used = present_after_gen else: selected_groups_used = [g for g in selected_groups if g in present_after_gen] if len(selected_groups_used) < 2: print(f"[Sampling] After generation only {selected_groups_used} present; expanding to all present categories") selected_groups_used = present_after_gen print(f"[Sampling] Using groups: {selected_groups_used}") print("Debug: Checking data before sampling...") print(f"Total evaluated results: {len(evaluated_results)}") print(f"Categories in data: {present_after_gen}") print(f"Names in data: {evaluated_results['name'].unique()}") for cat in selected_groups_used: cat_count = int((evaluated_results["category"] == cat).sum()) print(f"Category '{cat}': {cat_count} samples") print(f"Pipeline Step 4: Rank sampling on original evaluated results...(iterations={iterations}, temp={tau})") try: best_sent_subset = rank_sample(evaluated_results, num_samples=iterations, temp=tau, target_value=metric_target) except (ValueError, IndexError) as e: print(f"Sampling failed: {e}") mid_point = len(evaluated_results) // 2 best_sent_subset = evaluated_results.iloc[:mid_point].copy() sent_file = f"/tmp/pipeline_sent_subset_{timestamp}.csv" best_sent_subset.to_csv(sent_file, index=False) print(f"Pipeline Step 5: Rank sampling on CF-augmented results...(iterations={iterations}, temp={tau})") try: cf_best_sent_subset = rank_sample(augmented_results, num_samples=iterations, temp=tau, target_value=metric_target) except (ValueError, IndexError) as e: print(f"CF Sampling failed: {e}") mid_point = len(augmented_results) // 2 cf_best_sent_subset = augmented_results.iloc[:mid_point].copy() cf_sent_file = f"/tmp/pipeline_cf_sent_subset_{timestamp}.csv" cf_best_sent_subset.to_csv(cf_sent_file, index=False) orig_means = _mean_by_cat(best_sent_subset, selected_groups_used) final_mean_diff = _pairwise_max_abs_diff(orig_means) cf_means = _mean_by_cat(cf_best_sent_subset, selected_groups_used) cf_final_mean_diff = _pairwise_max_abs_diff(cf_means) print("Pipeline Step 6: Plotting distributions...") def _safe(s: str) -> str: import re return re.sub(r"[^A-Za-z0-9_.-]+", "_", s) orig_sent_title = _safe(f"{timestamp}_original_distribution") cf_sent_title = _safe(f"{timestamp}_cf_distribution") score_col = None for c in [ "sentiment_score", "regard_score", "toxicity_score", "stereotype_gender_score", "stereotype_religion_score", "stereotype_profession_score", "stereotype_race_score", "personality_score", ]: if c in best_sent_subset.columns: score_col = c break if score_col is None: raise KeyError(f"No score column found. Available: {list(best_sent_subset.columns)}") orig_path = _ensure_plot_saved( best_sent_subset, score_col, orig_sent_title, group_col="category", target=metric_target ) cf_path = _ensure_plot_saved( cf_best_sent_subset, score_col, cf_sent_title, group_col="category", target=metric_target ) print("[Plot check exists]", orig_path, os.path.exists(orig_path)) print("[Plot check exists]", cf_path, os.path.exists(cf_path)) results['plots'] = { 'original_sentiment': f"/tmp/{orig_sent_title}.png", 'counterfactual_sentiment': f"/tmp/{cf_sent_title}.png", } print("[Plot urls]", results['plots']) if config.get("enableFineTuning"): print("Pipeline Step 7: Fine-tuning enabled, starting training...") ft_cfg = config.get("finetuneParams", {}) or {} epochs = int(ft_cfg.get("epochs", 3)) batch_size = int(ft_cfg.get("batchSize", 8)) lr = float(ft_cfg.get("learningRate", 5e-5)) input_csv = augmented_file ft_output_dir = f"/tmp/ft_{timestamp}" os.makedirs(ft_output_dir, exist_ok=True) try: from utils.finetune import finetune_gpt2_from_csv finetune_gpt2_from_csv( csv_path=input_csv, output_dir=ft_output_dir, epochs=epochs, batch_size=batch_size, lr=lr ) print(f"[Fine-tune] Saved fine-tuned model to {ft_output_dir}") results["finetuned_model_dir"] = ft_output_dir zip_base = f"/tmp/ft_{timestamp}" import shutil zip_path = shutil.make_archive(zip_base, 'zip', ft_output_dir) results["finetuned_model_zip"] = f"/tmp/{os.path.basename(zip_path)}" except Exception as fe: print(f"[Fine-tune] Failed: {fe}") results["finetuned_model_error"] = str(fe) results.update({ 'sampling_method': 'rank_sentiment_only', 'used_groups': selected_groups_used, 'sentiment_subset_file': sent_file, 'cf_sentiment_subset_file': cf_sent_file, 'sentiment_subset_size': len(best_sent_subset), 'cf_sentiment_subset_size': len(cf_best_sent_subset), 'config_used': config, 'metrics': { 'finalMeanDiff': final_mean_diff, 'cfFinalMeanDiff': cf_final_mean_diff, 'reductionPct': (0.0 if final_mean_diff == 0 else max(0.0, (final_mean_diff - cf_final_mean_diff) / abs(final_mean_diff) * 100.0)), 'stableCoverage': 100.0 } }) return jsonify({ "status": "success", "message": "Complete pipeline executed successfully (with counterfactual augmentation)", "results": results, "timestamp": timestamp }) except Exception as e: print(f"Error in pipeline: {str(e)}") return jsonify({ "status": "error", "message": f"Pipeline failed: {str(e)}" }), 500 if __name__ == '__main__': os.makedirs("/tmp", exist_ok=True) print("Starting minimal Flask server...") print("Available endpoints:") print(" GET /health - Health check") print(" GET /dataset/fields?id=[&config=...][&split=...] - List dataset fields") print(" GET /dataset/field-stats?id=...&field=... - Get value distribution of a field") print(" GET /dataset/meta?id= - List configs/splits") print(" POST /pipeline - Run complete pipeline") app.run(host='0.0.0.0', port=5001, debug=True, threaded=True)