Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset, get_dataset_config_names | |
import random | |
from typing import List, Tuple | |
import logging | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
# Popular evaluation datasets with their configurations | |
EVAL_DATASETS = { | |
"openai/gsm8k": { | |
"name": "GSM8K - Grade School Math", | |
"type": "qa", | |
"config": "main", | |
"question_field": "question", | |
"answer_field": "answer", | |
"split": "train", | |
}, | |
"cais/mmlu": { | |
"name": "MMLU - Massive Multitask Language Understanding", | |
"type": "multiple_choice", | |
"config": "all", | |
"question_field": "question", | |
"choices_field": "choices", | |
"answer_field": "answer", | |
"split": "test", | |
}, | |
"allenai/ai2_arc": { | |
"name": "AI2 ARC - Science Questions", | |
"type": "multiple_choice", | |
"config": "ARC-Challenge", | |
"question_field": "question", | |
"choices_field": "choices", | |
"answer_field": "answerKey", | |
"split": "train", | |
}, | |
"Rowan/hellaswag": { | |
"name": "HellaSwag - Commonsense NLI", | |
"type": "multiple_choice", | |
"question_field": "ctx", | |
"choices_field": "endings", | |
"answer_field": "label", | |
"split": "train", | |
}, | |
"allenai/winogrande": { | |
"name": "WinoGrande - Winograd Schema", | |
"type": "binary_choice", | |
"config": "winogrande_xl", | |
"question_field": "sentence", | |
"option1_field": "option1", | |
"option2_field": "option2", | |
"answer_field": "answer", | |
"split": "train", | |
}, | |
"google/boolq": { | |
"name": "BoolQ - Boolean Questions", | |
"type": "true_false", | |
"question_field": "question", | |
"context_field": "passage", | |
"answer_field": "answer", | |
"split": "train", | |
}, | |
"rajpurkar/squad": { | |
"name": "SQuAD - Reading Comprehension", | |
"type": "extractive_qa", | |
"question_field": "question", | |
"context_field": "context", | |
"answer_field": "answers", | |
"split": "train", | |
}, | |
"allenai/piqa": { | |
"name": "PIQA - Physical Reasoning", | |
"type": "binary_choice", | |
"question_field": "goal", | |
"option1_field": "sol1", | |
"option2_field": "sol2", | |
"answer_field": "label", | |
"split": "train", | |
}, | |
} | |
class QuizApp: | |
def __init__(self): | |
self.current_dataset = None | |
self.current_dataset_name = None | |
self.questions = [] | |
self.current_question_idx = 0 | |
self.score = 0 | |
self.total_questions = 0 | |
def load_dataset_questions(self, dataset_name: str, num_questions: int = 10): | |
"""Load random questions from the selected dataset""" | |
try: | |
config = EVAL_DATASETS[dataset_name] | |
# Try to load dataset with config if specified | |
try: | |
if "config" in config: | |
dataset = load_dataset( | |
dataset_name, config["config"], split=config["split"] | |
) | |
else: | |
dataset = load_dataset(dataset_name, split=config["split"]) | |
except ValueError as e: | |
# If config is missing, try to get available configs | |
if "Config name is missing" in str(e): | |
configs = get_dataset_config_names(dataset_name) | |
# Use first config or "all" if available | |
if "all" in configs: | |
selected_config = "all" | |
else: | |
selected_config = configs[0] | |
print( | |
f"Auto-selected config '{selected_config}' for {dataset_name}" | |
) | |
dataset = load_dataset( | |
dataset_name, selected_config, split=config["split"] | |
) | |
else: | |
raise e | |
# Sample random questions | |
total_examples = len(dataset) | |
num_questions = min(num_questions, total_examples) | |
indices = random.sample(range(total_examples), num_questions) | |
self.questions = [] | |
for idx in indices: | |
example = dataset[idx] | |
self.questions.append(example) | |
self.current_dataset = config | |
self.current_dataset_name = dataset_name | |
self.current_question_idx = 0 | |
self.score = 0 | |
self.total_questions = len(self.questions) | |
return True, f"Loaded {num_questions} questions from {config['name']}" | |
except Exception as e: | |
return False, f"Error loading dataset: {str(e)}" | |
def get_current_question(self) -> Tuple[str, List[str], str]: | |
"""Get the current question formatted for display""" | |
if not self.questions or self.current_question_idx >= len(self.questions): | |
return "", [], "" | |
question_data = self.questions[self.current_question_idx] | |
config = self.current_dataset | |
logging.info(f"\n{'=' * 60}") | |
logging.info(f"Dataset: {self.current_dataset_name}") | |
logging.info(f"Question {self.current_question_idx + 1}/{self.total_questions}") | |
logging.info(f"Raw question data: {repr(question_data)}") | |
logging.info(f"{'=' * 60}\n") | |
# Format question based on dataset type | |
question_type = config["type"] | |
if question_type == "multiple_choice": | |
question = question_data[config["question_field"]] | |
choices = question_data[config["choices_field"]] | |
if config["answer_field"] in question_data: | |
answer = question_data[config["answer_field"]] | |
else: | |
answer = "" | |
# Format choices with letters | |
formatted_choices = [ | |
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices) | |
] | |
return question, formatted_choices, question_type | |
elif question_type == "true_false": | |
question = question_data[config["question_field"]] | |
if "context_field" in config: | |
context = question_data[config["context_field"]] | |
question = f"Context: {context}\n\nQuestion: {question}" | |
return question, ["True", "False"], question_type | |
elif question_type == "binary_choice": | |
question = question_data[config["question_field"]] | |
option1 = question_data[config["option1_field"]] | |
option2 = question_data[config["option2_field"]] | |
return question, [f"A. {option1}", f"B. {option2}"], question_type | |
elif question_type == "qa" or question_type == "extractive_qa": | |
question = question_data[config["question_field"]] | |
if "context_field" in config and config["context_field"] in question_data: | |
context = question_data[config["context_field"]] | |
question = f"Context: {context[:500]}...\n\nQuestion: {question}" | |
return question, [], question_type | |
return "", [], "" | |
def format_answer(self, answer: str, dataset_name: str) -> str: | |
"""Format answer based on dataset type for better readability""" | |
import re | |
# Convert <<equation>> to show the math clearly | |
# Extract the equation and its result, show just the result with equation in parentheses | |
def format_equation(match): | |
equation = match.group(1) | |
# Check if it's in format "calculation=result" | |
if '=' in equation: | |
parts = equation.split('=') | |
if len(parts) == 2: | |
calculation, result = parts[0], parts[1] | |
return f"{result} (={calculation})" | |
return f"[{equation}]" | |
answer = re.sub(r"<<([^>]+)>>", format_equation, answer) | |
# Dataset-specific formatting | |
if dataset_name == "openai/gsm8k": | |
# Format the final answer line | |
answer = answer.replace("####", "\n\nFinal Answer:") | |
# Ensure proper line breaks after periods for readability | |
answer = re.sub(r'\. (?=[A-Z])', '.\n', answer) | |
return answer | |
def check_answer(self, user_answer: str) -> Tuple[bool, str]: | |
"""Check if the user's answer is correct""" | |
if not self.questions or self.current_question_idx >= len(self.questions): | |
return False, "No question available" | |
question_data = self.questions[self.current_question_idx] | |
config = self.current_dataset | |
question_type = config["type"] | |
if question_type == "multiple_choice": | |
correct_answer_idx = question_data[config["answer_field"]] | |
# Handle both numeric and letter answers | |
if isinstance(correct_answer_idx, int): | |
correct_letter = chr(65 + correct_answer_idx) | |
else: | |
correct_letter = str(correct_answer_idx) | |
user_letter = user_answer.strip().upper()[0] if user_answer else "" | |
is_correct = user_letter == correct_letter | |
if is_correct: | |
return True, '✅ Correct!' | |
else: | |
choices = question_data[config["choices_field"]] | |
correct_choice = ( | |
choices[correct_answer_idx] | |
if isinstance(correct_answer_idx, int) | |
else correct_answer_idx | |
) | |
logging.info(f"Raw answer (multiple choice): {repr(correct_choice)}") | |
formatted_answer = self.format_answer( | |
correct_choice, self.current_dataset_name | |
) | |
return ( | |
False, | |
f'❌ Incorrect\n\nThe correct answer was {correct_letter}:\n\n{formatted_answer}', | |
) | |
elif question_type == "true_false": | |
correct_answer = question_data[config["answer_field"]] | |
user_bool = user_answer.lower().strip() == "true" | |
is_correct = user_bool == correct_answer | |
if is_correct: | |
return True, '✅ Correct!' | |
else: | |
return ( | |
False, | |
f'❌ Incorrect\n\nThe correct answer was {correct_answer}', | |
) | |
elif question_type == "binary_choice": | |
correct_answer_idx = question_data[config["answer_field"]] | |
user_idx = 0 if user_answer.strip().upper().startswith("A") else 1 | |
is_correct = user_idx == correct_answer_idx | |
if is_correct: | |
return True, '✅ Correct!' | |
else: | |
correct_letter = "A" if correct_answer_idx == 0 else "B" | |
option_field = ( | |
config["option1_field"] | |
if correct_answer_idx == 0 | |
else config["option2_field"] | |
) | |
correct_option = question_data[option_field] | |
logging.info(f"Raw answer (binary choice): {repr(correct_option)}") | |
formatted_answer = self.format_answer( | |
correct_option, self.current_dataset_name | |
) | |
return ( | |
False, | |
f'❌ Incorrect\n\nThe correct answer was {correct_letter}:\n\n{formatted_answer}', | |
) | |
elif question_type in ["qa", "extractive_qa"]: | |
# For QA, we'll do a simple check - in real app, you'd want more sophisticated matching | |
correct_answer = question_data[config["answer_field"]] | |
if isinstance(correct_answer, dict) and "text" in correct_answer: | |
correct_answer = ( | |
correct_answer["text"][0] if correct_answer["text"] else "" | |
) | |
elif isinstance(correct_answer, list) and len(correct_answer) > 0: | |
correct_answer = ( | |
correct_answer[0]["text"] | |
if isinstance(correct_answer[0], dict) | |
else str(correct_answer[0]) | |
) | |
else: | |
correct_answer = str(correct_answer) | |
# Extract final answer for GSM8K and similar datasets | |
import re | |
# For GSM8K, extract the final answer after #### | |
if "####" in correct_answer: | |
final_answer_match = re.search(r"####\s*(.+)", correct_answer) | |
if final_answer_match: | |
final_answer = final_answer_match.group(1).strip() | |
else: | |
final_answer = correct_answer | |
else: | |
final_answer = correct_answer | |
# First check if user answer is empty | |
if not user_answer or not user_answer.strip(): | |
is_correct = False | |
else: | |
# Extract numbers from both answers for comparison | |
correct_numbers = re.findall(r"-?\d+\.?\d*", final_answer) | |
user_numbers = re.findall(r"-?\d+\.?\d*", user_answer) | |
# Check if answers match | |
is_correct = False | |
# If both have numbers, compare the numbers | |
if correct_numbers and user_numbers: | |
# Convert to float for comparison to handle decimals | |
try: | |
correct_num = float( | |
correct_numbers[-1] | |
) # Take the last number as final answer | |
user_num = float(user_numbers[-1]) # Take the last number from user | |
is_correct = ( | |
abs(correct_num - user_num) < 0.0001 | |
) # Small tolerance for float comparison | |
except ValueError: | |
# Fall back to string comparison | |
is_correct = correct_numbers[-1] == user_numbers[-1] | |
elif correct_numbers and not user_numbers: | |
# If correct answer has numbers but user answer doesn't, it's wrong | |
is_correct = False | |
else: | |
# Fall back to substring matching for non-numeric answers | |
# But ensure both strings are non-empty | |
is_correct = ( | |
user_answer.lower().strip() in correct_answer.lower() | |
or correct_answer.lower() in user_answer.lower().strip() | |
) and len(user_answer.strip()) > 0 | |
if is_correct: | |
return True, '✅ Correct!' | |
else: | |
logging.info(f"Raw answer (QA): {repr(correct_answer)}") | |
logging.info(f"Extracted final answer: {repr(final_answer)}") | |
logging.info( | |
f"Correct numbers: {correct_numbers}, User numbers: {user_numbers}" | |
) | |
formatted_answer = self.format_answer( | |
correct_answer, self.current_dataset_name | |
) | |
# Debug: log the formatted answer | |
logging.info(f"Formatted answer with LaTeX: {repr(formatted_answer)}") | |
return ( | |
False, | |
f'❌ Incorrect\n\nThe correct answer was:\n\n{formatted_answer}', | |
) | |
return False, "Unknown question type" | |
# Create global quiz app instance | |
quiz_app = QuizApp() | |
def create_dataset_display(): | |
"""Create the dataset listing display""" | |
dataset_info = [] | |
for dataset_id, config in EVAL_DATASETS.items(): | |
dataset_info.append( | |
f"**{config['name']}**\n- Dataset: {dataset_id}\n- Type: {config['type']}" | |
) | |
return "\n\n".join(dataset_info) | |
def start_quiz(dataset_choice: str, num_questions: int): | |
"""Start a new quiz with the selected dataset""" | |
# Extract dataset ID from the choice | |
dataset_id = None | |
for did, config in EVAL_DATASETS.items(): | |
if config["name"] in dataset_choice: | |
dataset_id = did | |
break | |
if not dataset_id: | |
return ( | |
"Please select a dataset", | |
gr.update(visible=False), # question_display | |
gr.update(visible=False), # answer_radio | |
gr.update(visible=False), # answer_textbox | |
gr.update(visible=False), # submit_button | |
gr.update(visible=False), # progress_text | |
) | |
success, message = quiz_app.load_dataset_questions(dataset_id, num_questions) | |
if success: | |
question, choices, q_type = quiz_app.get_current_question() | |
if q_type in ["multiple_choice", "true_false", "binary_choice"]: | |
return ( | |
message, | |
gr.update(value=question, visible=True), # question_display | |
gr.update(choices=choices, visible=True, value=None), # answer_radio | |
gr.update(visible=False), # answer_textbox | |
gr.update(visible=True), # submit_button | |
gr.update(value=f"Question 1/{quiz_app.total_questions}", visible=True), # progress_text | |
) | |
else: | |
return ( | |
message, | |
gr.update(value=question, visible=True), # question_display | |
gr.update(visible=False), # answer_radio | |
gr.update(visible=True, value=""), # answer_textbox | |
gr.update(visible=True), # submit_button | |
gr.update(value=f"Question 1/{quiz_app.total_questions}", visible=True), # progress_text | |
) | |
else: | |
return ( | |
message, | |
gr.update(visible=False), # question_display | |
gr.update(visible=False), # answer_radio | |
gr.update(visible=False), # answer_textbox | |
gr.update(visible=False), # submit_button | |
gr.update(visible=False), # progress_text | |
) | |
def submit_answer(answer_choice, answer_text): | |
"""Submit answer and show feedback""" | |
# Determine which answer to use | |
if answer_choice: | |
answer = answer_choice | |
else: | |
answer = answer_text | |
is_correct, feedback = quiz_app.check_answer(answer) | |
if is_correct: | |
quiz_app.score += 1 | |
return gr.update(value=feedback, visible=True), gr.update(visible=True) | |
def next_question(): | |
"""Move to the next question""" | |
quiz_app.current_question_idx += 1 | |
if quiz_app.current_question_idx >= quiz_app.total_questions: | |
# Quiz complete | |
final_score = f'🎉 Quiz Complete!\n\nYour score: {quiz_app.score}/{quiz_app.total_questions} ({quiz_app.score / quiz_app.total_questions * 100:.1f}%)' | |
return ( | |
gr.update(value=final_score, visible=True), | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
"Quiz Complete", | |
) | |
question, choices, q_type = quiz_app.get_current_question() | |
if q_type in ["multiple_choice", "true_false", "binary_choice"]: | |
return ( | |
gr.update(value="", visible=False), # Clear feedback | |
gr.update(value=question), # question_display | |
gr.update(choices=choices, visible=True, value=None), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(value=f"Question {quiz_app.current_question_idx + 1}/{quiz_app.total_questions}"), | |
) | |
else: | |
return ( | |
gr.update(value="", visible=False), # Clear feedback | |
gr.update(value=question), # question_display | |
gr.update(visible=False), | |
gr.update(visible=True, value=""), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(value=f"Question {quiz_app.current_question_idx + 1}/{quiz_app.total_questions}"), | |
) | |
# Create Gradio interface | |
with gr.Blocks(title="HuggingFace Evaluation Dataset Quiz") as demo: | |
gr.Markdown("# 🤗 Evaluation Dataset Quiz") | |
gr.Markdown( | |
"Test yourself with questions from popular HuggingFace evaluation datasets!" | |
) | |
# Dataset Selection Section | |
with gr.Row(): | |
dataset_dropdown = gr.Dropdown( | |
choices=[config["name"] for config in EVAL_DATASETS.values()], | |
label="Select Dataset", | |
value=list(EVAL_DATASETS.values())[0]["name"], | |
) | |
num_questions_slider = gr.Slider( | |
minimum=5, maximum=20, value=10, step=1, label="Number of Questions" | |
) | |
start_button = gr.Button("Start Quiz", variant="primary") | |
status_message = gr.Textbox(label="Status", interactive=False) | |
# Quiz Section - shown when quiz starts | |
gr.Markdown("---") # Separator | |
progress_text = gr.Textbox(label="Progress", value="0/0", interactive=False, visible=False) | |
question_display = gr.Textbox(label="Question", lines=5, interactive=False, visible=False) | |
# Answer inputs (one will be visible at a time) | |
answer_radio = gr.Radio(label="Select your answer", visible=False) | |
answer_textbox = gr.Textbox(label="Type your answer (Raw number)", visible=False) | |
submit_button = gr.Button("Submit Answer", variant="primary", visible=False) | |
feedback_display = gr.Textbox( | |
label="Feedback", | |
visible=False, | |
lines=10, | |
max_lines=20, | |
interactive=False | |
) | |
next_button = gr.Button("Next Question", visible=False) | |
# Connect events | |
start_button.click( | |
start_quiz, | |
inputs=[dataset_dropdown, num_questions_slider], | |
outputs=[ | |
status_message, | |
question_display, | |
answer_radio, | |
answer_textbox, | |
submit_button, | |
progress_text, | |
], | |
) | |
submit_button.click( | |
submit_answer, | |
inputs=[answer_radio, answer_textbox], | |
outputs=[feedback_display, next_button], | |
) | |
next_button.click( | |
next_question, | |
outputs=[ | |
feedback_display, | |
question_display, | |
answer_radio, | |
answer_textbox, | |
submit_button, | |
next_button, | |
progress_text, | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |