import os import random import json import re import traceback import time import torch import gradio as gr from transformers import AutoTokenizer, T5ForConditionalGeneration from quiz_logic.wikipedia_utils import fetch_context, extract_keywords from quiz_logic.state import ( quiz_data, current_question_index, score, user_answer, set_user_answer, reset_user_answer, increment_score, increment_index, ) os.environ["TOKENIZERS_PARALLELISM"] = "false" cache_dir = "/tmp/hf_cache" os.makedirs(cache_dir, exist_ok=True) T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-small") _device = None _t5_tokenizer = None _t5_model = None QUESTION_TEMPLATES = { "definition": "What is {concept}?", "function": "What is the primary purpose of {concept}?", "example": "Which of the following is an example of {concept}?", "comparison": "How does {concept1} differ from {concept2}?", "application": "When would you use {concept}?", "characteristic": "Which characteristic best describes {concept}?", "process": "What happens when {concept} is applied?", "category": "Which category does {concept} belong to?" } IMPROVED_PROMPT_TEMPLATES = { "python": """Create a multiple choice question about Python programming. Topic: {focus} Context: {context} Difficulty: {level} Generate a clear, specific question with 4 distinct options where only one is correct. Focus on practical knowledge, syntax, or concepts. Format: Question: [specific question about {focus}] A) [correct answer] B) [plausible wrong answer] C) [plausible wrong answer] D) [plausible wrong answer] Answer: A""", "general": """Create a multiple choice question about {topic}. Focus area: {focus} Context: {context} Difficulty: {level} Generate a clear, educational question with 4 distinct options where only one is correct. Make the question specific and the wrong answers plausible but clearly incorrect. Format: Question: [specific question about {focus}] A) [correct answer] B) [realistic wrong answer] C) [realistic wrong answer] D) [realistic wrong answer] Answer: A""" } FALLBACK_QUESTIONS = { "python": [ { "question": "Which keyword is used to define a function in Python?", "options": ["def", "function", "define", "func"], "answer": "def", "explanation": "The 'def' keyword is used to define functions in Python." }, { "question": "What data type is used to store a sequence of characters in Python?", "options": ["str", "string", "text", "char"], "answer": "str", "explanation": "In Python, strings are represented by the 'str' data type." }, { "question": "Which operator is used for integer division in Python?", "options": ["//", "/", "%", "**"], "answer": "//", "explanation": "The '//' operator performs floor division (integer division) in Python." } ], "general": [ { "question": "What is the capital of France?", "options": ["Paris", "London", "Berlin", "Rome"], "answer": "Paris", "explanation": "Paris is the capital and largest city of France." } ] } def get_device(): global _device if _device is None: _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return _device def get_t5_tokenizer_and_model(): global _t5_tokenizer, _t5_model device = get_device() if _t5_tokenizer is None or _t5_model is None: try: _t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir, use_fast=True) _t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir) _t5_model.to(device) _t5_model.eval() except Exception: _t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL_ID, use_fast=True) _t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID) _t5_model.to(device) _t5_model.eval() return _t5_tokenizer, _t5_model def _parse_raw_question(text): text = text.replace("\r", "\n").strip() question_match = re.search(r'Question:\s*(.+?)(?=\n[A-D]\)|$)', text, re.S) if question_match: question = question_match.group(1).strip() else: lines = [line.strip() for line in text.split('\n') if line.strip()] if lines: question = lines[0] else: return None options = [] option_pattern = r'^([A-D])\)\s*(.+)$' for line in text.split('\n'): match = re.match(option_pattern, line.strip()) if match: options.append(match.group(2).strip()) if len(options) < 4: return None answer_match = re.search(r'Answer:\s*([A-D])', text) if answer_match: answer_index = ord(answer_match.group(1)) - ord('A') if 0 <= answer_index < len(options): answer = options[answer_index] else: answer = options[0] else: answer = options[0] return { "question": question, "options": options[:4], "answer": answer, "explanation": f"The correct answer is {answer}." } def _validate_question_quality(qobj): if not qobj or not isinstance(qobj, dict): return False question = qobj.get("question", "") options = qobj.get("options", []) answer = qobj.get("answer", "") if len(question.split()) < 5: return False if len(options) != 4: return False if answer not in options: return False unique_options = set(opt.lower().strip() for opt in options) if len(unique_options) != 4: return False for opt in options: if len(opt.strip()) < 2: return False return True def _create_smart_distractors(correct_answer, topic, context): distractors = [] if "python" in topic.lower(): python_distractors = { "def": ["function", "define", "method"], "str": ["string", "text", "varchar"], "int": ["integer", "number", "num"], "list": ["array", "vector", "sequence"], "dict": ["map", "hash", "object"], "True": ["true", "TRUE", "1"], "False": ["false", "FALSE", "0"], "None": ["null", "NULL", "nil"], "print": ["console.log", "echo", "write"], "len": ["length", "size", "count"] } if correct_answer in python_distractors: distractors.extend(python_distractors[correct_answer]) words = context.lower().split() for word in words: if (word != correct_answer.lower() and len(word) > 2 and word.isalpha() and len(distractors) < 6): distractors.append(word.capitalize()) generic_distractors = ["Option A", "Option B", "Option C", "Not applicable", "All of the above", "None of the above"] distractors.extend(generic_distractors) final_distractors = [] for d in distractors: if d.lower() != correct_answer.lower() and d not in final_distractors: final_distractors.append(d) if len(final_distractors) >= 3: break return final_distractors[:3] def _generate_structured_question(topic, focus, context, difficulty): topic_lower = topic.lower() if "python" in topic_lower: if "function" in focus.lower(): return { "question": "Which keyword is used to define a function in Python?", "options": ["def", "function", "define", "func"], "answer": "def", "explanation": "The 'def' keyword is the standard way to define functions in Python." } elif "string" in focus.lower() or "str" in focus.lower(): return { "question": "What method would you use to convert a string to uppercase in Python?", "options": ["upper()", "toUpperCase()", "uppercase()", "UPPER()"], "answer": "upper()", "explanation": "The upper() method returns a string with all characters converted to uppercase." } elif "list" in focus.lower(): return { "question": "Which method adds an element to the end of a Python list?", "options": ["append()", "add()", "insert()", "push()"], "answer": "append()", "explanation": "The append() method adds a single element to the end of a list." } template_type = random.choice(list(QUESTION_TEMPLATES.keys())) question_template = QUESTION_TEMPLATES[template_type] if "{concept}" in question_template: question = question_template.format(concept=focus) elif "{concept1}" in question_template: concepts = focus.split() if len(concepts) >= 2: question = question_template.format(concept1=concepts[0], concept2=concepts[1]) else: question = f"What is {focus}?" else: question = f"What is {focus}?" correct_answer = focus.capitalize() distractors = _create_smart_distractors(correct_answer, topic, context) while len(distractors) < 3: distractors.append(f"Not {correct_answer}") options = [correct_answer] + distractors[:3] random.shuffle(options) return { "question": question, "options": options, "answer": correct_answer, "explanation": f"{correct_answer} is the correct answer based on the given context." } def generate_questions(topic, n_questions=3, difficulty="medium"): try: quiz_data.clear() context = fetch_context(topic) or f"This is about {topic}" keywords = extract_keywords(context, top_k=max(n_questions * 2, 8)) or [topic] tokenizer, model = get_t5_tokenizer_and_model() successful_questions = 0 attempts = 0 max_attempts = n_questions * 3 topic_lower = topic.lower() prompt_template = IMPROVED_PROMPT_TEMPLATES.get("python" if "python" in topic_lower else "general") while successful_questions < n_questions and attempts < max_attempts: focus = random.choice(keywords) if keywords else topic if attempts < n_questions: prompt = prompt_template.format( topic=topic, focus=focus, context=context[:600], level=difficulty ) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(get_device()) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.8, top_p=0.9, no_repeat_ngram_size=3, early_stopping=True ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) parsed_question = _parse_raw_question(generated_text) if parsed_question and _validate_question_quality(parsed_question): quiz_data.append(parsed_question) successful_questions += 1 attempts += 1 continue structured_question = _generate_structured_question(topic, focus, context, difficulty) if _validate_question_quality(structured_question): quiz_data.append(structured_question) successful_questions += 1 attempts += 1 while successful_questions < n_questions: fallback_pool = FALLBACK_QUESTIONS.get("python" if "python" in topic_lower else "general", FALLBACK_QUESTIONS["general"]) fallback_q = random.choice(fallback_pool) quiz_data.append(fallback_q.copy()) successful_questions += 1 return len(quiz_data) > 0 except Exception as e: print(f"Error generating questions: {e}") traceback.print_exc() fallback_pool = FALLBACK_QUESTIONS.get("python" if "python" in topic.lower() else "general", FALLBACK_QUESTIONS["general"]) for i in range(min(n_questions, len(fallback_pool))): quiz_data.append(fallback_pool[i].copy()) return len(quiz_data) > 0 def get_question_ui(): reset_user_answer() if not quiz_data or current_question_index[0] >= len(quiz_data): return ( gr.update(value="Quiz finished!"), gr.update(choices=[], value=None, interactive=False, visible=False), gr.update(value="", visible=False), gr.update(visible=False), gr.update(value=f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}", visible=True) ) q = quiz_data[current_question_index[0]] question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}" return ( gr.update(value=question_display, visible=True), gr.update(choices=q["options"], value=None, interactive=True, visible=True), gr.update(value="", visible=False), gr.update(visible=False), gr.update(value="", visible=False) ) def next_question(): increment_index() if current_question_index[0] >= len(quiz_data): return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}" q = quiz_data[current_question_index[0]] return q["question"], q["options"] def on_select(selected_option): set_user_answer(selected_option) q = quiz_data[current_question_index[0]] correct = q.get("answer") if isinstance(correct, int): correct = q["options"][correct] is_correct = str(selected_option).strip().lower() == str(correct).strip().lower() if is_correct: increment_score() feedback = "Correct! Well done!" else: feedback = f"The correct answer is: {correct}" last = (current_question_index[0] == len(quiz_data) - 1) next_label = "View scores" if last else "Next question" return feedback, True, next_label