File size: 1,501 Bytes
6c09f76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
from dataclasses import dataclass
from datetime import datetime

from jinja2 import Template


@dataclass
class Question:
    id: int
    text: str
    answer_format: type
    user_answer: any = None


class Session:
    def __init__(self, questions):
        self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.questions = questions
        # self.questions = self.process_questions(questions)

    @staticmethod
    def process_questions(questions):
        qq = {}
        for q in questions:
            if q["answer_format"] == "number":
                Q = Question(q["id"], q["text"], int, None)
            elif q["answer_format"] == "text":
                Q = Question(q["id"], q["text"], str, None)
            elif q["answer_format"] == "list":
                Q = Question(q["id"], q["text"], list, None)
            else:
                raise ValueError("Invalid answer format")
            qq[q["id"]] = Q
        return qq

    def answer_question(self, question_id, user_answer):
        self.questions[question_id].user_answer = user_answer

    def get_next_question(self):
        for q in self.questions:
            if q.user_answer:
                return q
        return False

    def zero_shot_prompt(self, prompt_template_path):
        with open(prompt_template_path) as f:
            template_str = f.read()
            template = Template(template_str)
            return template.render(questions=json.dumps(self.questions, indent=4))