Spaces:
Sleeping
Sleeping
File size: 8,055 Bytes
7002054 b15ae3a 7002054 5291f6e 20e08ca 263da5c 4bd7858 52616f0 4bd7858 b15ae3a 7002054 b15ae3a 4bd7858 bec10b3 4bd7858 263da5c 011cd6b 4bd7858 7002054 2133f86 52616f0 2133f86 bec10b3 2133f86 7002054 a52eb68 7002054 2133f86 52616f0 7002054 52616f0 7002054 2133f86 52616f0 2133f86 52616f0 2133f86 7002054 2133f86 52616f0 2133f86 52616f0 2133f86 52616f0 2133f86 4bd7858 2133f86 7002054 2133f86 7002054 36a6ca8 7002054 a52eb68 7002054 a52eb68 7002054 20e08ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
import random
import traceback
import torch
from quiz_logic.wikipedia_utils import fetch_context, extract_keywords
from quiz_logic.state import (
quiz_data,
current_question_index,
score,
user_answer,
set_user_answer,
reset_user_answer,
increment_score,
increment_index,
)
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-base") # Upgraded to base for better quality
_device = torch.device("cpu")
_t5_tokenizer = None
_t5_model = None
def get_device():
return _device
def get_t5_tokenizer_and_model():
global _t5_tokenizer, _t5_model
if _t5_tokenizer is not None and _t5_model is not None:
return _t5_tokenizer, _t5_model
last_tb = ""
try:
from transformers import T5Tokenizer, T5ForConditionalGeneration
_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir)
_t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir)
_t5_model.to(get_device())
_t5_model.eval()
return _t5_tokenizer, _t5_model
except Exception:
last_tb = traceback.format_exc()
msg = (
f"Failed to load tokenizer/model '{T5_MODEL_ID}'. Ensure cache_dir '{cache_dir}' is writable and model name is correct.\n"
f"Original error:\n{last_tb}"
)
raise RuntimeError(msg)
_prompt_examples = (
"Example:\n"
"Q: What is the capital of France? | Paris | Lyon | Marseille | Nice | Paris\n\n"
"Q: Which planet is known as the Red Planet? | Venus | Mars | Jupiter | Saturn | Mars\n\n"
"Q: Who wrote Romeo and Juliet? | Shakespeare | Dickens | Tolkien | Austen | Shakespeare\n\n"
)
def _parse_generated(text):
text = text.strip()
if "|" in text:
parts = [p.strip() for p in text.split("|")]
if len(parts) >= 6:
question = parts[0]
opts = parts[1:5]
answer = parts[5]
return question, opts, answer
if len(parts) >= 2:
question = parts[0]
opts = parts[1:-1] if len(parts) > 2 else parts[1:]
answer = parts[-1] if len(parts) > 2 else (opts[0] if opts else "")
return question, opts, answer
lines = [l.strip() for l in text.splitlines() if l.strip()]
if lines:
first = lines[0]
if "?" in first:
question = first
opts = []
answer = ""
for l in lines[1:]:
if l.lower().startswith("answer:"):
answer = l.split(":", 1)[1].strip()
else:
cleaned = l.strip().lstrip("-").lstrip("0123456789. ").strip()
if cleaned:
opts.append(cleaned)
if not opts and "|" in first:
parts = [p.strip() for p in first.split("|")]
question = parts[0]
opts = parts[1:]
return question, opts, answer
return text, [], ""
def generate_questions(topic, n_questions=3, difficulty="medium"):
reset_user_answer()
quiz_data.clear()
score[0] = 0
current_question_index[0] = 0
try:
tokenizer, model = get_t5_tokenizer_and_model()
context = fetch_context(topic)
answers = extract_keywords(context, top_k=max(n_questions, 8))
# Basic stop-word filter to improve keyword quality (expand in wikipedia_utils if needed)
stop_words = {"a", "an", "the", "in", "on", "at", "to", "of", "and", "or", "for", "with", "is", "it", "that", "this"}
answers = [kw for kw in answers if kw.lower() not in stop_words and len(kw) > 2]
if not answers:
return []
k = min(n_questions, len(answers))
sampled_answers = random.sample(answers, k=k)
# Adjust prompt based on difficulty
diff_prompt = {
"easy": "Generate a simple multiple-choice question for beginners.",
"medium": "Generate a standard multiple-choice question.",
"hard": "Generate a challenging multiple-choice question with subtle distractors."
}.get(difficulty, "Generate a multiple-choice question.")
for ans in sampled_answers:
truncated_context = context if len(context) <= 2000 else context[:2000]
prompt = (
_prompt_examples
+ f"Now {diff_prompt} The correct answer must be '{ans}'. Use this exact format:\n"
"Question? | option1 | option2 | option3 | option4 | answer\n"
"Make the question diverse and directly related to the topic. Do not use 'passage' or generic questions.\n\n"
f"Context: {truncated_context}\n"
f"Correct Answer: {ans}\n"
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(get_device()) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=150,
num_beams=5, # Increased for better quality
temperature=0.8, # Slight increase for diversity
top_p=0.95,
no_repeat_ngram_size=3,
early_stopping=True
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
question, opts, answer = _parse_generated(decoded)
if not question.strip().endswith("?"):
question += "?" # Ensure it's a question
if not opts:
distractors_pool = [a for a in answers if a != ans]
distractors = random.sample(distractors_pool, k=min(3, len(distractors_pool))) if distractors_pool else []
opts = [ans] + distractors
question = f"What is a key concept in {topic} related to '{ans}'?" # Better fallback question
if answer == "" and opts:
answer = opts[0] if ans not in opts else ans
if answer not in opts:
if ans in opts:
answer = ans
else:
if len(opts) < 4:
extra = [a for a in answers if a not in opts and a != ans]
for e in extra[:4 - len(opts)]:
opts.append(e)
if ans not in opts:
opts[0] = ans
answer = ans
opts = opts[:4] if len(opts) >= 4 else (opts + [ans] * (4 - len(opts)))
random.shuffle(opts)
quiz_data.append({
"question": question,
"options": opts,
"answer": answer
})
return quiz_data
except Exception:
tb = traceback.format_exc()
raise RuntimeError(tb)
def get_question_ui():
reset_user_answer()
if not quiz_data:
return None
q = quiz_data[current_question_index[0]]
question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}\n\n**Choose your answer:**"
return question_display, q["options"]
def on_select(option):
if option is None:
return "", None, False
set_user_answer(option)
correct = quiz_data[current_question_index[0]]["answer"]
feedback = "✅ Correct!" if option == correct else f"❌ Incorrect. Correct answer: {correct}"
is_last = current_question_index[0] == len(quiz_data) - 1
next_label = "View Score" if is_last else "Next Question"
return feedback, False, next_label
def next_question():
if user_answer[0] is None:
return None
if user_answer[0] == quiz_data[current_question_index[0]]["answer"]:
increment_score()
increment_index()
reset_user_answer()
if current_question_index[0] >= len(quiz_data):
return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}"
else:
return get_question_ui() |