Spaces:
Sleeping
Sleeping
import os | |
import random | |
import json | |
import re | |
import traceback | |
import time | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
from quiz_logic.wikipedia_utils import fetch_context, extract_keywords | |
from quiz_logic.state import ( | |
quiz_data, | |
current_question_index, | |
score, | |
user_answer, | |
set_user_answer, | |
reset_user_answer, | |
increment_score, | |
increment_index, | |
) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
cache_dir = "/tmp/hf_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-small") | |
_device = None | |
_t5_tokenizer = None | |
_t5_model = None | |
QUESTION_TEMPLATES = { | |
"definition": "What is {concept}?", | |
"function": "What is the primary purpose of {concept}?", | |
"example": "Which of the following is an example of {concept}?", | |
"comparison": "How does {concept1} differ from {concept2}?", | |
"application": "When would you use {concept}?", | |
"characteristic": "Which characteristic best describes {concept}?", | |
"process": "What happens when {concept} is applied?", | |
"category": "Which category does {concept} belong to?" | |
} | |
IMPROVED_PROMPT_TEMPLATES = { | |
"python": """Create a multiple choice question about Python programming. | |
Topic: {focus} | |
Context: {context} | |
Difficulty: {level} | |
Generate a clear, specific question with 4 distinct options where only one is correct. | |
Focus on practical knowledge, syntax, or concepts. | |
Format: | |
Question: [specific question about {focus}] | |
A) [correct answer] | |
B) [plausible wrong answer] | |
C) [plausible wrong answer] | |
D) [plausible wrong answer] | |
Answer: A""", | |
"general": """Create a multiple choice question about {topic}. | |
Focus area: {focus} | |
Context: {context} | |
Difficulty: {level} | |
Generate a clear, educational question with 4 distinct options where only one is correct. | |
Make the question specific and the wrong answers plausible but clearly incorrect. | |
Format: | |
Question: [specific question about {focus}] | |
A) [correct answer] | |
B) [realistic wrong answer] | |
C) [realistic wrong answer] | |
D) [realistic wrong answer] | |
Answer: A""" | |
} | |
FALLBACK_QUESTIONS = { | |
"python": [ | |
{ | |
"question": "Which keyword is used to define a function in Python?", | |
"options": ["def", "function", "define", "func"], | |
"answer": "def", | |
"explanation": "The 'def' keyword is used to define functions in Python." | |
}, | |
{ | |
"question": "What data type is used to store a sequence of characters in Python?", | |
"options": ["str", "string", "text", "char"], | |
"answer": "str", | |
"explanation": "In Python, strings are represented by the 'str' data type." | |
}, | |
{ | |
"question": "Which operator is used for integer division in Python?", | |
"options": ["//", "/", "%", "**"], | |
"answer": "//", | |
"explanation": "The '//' operator performs floor division (integer division) in Python." | |
} | |
], | |
"general": [ | |
{ | |
"question": "What is the capital of France?", | |
"options": ["Paris", "London", "Berlin", "Rome"], | |
"answer": "Paris", | |
"explanation": "Paris is the capital and largest city of France." | |
} | |
] | |
} | |
def get_device(): | |
global _device | |
if _device is None: | |
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
return _device | |
def get_t5_tokenizer_and_model(): | |
global _t5_tokenizer, _t5_model | |
device = get_device() | |
if _t5_tokenizer is None or _t5_model is None: | |
try: | |
_t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir, use_fast=True) | |
_t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir) | |
_t5_model.to(device) | |
_t5_model.eval() | |
except Exception: | |
_t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_ID, use_fast=True) | |
_t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID) | |
_t5_model.to(device) | |
_t5_model.eval() | |
return _t5_tokenizer, _t5_model | |
def _parse_raw_question(text): | |
text = text.replace("\r", "\n").strip() | |
question_match = re.search(r'Question:\s*(.+?)(?=\n[A-D]\)|$)', text, re.S) | |
if question_match: | |
question = question_match.group(1).strip() | |
else: | |
lines = [line.strip() for line in text.split('\n') if line.strip()] | |
if lines: | |
question = lines[0] | |
else: | |
return None | |
options = [] | |
option_pattern = r'^([A-D])\)\s*(.+)$' | |
for line in text.split('\n'): | |
match = re.match(option_pattern, line.strip()) | |
if match: | |
options.append(match.group(2).strip()) | |
if len(options) < 4: | |
return None | |
answer_match = re.search(r'Answer:\s*([A-D])', text) | |
if answer_match: | |
answer_index = ord(answer_match.group(1)) - ord('A') | |
if 0 <= answer_index < len(options): | |
answer = options[answer_index] | |
else: | |
answer = options[0] | |
else: | |
answer = options[0] | |
return { | |
"question": question, | |
"options": options[:4], | |
"answer": answer, | |
"explanation": f"The correct answer is {answer}." | |
} | |
def _validate_question_quality(qobj): | |
if not qobj or not isinstance(qobj, dict): | |
return False | |
question = qobj.get("question", "") | |
options = qobj.get("options", []) | |
answer = qobj.get("answer", "") | |
if len(question.split()) < 5: | |
return False | |
if len(options) != 4: | |
return False | |
if answer not in options: | |
return False | |
unique_options = set(opt.lower().strip() for opt in options) | |
if len(unique_options) != 4: | |
return False | |
for opt in options: | |
if len(opt.strip()) < 2: | |
return False | |
return True | |
def _create_smart_distractors(correct_answer, topic, context): | |
distractors = [] | |
if "python" in topic.lower(): | |
python_distractors = { | |
"def": ["function", "define", "method"], | |
"str": ["string", "text", "varchar"], | |
"int": ["integer", "number", "num"], | |
"list": ["array", "vector", "sequence"], | |
"dict": ["map", "hash", "object"], | |
"True": ["true", "TRUE", "1"], | |
"False": ["false", "FALSE", "0"], | |
"None": ["null", "NULL", "nil"], | |
"print": ["console.log", "echo", "write"], | |
"len": ["length", "size", "count"] | |
} | |
if correct_answer in python_distractors: | |
distractors.extend(python_distractors[correct_answer]) | |
words = context.lower().split() | |
for word in words: | |
if (word != correct_answer.lower() and | |
len(word) > 2 and | |
word.isalpha() and | |
len(distractors) < 6): | |
distractors.append(word.capitalize()) | |
generic_distractors = ["Option A", "Option B", "Option C", "Not applicable", "All of the above", "None of the above"] | |
distractors.extend(generic_distractors) | |
final_distractors = [] | |
for d in distractors: | |
if d.lower() != correct_answer.lower() and d not in final_distractors: | |
final_distractors.append(d) | |
if len(final_distractors) >= 3: | |
break | |
return final_distractors[:3] | |
def _generate_structured_question(topic, focus, context, difficulty): | |
topic_lower = topic.lower() | |
if "python" in topic_lower: | |
if "function" in focus.lower(): | |
return { | |
"question": "Which keyword is used to define a function in Python?", | |
"options": ["def", "function", "define", "func"], | |
"answer": "def", | |
"explanation": "The 'def' keyword is the standard way to define functions in Python." | |
} | |
elif "string" in focus.lower() or "str" in focus.lower(): | |
return { | |
"question": "What method would you use to convert a string to uppercase in Python?", | |
"options": ["upper()", "toUpperCase()", "uppercase()", "UPPER()"], | |
"answer": "upper()", | |
"explanation": "The upper() method returns a string with all characters converted to uppercase." | |
} | |
elif "list" in focus.lower(): | |
return { | |
"question": "Which method adds an element to the end of a Python list?", | |
"options": ["append()", "add()", "insert()", "push()"], | |
"answer": "append()", | |
"explanation": "The append() method adds a single element to the end of a list." | |
} | |
template_type = random.choice(list(QUESTION_TEMPLATES.keys())) | |
question_template = QUESTION_TEMPLATES[template_type] | |
if "{concept}" in question_template: | |
question = question_template.format(concept=focus) | |
elif "{concept1}" in question_template: | |
concepts = focus.split() | |
if len(concepts) >= 2: | |
question = question_template.format(concept1=concepts[0], concept2=concepts[1]) | |
else: | |
question = f"What is {focus}?" | |
else: | |
question = f"What is {focus}?" | |
correct_answer = focus.capitalize() | |
distractors = _create_smart_distractors(correct_answer, topic, context) | |
while len(distractors) < 3: | |
distractors.append(f"Not {correct_answer}") | |
options = [correct_answer] + distractors[:3] | |
random.shuffle(options) | |
return { | |
"question": question, | |
"options": options, | |
"answer": correct_answer, | |
"explanation": f"{correct_answer} is the correct answer based on the given context." | |
} | |
def generate_questions(topic, n_questions=3, difficulty="medium"): | |
try: | |
quiz_data.clear() | |
context = fetch_context(topic) or f"This is about {topic}" | |
keywords = extract_keywords(context, top_k=max(n_questions * 2, 8)) or [topic] | |
tokenizer, model = get_t5_tokenizer_and_model() | |
successful_questions = 0 | |
attempts = 0 | |
max_attempts = n_questions * 3 | |
topic_lower = topic.lower() | |
prompt_template = IMPROVED_PROMPT_TEMPLATES.get("python" if "python" in topic_lower else "general") | |
while successful_questions < n_questions and attempts < max_attempts: | |
focus = random.choice(keywords) if keywords else topic | |
if attempts < n_questions: | |
prompt = prompt_template.format( | |
topic=topic, | |
focus=focus, | |
context=context[:600], | |
level=difficulty | |
) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = {k: v.to(get_device()) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
num_beams=5, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.9, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
parsed_question = _parse_raw_question(generated_text) | |
if parsed_question and _validate_question_quality(parsed_question): | |
quiz_data.append(parsed_question) | |
successful_questions += 1 | |
attempts += 1 | |
continue | |
structured_question = _generate_structured_question(topic, focus, context, difficulty) | |
if _validate_question_quality(structured_question): | |
quiz_data.append(structured_question) | |
successful_questions += 1 | |
attempts += 1 | |
while successful_questions < n_questions: | |
fallback_pool = FALLBACK_QUESTIONS.get("python" if "python" in topic_lower else "general", FALLBACK_QUESTIONS["general"]) | |
fallback_q = random.choice(fallback_pool) | |
quiz_data.append(fallback_q.copy()) | |
successful_questions += 1 | |
return len(quiz_data) > 0 | |
except Exception as e: | |
print(f"Error generating questions: {e}") | |
traceback.print_exc() | |
fallback_pool = FALLBACK_QUESTIONS.get("python" if "python" in topic.lower() else "general", FALLBACK_QUESTIONS["general"]) | |
for i in range(min(n_questions, len(fallback_pool))): | |
quiz_data.append(fallback_pool[i].copy()) | |
return len(quiz_data) > 0 | |
def get_question_ui(): | |
reset_user_answer() | |
if not quiz_data or current_question_index[0] >= len(quiz_data): | |
return ( | |
gr.update(value="Quiz finished!"), | |
gr.update(choices=[], value=None, interactive=False, visible=False), | |
gr.update(value="", visible=False), | |
gr.update(visible=False), | |
gr.update(value=f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}", visible=True) | |
) | |
q = quiz_data[current_question_index[0]] | |
question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}" | |
return ( | |
gr.update(value=question_display, visible=True), | |
gr.update(choices=q["options"], value=None, interactive=True, visible=True), | |
gr.update(value="", visible=False), | |
gr.update(visible=False), | |
gr.update(value="", visible=False) | |
) | |
def next_question(): | |
increment_index() | |
if current_question_index[0] >= len(quiz_data): | |
return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}" | |
q = quiz_data[current_question_index[0]] | |
return q["question"], q["options"] | |
def on_select(selected_option): | |
set_user_answer(selected_option) | |
q = quiz_data[current_question_index[0]] | |
correct = q.get("answer") | |
if isinstance(correct, int): | |
correct = q["options"][correct] | |
is_correct = str(selected_option).strip().lower() == str(correct).strip().lower() | |
if is_correct: | |
increment_score() | |
feedback = "Correct! Well done!" | |
else: | |
feedback = f"The correct answer is: {correct}" | |
last = (current_question_index[0] == len(quiz_data) - 1) | |
next_label = "View scores" if last else "Next question" | |
return feedback, True, next_label |