smart-quiz-api / quiz_logic /generator.py
NZLouislu's picture
Update quiz_logic/generator.py
52616f0 verified
import os
import random
import traceback
import torch
from quiz_logic.wikipedia_utils import fetch_context, extract_keywords
from quiz_logic.state import (
quiz_data,
current_question_index,
score,
user_answer,
set_user_answer,
reset_user_answer,
increment_score,
increment_index,
)
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-base") # Upgraded to base for better quality
_device = torch.device("cpu")
_t5_tokenizer = None
_t5_model = None
def get_device():
return _device
def get_t5_tokenizer_and_model():
global _t5_tokenizer, _t5_model
if _t5_tokenizer is not None and _t5_model is not None:
return _t5_tokenizer, _t5_model
last_tb = ""
try:
from transformers import T5Tokenizer, T5ForConditionalGeneration
_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir)
_t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir)
_t5_model.to(get_device())
_t5_model.eval()
return _t5_tokenizer, _t5_model
except Exception:
last_tb = traceback.format_exc()
msg = (
f"Failed to load tokenizer/model '{T5_MODEL_ID}'. Ensure cache_dir '{cache_dir}' is writable and model name is correct.\n"
f"Original error:\n{last_tb}"
)
raise RuntimeError(msg)
_prompt_examples = (
"Example:\n"
"Q: What is the capital of France? | Paris | Lyon | Marseille | Nice | Paris\n\n"
"Q: Which planet is known as the Red Planet? | Venus | Mars | Jupiter | Saturn | Mars\n\n"
"Q: Who wrote Romeo and Juliet? | Shakespeare | Dickens | Tolkien | Austen | Shakespeare\n\n"
)
def _parse_generated(text):
text = text.strip()
if "|" in text:
parts = [p.strip() for p in text.split("|")]
if len(parts) >= 6:
question = parts[0]
opts = parts[1:5]
answer = parts[5]
return question, opts, answer
if len(parts) >= 2:
question = parts[0]
opts = parts[1:-1] if len(parts) > 2 else parts[1:]
answer = parts[-1] if len(parts) > 2 else (opts[0] if opts else "")
return question, opts, answer
lines = [l.strip() for l in text.splitlines() if l.strip()]
if lines:
first = lines[0]
if "?" in first:
question = first
opts = []
answer = ""
for l in lines[1:]:
if l.lower().startswith("answer:"):
answer = l.split(":", 1)[1].strip()
else:
cleaned = l.strip().lstrip("-").lstrip("0123456789. ").strip()
if cleaned:
opts.append(cleaned)
if not opts and "|" in first:
parts = [p.strip() for p in first.split("|")]
question = parts[0]
opts = parts[1:]
return question, opts, answer
return text, [], ""
def generate_questions(topic, n_questions=3, difficulty="medium"):
reset_user_answer()
quiz_data.clear()
score[0] = 0
current_question_index[0] = 0
try:
tokenizer, model = get_t5_tokenizer_and_model()
context = fetch_context(topic)
answers = extract_keywords(context, top_k=max(n_questions, 8))
# Basic stop-word filter to improve keyword quality (expand in wikipedia_utils if needed)
stop_words = {"a", "an", "the", "in", "on", "at", "to", "of", "and", "or", "for", "with", "is", "it", "that", "this"}
answers = [kw for kw in answers if kw.lower() not in stop_words and len(kw) > 2]
if not answers:
return []
k = min(n_questions, len(answers))
sampled_answers = random.sample(answers, k=k)
# Adjust prompt based on difficulty
diff_prompt = {
"easy": "Generate a simple multiple-choice question for beginners.",
"medium": "Generate a standard multiple-choice question.",
"hard": "Generate a challenging multiple-choice question with subtle distractors."
}.get(difficulty, "Generate a multiple-choice question.")
for ans in sampled_answers:
truncated_context = context if len(context) <= 2000 else context[:2000]
prompt = (
_prompt_examples
+ f"Now {diff_prompt} The correct answer must be '{ans}'. Use this exact format:\n"
"Question? | option1 | option2 | option3 | option4 | answer\n"
"Make the question diverse and directly related to the topic. Do not use 'passage' or generic questions.\n\n"
f"Context: {truncated_context}\n"
f"Correct Answer: {ans}\n"
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(get_device()) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=150,
num_beams=5, # Increased for better quality
temperature=0.8, # Slight increase for diversity
top_p=0.95,
no_repeat_ngram_size=3,
early_stopping=True
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
question, opts, answer = _parse_generated(decoded)
if not question.strip().endswith("?"):
question += "?" # Ensure it's a question
if not opts:
distractors_pool = [a for a in answers if a != ans]
distractors = random.sample(distractors_pool, k=min(3, len(distractors_pool))) if distractors_pool else []
opts = [ans] + distractors
question = f"What is a key concept in {topic} related to '{ans}'?" # Better fallback question
if answer == "" and opts:
answer = opts[0] if ans not in opts else ans
if answer not in opts:
if ans in opts:
answer = ans
else:
if len(opts) < 4:
extra = [a for a in answers if a not in opts and a != ans]
for e in extra[:4 - len(opts)]:
opts.append(e)
if ans not in opts:
opts[0] = ans
answer = ans
opts = opts[:4] if len(opts) >= 4 else (opts + [ans] * (4 - len(opts)))
random.shuffle(opts)
quiz_data.append({
"question": question,
"options": opts,
"answer": answer
})
return quiz_data
except Exception:
tb = traceback.format_exc()
raise RuntimeError(tb)
def get_question_ui():
reset_user_answer()
if not quiz_data:
return None
q = quiz_data[current_question_index[0]]
question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}\n\n**Choose your answer:**"
return question_display, q["options"]
def on_select(option):
if option is None:
return "", None, False
set_user_answer(option)
correct = quiz_data[current_question_index[0]]["answer"]
feedback = "✅ Correct!" if option == correct else f"❌ Incorrect. Correct answer: {correct}"
is_last = current_question_index[0] == len(quiz_data) - 1
next_label = "View Score" if is_last else "Next Question"
return feedback, False, next_label
def next_question():
if user_answer[0] is None:
return None
if user_answer[0] == quiz_data[current_question_index[0]]["answer"]:
increment_score()
increment_index()
reset_user_answer()
if current_question_index[0] >= len(quiz_data):
return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}"
else:
return get_question_ui()