Spaces:
Runtime error
Runtime error
File size: 1,501 Bytes
6c09f76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import json
from dataclasses import dataclass
from datetime import datetime
from jinja2 import Template
@dataclass
class Question:
id: int
text: str
answer_format: type
user_answer: any = None
class Session:
def __init__(self, questions):
self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
self.questions = questions
# self.questions = self.process_questions(questions)
@staticmethod
def process_questions(questions):
qq = {}
for q in questions:
if q["answer_format"] == "number":
Q = Question(q["id"], q["text"], int, None)
elif q["answer_format"] == "text":
Q = Question(q["id"], q["text"], str, None)
elif q["answer_format"] == "list":
Q = Question(q["id"], q["text"], list, None)
else:
raise ValueError("Invalid answer format")
qq[q["id"]] = Q
return qq
def answer_question(self, question_id, user_answer):
self.questions[question_id].user_answer = user_answer
def get_next_question(self):
for q in self.questions:
if q.user_answer:
return q
return False
def zero_shot_prompt(self, prompt_template_path):
with open(prompt_template_path) as f:
template_str = f.read()
template = Template(template_str)
return template.render(questions=json.dumps(self.questions, indent=4))
|