Spaces:
Sleeping
Sleeping
import os | |
import random | |
import traceback | |
import torch | |
from quiz_logic.wikipedia_utils import fetch_context, extract_keywords | |
from quiz_logic.state import ( | |
quiz_data, | |
current_question_index, | |
score, | |
user_answer, | |
set_user_answer, | |
reset_user_answer, | |
increment_score, | |
increment_index, | |
) | |
cache_dir = "/tmp/hf_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-base") # Upgraded to base for better quality | |
_device = torch.device("cpu") | |
_t5_tokenizer = None | |
_t5_model = None | |
def get_device(): | |
return _device | |
def get_t5_tokenizer_and_model(): | |
global _t5_tokenizer, _t5_model | |
if _t5_tokenizer is not None and _t5_model is not None: | |
return _t5_tokenizer, _t5_model | |
last_tb = "" | |
try: | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir) | |
_t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir) | |
_t5_model.to(get_device()) | |
_t5_model.eval() | |
return _t5_tokenizer, _t5_model | |
except Exception: | |
last_tb = traceback.format_exc() | |
msg = ( | |
f"Failed to load tokenizer/model '{T5_MODEL_ID}'. Ensure cache_dir '{cache_dir}' is writable and model name is correct.\n" | |
f"Original error:\n{last_tb}" | |
) | |
raise RuntimeError(msg) | |
_prompt_examples = ( | |
"Example:\n" | |
"Q: What is the capital of France? | Paris | Lyon | Marseille | Nice | Paris\n\n" | |
"Q: Which planet is known as the Red Planet? | Venus | Mars | Jupiter | Saturn | Mars\n\n" | |
"Q: Who wrote Romeo and Juliet? | Shakespeare | Dickens | Tolkien | Austen | Shakespeare\n\n" | |
) | |
def _parse_generated(text): | |
text = text.strip() | |
if "|" in text: | |
parts = [p.strip() for p in text.split("|")] | |
if len(parts) >= 6: | |
question = parts[0] | |
opts = parts[1:5] | |
answer = parts[5] | |
return question, opts, answer | |
if len(parts) >= 2: | |
question = parts[0] | |
opts = parts[1:-1] if len(parts) > 2 else parts[1:] | |
answer = parts[-1] if len(parts) > 2 else (opts[0] if opts else "") | |
return question, opts, answer | |
lines = [l.strip() for l in text.splitlines() if l.strip()] | |
if lines: | |
first = lines[0] | |
if "?" in first: | |
question = first | |
opts = [] | |
answer = "" | |
for l in lines[1:]: | |
if l.lower().startswith("answer:"): | |
answer = l.split(":", 1)[1].strip() | |
else: | |
cleaned = l.strip().lstrip("-").lstrip("0123456789. ").strip() | |
if cleaned: | |
opts.append(cleaned) | |
if not opts and "|" in first: | |
parts = [p.strip() for p in first.split("|")] | |
question = parts[0] | |
opts = parts[1:] | |
return question, opts, answer | |
return text, [], "" | |
def generate_questions(topic, n_questions=3, difficulty="medium"): | |
reset_user_answer() | |
quiz_data.clear() | |
score[0] = 0 | |
current_question_index[0] = 0 | |
try: | |
tokenizer, model = get_t5_tokenizer_and_model() | |
context = fetch_context(topic) | |
answers = extract_keywords(context, top_k=max(n_questions, 8)) | |
# Basic stop-word filter to improve keyword quality (expand in wikipedia_utils if needed) | |
stop_words = {"a", "an", "the", "in", "on", "at", "to", "of", "and", "or", "for", "with", "is", "it", "that", "this"} | |
answers = [kw for kw in answers if kw.lower() not in stop_words and len(kw) > 2] | |
if not answers: | |
return [] | |
k = min(n_questions, len(answers)) | |
sampled_answers = random.sample(answers, k=k) | |
# Adjust prompt based on difficulty | |
diff_prompt = { | |
"easy": "Generate a simple multiple-choice question for beginners.", | |
"medium": "Generate a standard multiple-choice question.", | |
"hard": "Generate a challenging multiple-choice question with subtle distractors." | |
}.get(difficulty, "Generate a multiple-choice question.") | |
for ans in sampled_answers: | |
truncated_context = context if len(context) <= 2000 else context[:2000] | |
prompt = ( | |
_prompt_examples | |
+ f"Now {diff_prompt} The correct answer must be '{ans}'. Use this exact format:\n" | |
"Question? | option1 | option2 | option3 | option4 | answer\n" | |
"Make the question diverse and directly related to the topic. Do not use 'passage' or generic questions.\n\n" | |
f"Context: {truncated_context}\n" | |
f"Correct Answer: {ans}\n" | |
) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = {k: v.to(get_device()) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
num_beams=5, # Increased for better quality | |
temperature=0.8, # Slight increase for diversity | |
top_p=0.95, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
question, opts, answer = _parse_generated(decoded) | |
if not question.strip().endswith("?"): | |
question += "?" # Ensure it's a question | |
if not opts: | |
distractors_pool = [a for a in answers if a != ans] | |
distractors = random.sample(distractors_pool, k=min(3, len(distractors_pool))) if distractors_pool else [] | |
opts = [ans] + distractors | |
question = f"What is a key concept in {topic} related to '{ans}'?" # Better fallback question | |
if answer == "" and opts: | |
answer = opts[0] if ans not in opts else ans | |
if answer not in opts: | |
if ans in opts: | |
answer = ans | |
else: | |
if len(opts) < 4: | |
extra = [a for a in answers if a not in opts and a != ans] | |
for e in extra[:4 - len(opts)]: | |
opts.append(e) | |
if ans not in opts: | |
opts[0] = ans | |
answer = ans | |
opts = opts[:4] if len(opts) >= 4 else (opts + [ans] * (4 - len(opts))) | |
random.shuffle(opts) | |
quiz_data.append({ | |
"question": question, | |
"options": opts, | |
"answer": answer | |
}) | |
return quiz_data | |
except Exception: | |
tb = traceback.format_exc() | |
raise RuntimeError(tb) | |
def get_question_ui(): | |
reset_user_answer() | |
if not quiz_data: | |
return None | |
q = quiz_data[current_question_index[0]] | |
question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}\n\n**Choose your answer:**" | |
return question_display, q["options"] | |
def on_select(option): | |
if option is None: | |
return "", None, False | |
set_user_answer(option) | |
correct = quiz_data[current_question_index[0]]["answer"] | |
feedback = "✅ Correct!" if option == correct else f"❌ Incorrect. Correct answer: {correct}" | |
is_last = current_question_index[0] == len(quiz_data) - 1 | |
next_label = "View Score" if is_last else "Next Question" | |
return feedback, False, next_label | |
def next_question(): | |
if user_answer[0] is None: | |
return None | |
if user_answer[0] == quiz_data[current_question_index[0]]["answer"]: | |
increment_score() | |
increment_index() | |
reset_user_answer() | |
if current_question_index[0] >= len(quiz_data): | |
return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}" | |
else: | |
return get_question_ui() |