File size: 8,055 Bytes
7002054
 
b15ae3a
 
7002054
 
 
 
 
 
 
 
 
 
 
5291f6e
20e08ca
263da5c
4bd7858
52616f0
4bd7858
 
b15ae3a
 
 
 
7002054
b15ae3a
 
 
 
 
4bd7858
bec10b3
4bd7858
263da5c
 
011cd6b
 
 
4bd7858
 
 
 
 
 
 
7002054
2133f86
 
 
 
52616f0
2133f86
 
 
 
 
 
 
 
 
 
 
 
 
 
bec10b3
2133f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7002054
 
 
 
 
a52eb68
7002054
 
2133f86
52616f0
 
 
 
 
7002054
 
 
 
52616f0
 
 
 
 
 
 
 
7002054
2133f86
 
 
52616f0
 
 
2133f86
52616f0
2133f86
 
7002054
2133f86
 
 
52616f0
 
 
2133f86
 
 
 
 
52616f0
 
2133f86
 
 
 
52616f0
2133f86
4bd7858
2133f86
 
 
 
 
 
 
 
 
 
 
 
 
7002054
 
2133f86
 
7002054
 
 
 
 
36a6ca8
7002054
 
 
 
 
 
 
a52eb68
7002054
 
 
 
 
 
 
 
 
a52eb68
7002054
 
 
 
 
 
 
 
 
 
20e08ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import random
import traceback
import torch
from quiz_logic.wikipedia_utils import fetch_context, extract_keywords
from quiz_logic.state import (
    quiz_data,
    current_question_index,
    score,
    user_answer,
    set_user_answer,
    reset_user_answer,
    increment_score,
    increment_index,
)

cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)

T5_MODEL_ID = os.environ.get("T5_MODEL_ID", "google/flan-t5-base")  # Upgraded to base for better quality

_device = torch.device("cpu")
_t5_tokenizer = None
_t5_model = None

def get_device():
    return _device

def get_t5_tokenizer_and_model():
    global _t5_tokenizer, _t5_model
    if _t5_tokenizer is not None and _t5_model is not None:
        return _t5_tokenizer, _t5_model
    last_tb = ""
    try:
        from transformers import T5Tokenizer, T5ForConditionalGeneration
        _t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir)
        _t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_ID, cache_dir=cache_dir)
        _t5_model.to(get_device())
        _t5_model.eval()
        return _t5_tokenizer, _t5_model
    except Exception:
        last_tb = traceback.format_exc()
    msg = (
        f"Failed to load tokenizer/model '{T5_MODEL_ID}'. Ensure cache_dir '{cache_dir}' is writable and model name is correct.\n"
        f"Original error:\n{last_tb}"
    )
    raise RuntimeError(msg)

_prompt_examples = (
    "Example:\n"
    "Q: What is the capital of France? | Paris | Lyon | Marseille | Nice | Paris\n\n"
    "Q: Which planet is known as the Red Planet? | Venus | Mars | Jupiter | Saturn | Mars\n\n"
    "Q: Who wrote Romeo and Juliet? | Shakespeare | Dickens | Tolkien | Austen | Shakespeare\n\n"
)

def _parse_generated(text):
    text = text.strip()
    if "|" in text:
        parts = [p.strip() for p in text.split("|")]
        if len(parts) >= 6:
            question = parts[0]
            opts = parts[1:5]
            answer = parts[5]
            return question, opts, answer
        if len(parts) >= 2:
            question = parts[0]
            opts = parts[1:-1] if len(parts) > 2 else parts[1:]
            answer = parts[-1] if len(parts) > 2 else (opts[0] if opts else "")
            return question, opts, answer
    lines = [l.strip() for l in text.splitlines() if l.strip()]
    if lines:
        first = lines[0]
        if "?" in first:
            question = first
            opts = []
            answer = ""
            for l in lines[1:]:
                if l.lower().startswith("answer:"):
                    answer = l.split(":", 1)[1].strip()
                else:
                    cleaned = l.strip().lstrip("-").lstrip("0123456789. ").strip()
                    if cleaned:
                        opts.append(cleaned)
            if not opts and "|" in first:
                parts = [p.strip() for p in first.split("|")]
                question = parts[0]
                opts = parts[1:]
            return question, opts, answer
    return text, [], ""

