File size: 12,522 Bytes
de7b716 00c0d2b 0dc341f 24cbc9a eb2147a 465850b 0dc341f eb2147a 0dc341f 465850b 0dc341f 00c0d2b 465850b eb2147a 465850b 8251e4d 8873a5c 465850b 0dc341f 465850b 0ea49b5 465850b 436586a de7b716 de400cc 465850b 3c6a020 00c0d2b 465850b 00c0d2b 465850b 00c0d2b 436586a de7b716 de400cc de7b716 de400cc eb2147a de400cc eb2147a 465850b de400cc eb2147a de400cc eb2147a 3c6a020 436586a 3c6a020 465850b eb2147a 465850b eb2147a 0dc341f 465850b 0dc341f 465850b eb2147a 465850b 2b0d981 eb2147a 0dc341f 465850b cf6133b 465850b cf6133b 0dc341f 465850b 0dc341f 0fe5446 465850b 0dc341f 465850b da1da84 ea1f759 465850b de7b716 83152ea 86a95d5 eb2147a 465850b de7b716 83152ea 00c0d2b 4600a8b 465850b 0dc341f ea1f759 465850b 86a95d5 de7b716 0dc341f 00c0d2b 83152ea 48b2398 83152ea de7b716 0dc341f 465850b 51ed47e 465850b 0fe5446 465850b 0fe5446 465850b d72e943 4020dbd 83152ea 4020dbd 465850b 0dc341f 612ce70 de7b716 0dc341f 00c0d2b 83152ea eb2147a 83152ea 3c6a020 465850b 0dc341f 465850b eb2147a 465850b 00c0d2b 3c6a020 465850b eb2147a 0dc341f 4020dbd 0dc341f 465850b fafbcc3 38fa440 fafbcc3 506c033 8d4f7ad b314c6f 506c033 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
import uuid
from flask import Flask, render_template, request, redirect, url_for, send_from_directory
import json
import random
import os
import string
import logging
from datetime import datetime
from huggingface_hub import login, HfApi, hf_hub_download
from statistics import mean
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
])
logger = logging.getLogger(__name__)
# Use the Hugging Face token from environment variables
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
logger.error("HF_TOKEN not found in environment variables")
app = Flask(__name__)
app.config['SECRET_KEY'] = 'supersecretkey'
# File-based session storage
SESSION_DIR = '/tmp/sessions'
os.makedirs(SESSION_DIR, exist_ok=True)
# Update visualization directories for the 4 methods
VISUALIZATION_DIRS = {
"Text2SQL": "htmls_Text2SQL",
"Dater": "htmls_DATER_mod2",
"Chain-of-Table": "htmls_COT_mod",
"Plan-of-SQLs": "htmls_POS_mod2"
}
# Update method directory mapping
def get_method_dir(method):
method_mapping = {
'Text2SQL': 'Text2SQL',
'Dater': 'DATER',
'Chain-of-Table': 'COT',
'Plan-of-SQLs': 'POS'
}
return method_mapping.get(method)
# Update methods list to only include the 4 methods we want to rank
METHODS = ["Text2SQL", "Dater", "Chain-of-Table", "Plan-of-SQLs"]
def generate_session_id():
return str(uuid.uuid4())
def save_session_data(session_id, data):
file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
with open(file_path, 'w') as f:
json.dump(data, f)
logger.info(f"Session data saved for session {session_id}")
def load_session_data(session_id):
file_path = os.path.join(SESSION_DIR, f'{session_id}.json')
if os.path.exists(file_path):
with open(file_path, 'r') as f:
return json.load(f)
return None
def save_session_data_to_hf(session_id, data):
try:
username = data.get('username', 'unknown')
seed = data.get('seed', 'unknown')
start_time = data.get('start_time', datetime.now().isoformat())
file_name = f'{username}_seed{seed}_{start_time}_{session_id}_session.json'
file_name = "".join(c for c in file_name if c.isalnum() or c in ['_', '-', '.'])
json_data = json.dumps(data, indent=4)
temp_file_path = f"/tmp/{file_name}"
with open(temp_file_path, 'w') as f:
f.write(json_data)
api = HfApi()
repo_path = "session_data_preference_ranking"
api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=f"{repo_path}/{file_name}",
repo_id="luulinh90s/Tabular-LLM-Study-Data",
repo_type="space",
)
os.remove(temp_file_path)
logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space")
except Exception as e:
logger.exception(f"Error saving session data for session {session_id}: {e}")
def load_samples_for_all_methods(metadata_files):
samples_by_method = {}
common_samples = []
# First, load all samples for each method
for method in METHODS:
method_samples = []
categories = ["TP", "TN", "FP", "FN"]
for category in categories:
method_dir = VISUALIZATION_DIRS[method]
try:
files = set(os.listdir(f'{method_dir}/{category}'))
for file in files:
index = file.split('-')[1].split('.')[0]
metadata_key = f"{get_method_dir(method)}_test-{index}.html"
# Get metadata for this sample
sample_metadata = metadata_files[method].get(metadata_key, {})
method_samples.append({
'category': category,
'file': file,
'metadata': sample_metadata
})
except Exception as e:
logger.error(f"Error loading samples for method {method}, category {category}: {e}")
samples_by_method[method] = method_samples
# Find common samples across all methods
file_sets = []
for method, samples in samples_by_method.items():
file_set = {s['file'] for s in samples}
file_sets.append(file_set)
common_files = set.intersection(*file_sets)
# Create groups of samples that exist across all methods
for file_name in common_files:
sample_group = {}
for method in METHODS:
sample = next((s for s in samples_by_method[method] if s['file'] == file_name), None)
if sample:
sample_group[method] = sample
if len(sample_group) == len(METHODS):
common_samples.append(sample_group)
return common_samples
def select_balanced_samples(samples):
try:
# Get the category from any method (they should all be the same)
sample_categories = [(s, next(iter(s.values()))['category']) for s in samples]
# Separate samples into two groups
tp_fp_samples = [s for s, cat in sample_categories if cat in ['TP', 'FP']]
tn_fn_samples = [s for s, cat in sample_categories if cat in ['TN', 'FN']]
# Select balanced samples
if len(tp_fp_samples) >= 5 and len(tn_fn_samples) >= 5:
selected_tp_fp = random.sample(tp_fp_samples, 5)
selected_tn_fn = random.sample(tn_fn_samples, 5)
selected_samples = selected_tp_fp + selected_tn_fn
random.shuffle(selected_samples)
else:
logger.warning(
f"Not enough samples for balanced selection. TP+FP: {len(tp_fp_samples)}, TN+FN: {len(tn_fn_samples)}")
selected_samples = random.sample(samples, min(10, len(samples)))
return selected_samples
except Exception as e:
logger.exception("Error selecting balanced samples")
return []
@app.route('/')
def root():
return redirect(url_for('consent'))
@app.route('/consent', methods=['GET', 'POST'])
def consent():
if request.method == 'POST':
return redirect(url_for('introduction'))
return render_template('consent.html')
@app.route('/introduction')
def introduction():
return render_template('introduction.html')
@app.route('/attribution')
def attribution():
return render_template('attribution.html')
@app.route('/index', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
username = request.form.get('username')
seed = request.form.get('seed')
if not username or not seed:
return render_template('index.html', error="Please fill in all fields.")
try:
seed = int(seed)
random.seed(seed)
# Load metadata for all methods
metadata_files = {}
for method in METHODS:
json_file = f'Tabular_LLMs_human_study_vis_6_{get_method_dir(method)}.json'
with open(json_file, 'r') as f:
metadata_files[method] = json.load(f)
# Load and select samples
all_samples = load_samples_for_all_methods(metadata_files)
selected_samples = select_balanced_samples(all_samples)
if len(selected_samples) == 0:
return render_template('index.html', error="No common samples were found")
# Create session
session_id = generate_session_id()
session_data = {
'username': username,
'seed': str(seed),
'selected_samples': selected_samples,
'current_index': 0,
'responses': [],
'start_time': datetime.now().isoformat(),
'session_id': session_id
}
save_session_data(session_id, session_data)
return redirect(url_for('experiment', session_id=session_id))
except Exception as e:
logger.exception(f"Error in index route: {e}")
return render_template('index.html', error="An error occurred. Please try again.")
return render_template('index.html')
@app.route('/experiment/<session_id>', methods=['GET', 'POST'])
def experiment(session_id):
try:
session_data = load_session_data(session_id)
if not session_data:
return redirect(url_for('index'))
selected_samples = session_data['selected_samples']
current_index = session_data['current_index']
if current_index >= len(selected_samples):
return redirect(url_for('completed', session_id=session_id))
if request.method == 'POST':
# Validate and save rankings
rankings = {method: int(request.form.get(method)) for method in METHODS}
if not all(1 <= rank <= 4 for rank in rankings.values()):
return "Invalid rankings. Please use numbers 1-4.", 400
if len(set(rankings.values())) != 4:
return "Each method must have a unique rank.", 400
session_data['responses'].append({
'sample_id': current_index,
'rankings': rankings
})
session_data['current_index'] += 1
save_session_data(session_id, session_data)
return redirect(url_for('experiment', session_id=session_id))
# Get current sample group and prepare visualizations
sample_group = selected_samples[current_index]
visualizations = {
method: url_for('send_visualization',
filename=f"{VISUALIZATION_DIRS[method]}/{sample['category']}/{sample['file']}")
for method, sample in sample_group.items()
}
# Get metadata from any method (they should all have the same statement)
sample_metadata = next(iter(sample_group.values()))['metadata']
statement = sample_metadata.get('statement', '')
return render_template('experiment.html',
sample_id=current_index,
statement=statement,
visualizations=visualizations,
methods=METHODS,
session_id=session_id)
except Exception as e:
logger.exception(f"An error occurred in the experiment route: {e}")
return "An error occurred", 500
@app.route('/completed/<session_id>')
def completed(session_id):
try:
session_data = load_session_data(session_id)
if not session_data:
return redirect(url_for('index'))
session_data['end_time'] = datetime.now().isoformat()
responses = session_data['responses']
# Calculate average ranking for each method
average_rankings = {
method: mean(r['rankings'][method] for r in responses)
for method in METHODS
}
# Sort methods by average ranking (ascending)
sorted_methods = sorted(
average_rankings.items(),
key=lambda x: x[1]
)
session_data['average_rankings'] = average_rankings
save_session_data_to_hf(session_id, session_data)
# Clean up local session file
try:
os.remove(os.path.join(SESSION_DIR, f'{session_id}.json'))
except Exception as e:
logger.warning(f"Error removing session file: {e}")
return render_template(
'completed.html',
average_rankings=average_rankings,
sorted_methods=sorted_methods
)
except Exception as e:
logger.exception(f"An error occurred in the completed route: {e}")
return "An error occurred", 500
@app.route('/visualizations/<path:filename>')
def send_visualization(filename):
base_dir = os.getcwd()
file_path = os.path.normpath(os.path.join(base_dir, filename))
if not file_path.startswith(base_dir):
return "Access denied", 403
if not os.path.exists(file_path):
return "File not found", 404
directory = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
return send_from_directory(directory, file_name)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True) |