import os import random import traceback import torch from quiz_logic.wikipedia_utils import fetch_context, extract_keywords from quiz_logic.state import ( quiz_data, current_question_index, score, user_answer, set_user_answer, reset_user_answer, increment_score, increment_index, ) cache_dir = "/tmp/hf_cache" os.makedirs(cache_dir, exist_ok=True) T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-base") # Upgraded to base for better quality _device = torch.device("cpu") _t5_tokenizer = None _t5_model = None def get_device(): return _device def get_t5_tokenizer_and_model(): global _t5_tokenizer, _t5_model if _t5_tokenizer is not None and _t5_model is not None: return _t5_tokenizer, _t5_model last_tb = "" try: from transformers import T5Tokenizer, T5ForConditionalGeneration _t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir) _t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir) _t5_model.to(get_device()) _t5_model.eval() return _t5_tokenizer, _t5_model except Exception: last_tb = traceback.format_exc() msg = ( f"Failed to load tokenizer/model '{T5_MODEL_ID}'. Ensure cache_dir '{cache_dir}' is writable and model name is correct.\n" f"Original error:\n{last_tb}" ) raise RuntimeError(msg) _prompt_examples = ( "Example:\n" "Q: What is the capital of France? | Paris | Lyon | Marseille | Nice | Paris\n\n" "Q: Which planet is known as the Red Planet? | Venus | Mars | Jupiter | Saturn | Mars\n\n" "Q: Who wrote Romeo and Juliet? | Shakespeare | Dickens | Tolkien | Austen | Shakespeare\n\n" ) def _parse_generated(text): text = text.strip() if "|" in text: parts = [p.strip() for p in text.split("|")] if len(parts) >= 6: question = parts[0] opts = parts[1:5] answer = parts[5] return question, opts, answer if len(parts) >= 2: question = parts[0] opts = parts[1:-1] if len(parts) > 2 else parts[1:] answer = parts[-1] if len(parts) > 2 else (opts[0] if opts else "") return question, opts, answer lines = [l.strip() for l in text.splitlines() if l.strip()] if lines: first = lines[0] if "?" in first: question = first opts = [] answer = "" for l in lines[1:]: if l.lower().startswith("answer:"): answer = l.split(":", 1)[1].strip() else: cleaned = l.strip().lstrip("-").lstrip("0123456789. ").strip() if cleaned: opts.append(cleaned) if not opts and "|" in first: parts = [p.strip() for p in first.split("|")] question = parts[0] opts = parts[1:] return question, opts, answer return text, [], "" def generate_questions(topic, n_questions=3, difficulty="medium"): reset_user_answer() quiz_data.clear() score[0] = 0 current_question_index[0] = 0 try: tokenizer, model = get_t5_tokenizer_and_model() context = fetch_context(topic) answers = extract_keywords(context, top_k=max(n_questions, 8)) # Basic stop-word filter to improve keyword quality (expand in wikipedia_utils if needed) stop_words = {"a", "an", "the", "in", "on", "at", "to", "of", "and", "or", "for", "with", "is", "it", "that", "this"} answers = [kw for kw in answers if kw.lower() not in stop_words and len(kw) > 2] if not answers: return [] k = min(n_questions, len(answers)) sampled_answers = random.sample(answers, k=k) # Adjust prompt based on difficulty diff_prompt = { "easy": "Generate a simple multiple-choice question for beginners.", "medium": "Generate a standard multiple-choice question.", "hard": "Generate a challenging multiple-choice question with subtle distractors." }.get(difficulty, "Generate a multiple-choice question.") for ans in sampled_answers: truncated_context = context if len(context) <= 2000 else context[:2000] prompt = ( _prompt_examples + f"Now {diff_prompt} The correct answer must be '{ans}'. Use this exact format:\n" "Question? | option1 | option2 | option3 | option4 | answer\n" "Make the question diverse and directly related to the topic. Do not use 'passage' or generic questions.\n\n" f"Context: {truncated_context}\n" f"Correct Answer: {ans}\n" ) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(get_device()) for k, v in inputs.items()} outputs = model.generate( **inputs, max_new_tokens=150, num_beams=5, # Increased for better quality temperature=0.8, # Slight increase for diversity top_p=0.95, no_repeat_ngram_size=3, early_stopping=True ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) question, opts, answer = _parse_generated(decoded) if not question.strip().endswith("?"): question += "?" # Ensure it's a question if not opts: distractors_pool = [a for a in answers if a != ans] distractors = random.sample(distractors_pool, k=min(3, len(distractors_pool))) if distractors_pool else [] opts = [ans] + distractors question = f"What is a key concept in {topic} related to '{ans}'?" # Better fallback question if answer == "" and opts: answer = opts[0] if ans not in opts else ans if answer not in opts: if ans in opts: answer = ans else: if len(opts) < 4: extra = [a for a in answers if a not in opts and a != ans] for e in extra[:4 - len(opts)]: opts.append(e) if ans not in opts: opts[0] = ans answer = ans opts = opts[:4] if len(opts) >= 4 else (opts + [ans] * (4 - len(opts))) random.shuffle(opts) quiz_data.append({ "question": question, "options": opts, "answer": answer }) return quiz_data except Exception: tb = traceback.format_exc() raise RuntimeError(tb) def get_question_ui(): reset_user_answer() if not quiz_data: return None q = quiz_data[current_question_index[0]] question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}\n\n**Choose your answer:**" return question_display, q["options"] def on_select(option): if option is None: return "", None, False set_user_answer(option) correct = quiz_data[current_question_index[0]]["answer"] feedback = "✅ Correct!" if option == correct else f"❌ Incorrect. Correct answer: {correct}" is_last = current_question_index[0] == len(quiz_data) - 1 next_label = "View Score" if is_last else "Next Question" return feedback, False, next_label def next_question(): if user_answer[0] is None: return None if user_answer[0] == quiz_data[current_question_index[0]]["answer"]: increment_score() increment_index() reset_user_answer() if current_question_index[0] >= len(quiz_data): return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}" else: return get_question_ui()