luulinh90s's picture
update
465850b
import uuid
from flask import Flask, render_template, request, redirect, url_for, send_from_directory
import json
import random
import os
import string
import logging
from datetime import datetime
from huggingface_hub import login, HfApi, hf_hub_download
from statistics import mean
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
])
logger = logging.getLogger(__name__)
# Use the Hugging Face token from environment variables
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
logger.error("HF_TOKEN not found in environment variables")
app = Flask(__name__)
app.config['SECRET_KEY'] = 'supersecretkey'
# File-based session storage
SESSION_DIR = '/tmp/sessions'
os.makedirs(SESSION_DIR, exist_ok=True)
# Update visualization directories for the 4 methods
VISUALIZATION_DIRS = {
"Text2SQL": "htmls_Text2SQL",
"Dater": "htmls_DATER_mod2",
"Chain-of-Table": "htmls_COT_mod",
"Plan-of-SQLs": "htmls_POS_mod2"
}
# Update method directory mapping
def get_method_dir(method):
method_mapping = {
'Text2SQL': 'Text2SQL',
'Dater': 'DATER',
'Chain-of-Table': 'COT',
'Plan-of-SQLs': 'POS'
}
return method_mapping.get(method)
# Update methods list to only include the 4 methods we want to rank
METHODS = ["Text2SQL", "Dater", "Chain-of-Table", "Plan-of-SQLs"]
def generate_session_id():
return str(uuid.uuid4())
def save_session_data(session_id, data):
file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
with open(file_path, 'w') as f:
json.dump(data, f)
logger.info(f"Session data saved for session {session_id}")
def load_session_data(session_id):
file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
if os.path.exists(file_path):
with open(file_path, 'r') as f:
return json.load(f)
return None
def save_session_data_to_hf(session_id, data):
try:
username = data.get('username', 'unknown')
seed = data.get('seed', 'unknown')
start_time = data.get('start_time', datetime.now().isoformat())
file_name = f'{username}_seed{seed}_{start_time}_{session_id}_session.json'
file_name = "".join(c for c in file_name if c.isalnum() or c in ['_', '-', '.'])
json_data = json.dumps(data, indent=4)
temp_file_path = f"/tmp/{file_name}"
with open(temp_file_path, 'w') as f:
f.write(json_data)
api = HfApi()
repo_path = "session_data_preference_ranking"
api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=f"{repo_path}/{file_name}",
repo_id="luulinh90s/Tabular-LLM-Study-Data",
repo_type="space",
)
os.remove(temp_file_path)
logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space")
except Exception as e:
logger.exception(f"Error saving session data for session {session_id}: {e}")
def load_samples_for_all_methods(metadata_files):
samples_by_method = {}
common_samples = []
# First, load all samples for each method
for method in METHODS:
method_samples = []
categories = ["TP", "TN", "FP", "FN"]
for category in categories:
method_dir = VISUALIZATION_DIRS[method]
try:
files = set(os.listdir(f'{method_dir}/{category}'))
for file in files:
index = file.split('-')[1].split('.')[0]
metadata_key = f"{get_method_dir(method)}_test-{index}.html"
# Get metadata for this sample
sample_metadata = metadata_files[method].get(metadata_key, {})
method_samples.append({
'category': category,
'file': file,
'metadata': sample_metadata
})
except Exception as e:
logger.error(f"Error loading samples for method {method}, category {category}: {e}")
samples_by_method[method] = method_samples
# Find common samples across all methods
file_sets = []
for method, samples in samples_by_method.items():
file_set = {s['file'] for s in samples}
file_sets.append(file_set)
common_files = set.intersection(*file_sets)
# Create groups of samples that exist across all methods
for file_name in common_files:
sample_group = {}
for method in METHODS:
sample = next((s for s in samples_by_method[method] if s['file'] == file_name), None)
if sample:
sample_group[method] = sample
if len(sample_group) == len(METHODS):
common_samples.append(sample_group)
return common_samples
def select_balanced_samples(samples):
try:
# Get the category from any method (they should all be the same)
sample_categories = [(s, next(iter(s.values()))['category']) for s in samples]
# Separate samples into two groups
tp_fp_samples = [s for s, cat in sample_categories if cat in ['TP', 'FP']]
tn_fn_samples = [s for s, cat in sample_categories if cat in ['TN', 'FN']]
# Select balanced samples
if len(tp_fp_samples) >= 5 and len(tn_fn_samples) >= 5:
selected_tp_fp = random.sample(tp_fp_samples, 5)
selected_tn_fn = random.sample(tn_fn_samples, 5)
selected_samples = selected_tp_fp + selected_tn_fn
random.shuffle(selected_samples)
else:
logger.warning(
f"Not enough samples for balanced selection. TP+FP: {len(tp_fp_samples)}, TN+FN: {len(tn_fn_samples)}")
selected_samples = random.sample(samples, min(10, len(samples)))
return selected_samples
except Exception as e:
logger.exception("Error selecting balanced samples")
return []
@app.route('/')
def root():
return redirect(url_for('consent'))
@app.route('/consent', methods=['GET', 'POST'])
def consent():
if request.method == 'POST':
return redirect(url_for('introduction'))
return render_template('consent.html')
@app.route('/introduction')
def introduction():
return render_template('introduction.html')
@app.route('/attribution')
def attribution():
return render_template('attribution.html')
@app.route('/index', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
username = request.form.get('username')
seed = request.form.get('seed')
if not username or not seed:
return render_template('index.html', error="Please fill in all fields.")
try:
seed = int(seed)
random.seed(seed)
# Load metadata for all methods
metadata_files = {}
for method in METHODS:
json_file = f'Tabular_LLMs_human_study_vis_6_{get_method_dir(method)}.json'
with open(json_file, 'r') as f:
metadata_files[method] = json.load(f)
# Load and select samples
all_samples = load_samples_for_all_methods(metadata_files)
selected_samples = select_balanced_samples(all_samples)
if len(selected_samples) == 0:
return render_template('index.html', error="No common samples were found")
# Create session
session_id = generate_session_id()
session_data = {
'username': username,
'seed': str(seed),
'selected_samples': selected_samples,
'current_index': 0,
'responses': [],
'start_time': datetime.now().isoformat(),
'session_id': session_id
}
save_session_data(session_id, session_data)
return redirect(url_for('experiment', session_id=session_id))
except Exception as e:
logger.exception(f"Error in index route: {e}")
return render_template('index.html', error="An error occurred. Please try again.")
return render_template('index.html')
@app.route('/experiment/<session_id>', methods=['GET', 'POST'])
def experiment(session_id):
try:
session_data = load_session_data(session_id)
if not session_data:
return redirect(url_for('index'))
selected_samples = session_data['selected_samples']
current_index = session_data['current_index']
if current_index >= len(selected_samples):
return redirect(url_for('completed', session_id=session_id))
if request.method == 'POST':
# Validate and save rankings
rankings = {method: int(request.form.get(method)) for method in METHODS}
if not all(1 <= rank <= 4 for rank in rankings.values()):
return "Invalid rankings. Please use numbers 1-4.", 400
if len(set(rankings.values())) != 4:
return "Each method must have a unique rank.", 400
session_data['responses'].append({
'sample_id': current_index,
'rankings': rankings
})
session_data['current_index'] += 1
save_session_data(session_id, session_data)
return redirect(url_for('experiment', session_id=session_id))
# Get current sample group and prepare visualizations
sample_group = selected_samples[current_index]
visualizations = {
method: url_for('send_visualization',
filename=f"{VISUALIZATION_DIRS[method]}/{sample['category']}/{sample['file']}")
for method, sample in sample_group.items()
}
# Get metadata from any method (they should all have the same statement)
sample_metadata = next(iter(sample_group.values()))['metadata']
statement = sample_metadata.get('statement', '')
return render_template('experiment.html',
sample_id=current_index,
statement=statement,
visualizations=visualizations,
methods=METHODS,
session_id=session_id)
except Exception as e:
logger.exception(f"An error occurred in the experiment route: {e}")
return "An error occurred", 500
@app.route('/completed/<session_id>')
def completed(session_id):
try:
session_data = load_session_data(session_id)
if not session_data:
return redirect(url_for('index'))
session_data['end_time'] = datetime.now().isoformat()
responses = session_data['responses']
# Calculate average ranking for each method
average_rankings = {
method: mean(r['rankings'][method] for r in responses)
for method in METHODS
}
# Sort methods by average ranking (ascending)
sorted_methods = sorted(
average_rankings.items(),
key=lambda x: x[1]
)
session_data['average_rankings'] = average_rankings
save_session_data_to_hf(session_id, session_data)
# Clean up local session file
try:
os.remove(os.path.join(SESSION_DIR, f'{session_id}.json'))
except Exception as e:
logger.warning(f"Error removing session file: {e}")
return render_template(
'completed.html',
average_rankings=average_rankings,
sorted_methods=sorted_methods
)
except Exception as e:
logger.exception(f"An error occurred in the completed route: {e}")
return "An error occurred", 500
@app.route('/visualizations/<path:filename>')
def send_visualization(filename):
base_dir = os.getcwd()
file_path = os.path.normpath(os.path.join(base_dir, filename))
if not file_path.startswith(base_dir):
return "Access denied", 403
if not os.path.exists(file_path):
return "File not found", 404
directory = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
return send_from_directory(directory, file_name)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)