|
from flask import Flask, request, jsonify, send_file, send_from_directory |
|
from flask_cors import CORS |
|
import pandas as pd |
|
import torch |
|
import os |
|
from datetime import datetime |
|
from tqdm import tqdm |
|
import logging |
|
from functools import lru_cache |
|
from typing import Optional, List, Dict, Any |
|
from utils.utils import _ensure_plot_saved |
|
|
|
os.environ["MPLBACKEND"] = "Agg" |
|
os.environ["QT_QPA_PLATFORM"] = "offscreen" |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
from utils.sampling import rank_sample |
|
try: |
|
from transformers import TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
|
print("β transformers training components imported") |
|
except Exception as e: |
|
print(f"β transformers training import failed: {e}") |
|
def finetune(*args, **kwargs): |
|
print("Warning: Transformers training components not available, skipping fine-tuning") |
|
return None |
|
|
|
|
|
try: |
|
from datasets import ( |
|
load_dataset, |
|
load_dataset_builder, |
|
get_dataset_config_names, |
|
get_dataset_split_names, |
|
Features, |
|
) |
|
print("β datasets imported") |
|
except Exception as e: |
|
print(f"β datasets import failed: {e}") |
|
raise |
|
|
|
from utils.utils import ( |
|
generate_topk_samples, |
|
evaluate_generated_outputs, |
|
load_model_and_tokenizer, |
|
generate_counterfactual_augmentations, |
|
) |
|
print("β utils imported") |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
_MODELS = {} |
|
_CURRENT_DATASET = None |
|
_GENERATION_RESULTS = None |
|
|
|
@app.route('/tmp/<path:filename>') |
|
def serve_data(filename): |
|
import os |
|
from flask import Response |
|
|
|
print(f"[Static] Requested file: {filename}") |
|
|
|
data_dir = os.path.abspath('/tmp') |
|
file_path = os.path.join(data_dir, filename) |
|
|
|
print(f"[Static] Full path: {file_path}") |
|
print(f"[Static] File exists: {os.path.exists(file_path)}") |
|
|
|
if not os.path.exists(file_path): |
|
return "File not found", 404 |
|
|
|
try: |
|
with open(file_path, 'rb') as f: |
|
file_data = f.read() |
|
|
|
if filename.endswith('.png'): |
|
mimetype = 'image/png' |
|
elif filename.endswith('.jpg') or filename.endswith('.jpeg'): |
|
mimetype = 'image/jpeg' |
|
elif filename.endswith('.csv'): |
|
mimetype = 'text/csv' |
|
else: |
|
mimetype = 'application/octet-stream' |
|
|
|
print(f"[Static] Serving {len(file_data)} bytes as {mimetype}") |
|
|
|
return Response(file_data, mimetype=mimetype) |
|
|
|
except Exception as e: |
|
print(f"[Static] Error reading file: {e}") |
|
return f"Error reading file: {str(e)}", 500 |
|
|
|
@app.route('/debug/files', methods=['GET']) |
|
def debug_files(): |
|
try: |
|
data_dir = os.path.abspath('/tmp/data') |
|
if not os.path.exists(data_dir): |
|
return jsonify({"error": "Data directory not found", "path": data_dir}) |
|
|
|
files = [] |
|
for f in os.listdir(data_dir): |
|
file_path = os.path.join(data_dir, f) |
|
files.append({ |
|
"name": f, |
|
"path": file_path, |
|
"exists": os.path.exists(file_path), |
|
"size": os.path.getsize(file_path) if os.path.exists(file_path) else 0 |
|
}) |
|
|
|
return jsonify({ |
|
"data_directory": data_dir, |
|
"files": files |
|
}) |
|
except Exception as e: |
|
return jsonify({"error": str(e)}) |
|
|
|
def get_model(model_name: str): |
|
if model_name in _MODELS: |
|
print(f"Using cached model: {model_name}") |
|
return _MODELS[model_name] |
|
print(f"Loading new model: {model_name}") |
|
tokenizer, model, device = load_model_and_tokenizer(model_name) |
|
_MODELS[model_name] = (tokenizer, model, device) |
|
return tokenizer, model, device |
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
def health_check(): |
|
return jsonify({ |
|
"status": "healthy", |
|
"timestamp": datetime.now().isoformat(), |
|
"loaded_models": list(_MODELS.keys()), |
|
"dataset_loaded": _CURRENT_DATASET is not None, |
|
"generation_results_available": _GENERATION_RESULTS is not None |
|
}) |
|
|
|
|
|
def _flatten_features(feats, prefix: str = "") -> List[str]: |
|
cols: List[str] = [] |
|
try: |
|
items = feats.items() if isinstance(feats, (Features, dict)) else feats.items() |
|
except Exception: |
|
try: |
|
return list(feats.keys()) |
|
except Exception: |
|
return cols |
|
for name, sub in items: |
|
full = f"{prefix}.{name}" if prefix else name |
|
try: |
|
if isinstance(sub, (Features, dict)): |
|
cols += _flatten_features(sub, prefix=full) |
|
else: |
|
cols.append(full) |
|
except Exception: |
|
cols.append(full) |
|
return cols |
|
|
|
@lru_cache(maxsize=256) |
|
def _get_dataset_fields_cached(dataset_id: str, config: Optional[str], split: str) -> List[str]: |
|
try: |
|
builder = load_dataset_builder(dataset_id, name=config) |
|
feats = builder.info.features |
|
fields = _flatten_features(feats) |
|
return sorted(set(fields)) |
|
except Exception as e_builder: |
|
try: |
|
ds = load_dataset(dataset_id, name=config, split=split, streaming=True) |
|
first = next(iter(ds.take(1)), None) |
|
if first is None: |
|
return [] |
|
fields = list(first.keys()) |
|
return sorted(set(fields)) |
|
except Exception as e_stream: |
|
raise RuntimeError(f"builder_error={e_builder}; streaming_error={e_stream}") |
|
|
|
@app.route('/dataset/fields', methods=['GET']) |
|
def dataset_fields(): |
|
dataset_id = request.args.get('id') |
|
cfg = request.args.get('config') |
|
split = request.args.get('split', 'train') |
|
if not dataset_id: |
|
return jsonify({"error": "Missing required query param 'id'"}), 400 |
|
try: |
|
fields = _get_dataset_fields_cached(dataset_id, cfg, split) |
|
return jsonify({ |
|
"fields": fields, |
|
"datasetId": dataset_id, |
|
"config": cfg, |
|
"split": split, |
|
"source": "huggingface-builder" if fields else "unknown" |
|
}) |
|
except Exception as e: |
|
return jsonify({ |
|
"error": "Failed to fetch dataset fields", |
|
"datasetId": dataset_id, |
|
"config": cfg, |
|
"split": split, |
|
"detail": str(e) |
|
}), 400 |
|
|
|
@app.route('/dataset/meta', methods=['GET']) |
|
def dataset_meta(): |
|
dataset_id = request.args.get('id') |
|
if not dataset_id: |
|
return jsonify({"error": "Missing required query param 'id'"}), 400 |
|
try: |
|
configs = get_dataset_config_names(dataset_id) |
|
except Exception as e: |
|
configs = [] |
|
logging.warning(f"get_dataset_config_names failed for {dataset_id}: {e}") |
|
splits: List[str] = [] |
|
try: |
|
if configs: |
|
try: |
|
b0 = load_dataset_builder(dataset_id, name=configs[0]) |
|
splits = sorted(list(b0.info.splits) or []) |
|
except Exception: |
|
splits = get_dataset_split_names(dataset_id, configs[0]) |
|
else: |
|
try: |
|
b = load_dataset_builder(dataset_id) |
|
splits = sorted(list(b.info.splits) or []) |
|
except Exception: |
|
splits = get_dataset_split_names(dataset_id) |
|
except Exception as e: |
|
logging.warning(f"get splits failed for {dataset_id}: {e}") |
|
splits = [] |
|
return jsonify({ |
|
"datasetId": dataset_id, |
|
"configs": configs, |
|
"splits": splits |
|
}) |
|
|
|
@app.route('/dataset/field-stats', methods=['GET']) |
|
def dataset_field_stats(): |
|
dataset_id = request.args.get('id') |
|
cfg = request.args.get('config') |
|
split = request.args.get('split', 'train') |
|
field = request.args.get('field') |
|
subfield = request.args.get('subfield') |
|
if not dataset_id or not field: |
|
return jsonify({"error": "Missing required query params 'id' or 'field'"}), 400 |
|
try: |
|
ds = load_dataset(dataset_id, name=cfg, split=split, streaming=True) |
|
max_rows = 50000 |
|
counter: Dict[str, Any] = {} |
|
print(f"[field-stats] Computing stats for '{field}'" + (f" β '{subfield}'" if subfield else "")) |
|
for i, row in enumerate(ds): |
|
if i >= max_rows: |
|
break |
|
main_val = row.get(field) |
|
if main_val is None: |
|
continue |
|
if subfield: |
|
sub_val = row.get(subfield) |
|
if sub_val is None: |
|
continue |
|
counter.setdefault(main_val, {}) |
|
counter[main_val][sub_val] = counter[main_val].get(sub_val, 0) + 1 |
|
else: |
|
counter[main_val] = counter.get(main_val, 0) + 1 |
|
return jsonify({ |
|
"field": field, |
|
"subfield": subfield, |
|
"datasetId": dataset_id, |
|
"config": cfg, |
|
"split": split, |
|
"counts": counter |
|
}) |
|
except Exception as e: |
|
return jsonify({ |
|
"error": f"Failed to compute field stats: {str(e)}", |
|
"datasetId": dataset_id, |
|
"config": cfg, |
|
"split": split, |
|
"field": field, |
|
"subfield": subfield |
|
}), 500 |
|
|
|
def _parse_selected_groups_from_config(config: dict) -> List[str]: |
|
raw = config.get('selectedCfFields', []) or [] |
|
out: List[str] = [] |
|
for s in raw: |
|
s = (s or "").strip() |
|
if not s: |
|
continue |
|
if "/" in s: |
|
out.append(s.split("/")[-1]) |
|
else: |
|
out.append(s) |
|
seen = set() |
|
uniq = [] |
|
for x in out: |
|
if x not in seen: |
|
uniq.append(x) |
|
seen.add(x) |
|
return uniq |
|
|
|
def stratified_sample_by_category(df: pd.DataFrame, category_col: str, groups: List[str], total_n: Optional[int]) -> pd.DataFrame: |
|
if total_n is None or total_n <= 0: |
|
return df |
|
|
|
groups_present = [g for g in groups if g in df[category_col].unique()] |
|
if not groups_present: |
|
return df.sample(n=min(total_n, len(df)), random_state=42) |
|
|
|
base_each = max(1, total_n // max(1, len(groups_present))) |
|
remainder = max(0, total_n - base_each * len(groups_present)) |
|
|
|
parts = [] |
|
for g in groups_present: |
|
gdf = df[df[category_col] == g] |
|
need = min(base_each, len(gdf)) |
|
if need > 0: |
|
parts.append(gdf.sample(n=need, random_state=42)) |
|
|
|
i = 0 |
|
while remainder > 0 and len(df) > 0: |
|
g = groups_present[i % len(groups_present)] |
|
gdf = df[df[category_col] == g] |
|
if len(gdf) > 0: |
|
parts.append(gdf.sample(n=1, replace=(len(gdf) < 1), random_state=42 + remainder)) |
|
remainder -= 1 |
|
i += 1 |
|
|
|
out = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=df.columns) |
|
if len(out) < total_n and len(df) > len(out): |
|
rest = min(total_n - len(out), len(df) - len(out)) |
|
pool = df.drop(out.index, errors="ignore") |
|
if len(pool) > 0 and rest > 0: |
|
out = pd.concat([out, pool.sample(n=min(rest, len(pool)), random_state=777)], ignore_index=True) |
|
return out |
|
|
|
def _pairwise_max_abs_diff(means: Dict[str, float]) -> float: |
|
from itertools import combinations |
|
keys = list(means.keys()) |
|
if len(keys) < 2: |
|
return 0.0 |
|
diffs = [abs(means[a] - means[b]) for a, b in combinations(keys, 2)] |
|
return float(max(diffs)) if diffs else 0.0 |
|
|
|
def _mean_by_cat(df: pd.DataFrame, cats: List[str], score_col: str = "sentiment_score") -> Dict[str, float]: |
|
out: Dict[str, float] = {} |
|
for c in cats: |
|
sub = df[df["category"] == c] |
|
if len(sub) > 0: |
|
out[c] = float(sub[score_col].mean()) |
|
return out |
|
|
|
@app.route('/pipeline', methods=['POST']) |
|
def run_pipeline(): |
|
"""Run the complete pipeline with frontend JobConfig format""" |
|
data = request.get_json() or {} |
|
config = data.get('config', data) or {} |
|
print("[DEBUG] Received config:", config) |
|
|
|
dataset_id = config.get('dataset') or "AmazonScience/bold" |
|
model_name = config.get('languageModel', 'openai-community/gpt2') |
|
top_k = int(config.get('k', 5)) |
|
dataset_limit_raw = config.get('datasetLimit') |
|
dataset_limit = int(dataset_limit_raw) if dataset_limit_raw is not None else None |
|
num_cf_per_row = int(config.get('numCounterfactuals') or 3) |
|
tau = float(config.get('tau', 0.1)) |
|
iterations = int(config.get('iterations', 1000)) |
|
metric_target = config.get('metrictarget') |
|
|
|
try: |
|
results = {} |
|
global _CURRENT_DATASET, _GENERATION_RESULTS |
|
|
|
print("Pipeline Step 1: Loading data...") |
|
ds = load_dataset(dataset_id, split="train") |
|
df_full = pd.DataFrame(ds)[["domain", "name", "category", "prompts", "wikipedia"]].copy() |
|
|
|
selected_groups = _parse_selected_groups_from_config(config) |
|
present_all = sorted(df_full["category"].dropna().unique().tolist()) |
|
|
|
if selected_groups: |
|
selected_groups = [g for g in selected_groups if g in present_all] |
|
if len(selected_groups) < 2: |
|
print(f"[Filter] Requested groups not enough in dataset (have {selected_groups}); fallback to ALL categories") |
|
selected_groups = [] |
|
else: |
|
print("[Filter] No groups requested from frontend; will use categories present after generation.") |
|
|
|
df_pool = df_full[df_full["category"].isin(selected_groups)].copy() if selected_groups else df_full.copy() |
|
|
|
df = stratified_sample_by_category( |
|
df=df_pool, |
|
category_col="category", |
|
groups=selected_groups if selected_groups else sorted(df_pool["category"].unique().tolist()), |
|
total_n=dataset_limit |
|
) |
|
|
|
print(f"[Pool] pool_size={len(df_pool)}, sampled={len(df)}") |
|
print(f"[Pool] categories in pool: {sorted(df_pool['category'].unique().tolist())}") |
|
print(f"[Pool] categories in sample: {sorted(df['category'].unique().tolist())}") |
|
|
|
_CURRENT_DATASET = df |
|
results['data_loaded'] = len(df) |
|
print(f"Dataset loaded: {len(df)} rows") |
|
|
|
print("Pipeline Step 2: Loading model...") |
|
tokenizer, model, device = get_model(model_name) |
|
results['model_loaded'] = model_name |
|
|
|
print(f"Pipeline Step 3: Generating samples for {len(df)} entries...") |
|
generation_results = generate_topk_samples(model, _CURRENT_DATASET, tokenizer, device, top_k=top_k) |
|
task = config.get('classificationTask', 'sentiment') |
|
tox_choice = config.get('toxicityModelChoice', 'detoxify') |
|
|
|
evaluated_results = evaluate_generated_outputs( |
|
generation_results, device, |
|
task=task, |
|
toxicity_model_choice=tox_choice |
|
) |
|
_GENERATION_RESULTS = evaluated_results |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
os.makedirs("/tmp", exist_ok=True) |
|
output_file = f"/tmp/pipeline_generation_{timestamp}.csv" |
|
evaluated_results.to_csv(output_file, index=False) |
|
results['generation_file'] = output_file |
|
results['generation_samples'] = len(evaluated_results) |
|
print("Pipeline Step 3.5: Counterfactual augmentation...") |
|
augmented_results = generate_counterfactual_augmentations( |
|
evaluated_results, |
|
text_col="generated", |
|
name_col="name", |
|
category_col="category", |
|
num_cf_per_row=num_cf_per_row |
|
) |
|
augmented_file = f"/tmp/pipeline_generation_cf_augmented_{timestamp}.csv" |
|
augmented_results.to_csv(augmented_file, index=False) |
|
results['counterfactual_file'] = augmented_file |
|
results['counterfactual_added'] = len(augmented_results) - len(evaluated_results) |
|
results['counterfactual_total'] = len(augmented_results) |
|
|
|
present_after_gen = sorted(evaluated_results["category"].dropna().unique().tolist()) |
|
if not selected_groups: |
|
selected_groups_used = present_after_gen |
|
else: |
|
selected_groups_used = [g for g in selected_groups if g in present_after_gen] |
|
if len(selected_groups_used) < 2: |
|
print(f"[Sampling] After generation only {selected_groups_used} present; expanding to all present categories") |
|
selected_groups_used = present_after_gen |
|
|
|
print(f"[Sampling] Using groups: {selected_groups_used}") |
|
|
|
print("Debug: Checking data before sampling...") |
|
print(f"Total evaluated results: {len(evaluated_results)}") |
|
print(f"Categories in data: {present_after_gen}") |
|
print(f"Names in data: {evaluated_results['name'].unique()}") |
|
|
|
for cat in selected_groups_used: |
|
cat_count = int((evaluated_results["category"] == cat).sum()) |
|
print(f"Category '{cat}': {cat_count} samples") |
|
|
|
print(f"Pipeline Step 4: Rank sampling on original evaluated results...(iterations={iterations}, temp={tau})") |
|
try: |
|
best_sent_subset = rank_sample(evaluated_results, num_samples=iterations, temp=tau, target_value=metric_target) |
|
except (ValueError, IndexError) as e: |
|
print(f"Sampling failed: {e}") |
|
mid_point = len(evaluated_results) // 2 |
|
best_sent_subset = evaluated_results.iloc[:mid_point].copy() |
|
|
|
sent_file = f"/tmp/pipeline_sent_subset_{timestamp}.csv" |
|
best_sent_subset.to_csv(sent_file, index=False) |
|
|
|
print(f"Pipeline Step 5: Rank sampling on CF-augmented results...(iterations={iterations}, temp={tau})") |
|
try: |
|
cf_best_sent_subset = rank_sample(augmented_results, num_samples=iterations, temp=tau, target_value=metric_target) |
|
except (ValueError, IndexError) as e: |
|
print(f"CF Sampling failed: {e}") |
|
mid_point = len(augmented_results) // 2 |
|
cf_best_sent_subset = augmented_results.iloc[:mid_point].copy() |
|
|
|
cf_sent_file = f"/tmp/pipeline_cf_sent_subset_{timestamp}.csv" |
|
cf_best_sent_subset.to_csv(cf_sent_file, index=False) |
|
|
|
orig_means = _mean_by_cat(best_sent_subset, selected_groups_used) |
|
final_mean_diff = _pairwise_max_abs_diff(orig_means) |
|
|
|
cf_means = _mean_by_cat(cf_best_sent_subset, selected_groups_used) |
|
cf_final_mean_diff = _pairwise_max_abs_diff(cf_means) |
|
|
|
print("Pipeline Step 6: Plotting distributions...") |
|
|
|
def _safe(s: str) -> str: |
|
import re |
|
return re.sub(r"[^A-Za-z0-9_.-]+", "_", s) |
|
|
|
orig_sent_title = _safe(f"{timestamp}_original_distribution") |
|
cf_sent_title = _safe(f"{timestamp}_cf_distribution") |
|
|
|
score_col = None |
|
for c in [ |
|
"sentiment_score", "regard_score", "toxicity_score", |
|
"stereotype_gender_score", "stereotype_religion_score", |
|
"stereotype_profession_score", "stereotype_race_score", |
|
"personality_score", |
|
]: |
|
if c in best_sent_subset.columns: |
|
score_col = c |
|
break |
|
if score_col is None: |
|
raise KeyError(f"No score column found. Available: {list(best_sent_subset.columns)}") |
|
|
|
orig_path = _ensure_plot_saved( |
|
best_sent_subset, score_col, orig_sent_title, |
|
group_col="category", target=metric_target |
|
) |
|
cf_path = _ensure_plot_saved( |
|
cf_best_sent_subset, score_col, cf_sent_title, |
|
group_col="category", target=metric_target |
|
) |
|
print("[Plot check exists]", orig_path, os.path.exists(orig_path)) |
|
print("[Plot check exists]", cf_path, os.path.exists(cf_path)) |
|
|
|
results['plots'] = { |
|
'original_sentiment': f"/tmp/{orig_sent_title}.png", |
|
'counterfactual_sentiment': f"/tmp/{cf_sent_title}.png", |
|
} |
|
|
|
print("[Plot urls]", results['plots']) |
|
|
|
if config.get("enableFineTuning"): |
|
print("Pipeline Step 7: Fine-tuning enabled, starting training...") |
|
|
|
ft_cfg = config.get("finetuneParams", {}) or {} |
|
epochs = int(ft_cfg.get("epochs", 3)) |
|
batch_size = int(ft_cfg.get("batchSize", 8)) |
|
lr = float(ft_cfg.get("learningRate", 5e-5)) |
|
|
|
input_csv = augmented_file |
|
ft_output_dir = f"/tmp/ft_{timestamp}" |
|
os.makedirs(ft_output_dir, exist_ok=True) |
|
|
|
try: |
|
from utils.finetune import finetune_gpt2_from_csv |
|
finetune_gpt2_from_csv( |
|
csv_path=input_csv, |
|
output_dir=ft_output_dir, |
|
epochs=epochs, |
|
batch_size=batch_size, |
|
lr=lr |
|
) |
|
print(f"[Fine-tune] Saved fine-tuned model to {ft_output_dir}") |
|
results["finetuned_model_dir"] = ft_output_dir |
|
zip_base = f"/tmp/ft_{timestamp}" |
|
import shutil |
|
zip_path = shutil.make_archive(zip_base, 'zip', ft_output_dir) |
|
results["finetuned_model_zip"] = f"/tmp/{os.path.basename(zip_path)}" |
|
except Exception as fe: |
|
print(f"[Fine-tune] Failed: {fe}") |
|
results["finetuned_model_error"] = str(fe) |
|
|
|
|
|
results.update({ |
|
'sampling_method': 'rank_sentiment_only', |
|
'used_groups': selected_groups_used, |
|
'sentiment_subset_file': sent_file, |
|
'cf_sentiment_subset_file': cf_sent_file, |
|
'sentiment_subset_size': len(best_sent_subset), |
|
'cf_sentiment_subset_size': len(cf_best_sent_subset), |
|
'config_used': config, |
|
'metrics': { |
|
'finalMeanDiff': final_mean_diff, |
|
'cfFinalMeanDiff': cf_final_mean_diff, |
|
'reductionPct': (0.0 if final_mean_diff == 0 else max(0.0, (final_mean_diff - cf_final_mean_diff) / abs(final_mean_diff) * 100.0)), |
|
'stableCoverage': 100.0 |
|
} |
|
}) |
|
|
|
return jsonify({ |
|
"status": "success", |
|
"message": "Complete pipeline executed successfully (with counterfactual augmentation)", |
|
"results": results, |
|
"timestamp": timestamp |
|
}) |
|
|
|
except Exception as e: |
|
print(f"Error in pipeline: {str(e)}") |
|
return jsonify({ |
|
"status": "error", |
|
"message": f"Pipeline failed: {str(e)}" |
|
}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
os.makedirs("/tmp", exist_ok=True) |
|
print("Starting minimal Flask server...") |
|
print("Available endpoints:") |
|
print(" GET /health - Health check") |
|
print(" GET /dataset/fields?id=<hf_id>[&config=...][&split=...] - List dataset fields") |
|
print(" GET /dataset/field-stats?id=...&field=... - Get value distribution of a field") |
|
print(" GET /dataset/meta?id=<hf_id> - List configs/splits") |
|
print(" POST /pipeline - Run complete pipeline") |
|
app.run(host='0.0.0.0', port=5001, debug=True, threaded=True) |
|
|