def generate_questions(topic, n_questions=3, difficulty="medium"):
    reset_user_answer()
    quiz_data.clear()
    score[0] = 0
    current_question_index[0] = 0
    try:
        tokenizer, model = get_t5_tokenizer_and_model()
        context = fetch_context(topic)
        answers = extract_keywords(context, top_k=max(n_questions, 8))
        
        # Basic stop-word filter to improve keyword quality (expand in wikipedia_utils if needed)
        stop_words = {"a", "an", "the", "in", "on", "at", "to", "of", "and", "or", "for", "with", "is", "it", "that", "this"}
        answers = [kw for kw in answers if kw.lower() not in stop_words and len(kw) > 2]
        
        if not answers:
            return []
        k = min(n_questions, len(answers))
        sampled_answers = random.sample(answers, k=k)
        
        # Adjust prompt based on difficulty
        diff_prompt = {
            "easy": "Generate a simple multiple-choice question for beginners.",
            "medium": "Generate a standard multiple-choice question.",
            "hard": "Generate a challenging multiple-choice question with subtle distractors."
        }.get(difficulty, "Generate a multiple-choice question.")
        
        for ans in sampled_answers:
            truncated_context = context if len(context) <= 2000 else context[:2000]
            prompt = (
                _prompt_examples
                + f"Now {diff_prompt} The correct answer must be '{ans}'. Use this exact format:\n"
                "Question? | option1 | option2 | option3 | option4 | answer\n"
                "Make the question diverse and directly related to the topic. Do not use 'passage' or generic questions.\n\n"
                f"Context: {truncated_context}\n"
                f"Correct Answer: {ans}\n"
            )
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(get_device()) for k, v in inputs.items()}
            outputs = model.generate(
                **inputs,
                max_new_tokens=150,
                num_beams=5,  # Increased for better quality
                temperature=0.8,  # Slight increase for diversity
                top_p=0.95,
                no_repeat_ngram_size=3,
                early_stopping=True
            )
            decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
            question, opts, answer = _parse_generated(decoded)
            if not question.strip().endswith("?"):
                question += "?"  # Ensure it's a question
            if not opts:
                distractors_pool = [a for a in answers if a != ans]
                distractors = random.sample(distractors_pool, k=min(3, len(distractors_pool))) if distractors_pool else []
                opts = [ans] + distractors
                question = f"What is a key concept in {topic} related to '{ans}'?"  # Better fallback question
            if answer == "" and opts:
                answer = opts[0] if ans not in opts else ans
            if answer not in opts:
                if ans in opts:
                    answer = ans
                else:
                    if len(opts) < 4:
                        extra = [a for a in answers if a not in opts and a != ans]
                        for e in extra[:4 - len(opts)]:
                            opts.append(e)
                    if ans not in opts:
                        opts[0] = ans
                    answer = ans
            opts = opts[:4] if len(opts) >= 4 else (opts + [ans] * (4 - len(opts)))
            random.shuffle(opts)
            quiz_data.append({
                "question": question,
                "options": opts,
                "answer": answer
            })
        return quiz_data
    except Exception:
        tb = traceback.format_exc()
        raise RuntimeError(tb)

def get_question_ui():
    reset_user_answer()
    if not quiz_data:
        return None
    q = quiz_data[current_question_index[0]]
    question_display = f"### Question {current_question_index[0] + 1}\n{q['question']}\n\n**Choose your answer:**"
    return question_display, q["options"]

def on_select(option):
    if option is None:
        return "", None, False
    set_user_answer(option)
    correct = quiz_data[current_question_index[0]]["answer"]
    feedback = "✅ Correct!" if option == correct else f"❌ Incorrect. Correct answer: {correct}"
    is_last = current_question_index[0] == len(quiz_data) - 1
    next_label = "View Score" if is_last else "Next Question"
    return feedback, False, next_label

def next_question():
    if user_answer[0] is None:
        return None
    if user_answer[0] == quiz_data[current_question_index[0]]["answer"]:
        increment_score()
    increment_index()
    reset_user_answer()
    if current_question_index[0] >= len(quiz_data):
        return f"🎉 Quiz finished! Your score: {score[0]}/{len(quiz_data)}"
    else:
        return get_question_ui()