| | |
| |
|
| |
|
| | from collections import deque |
| | from fastai.text.all import load_learner |
| | import sys |
| | import re |
| |
|
| | |
| | MAX_HISTORY_CHARS = 800 |
| | MAX_HISTORY_MESSAGES = 1 |
| | GENERATE_TOKENS = 70 |
| | TEMPERATURE = 0.3 |
| |
|
| | def evaluate_placeholders(text: str) -> str: |
| | def repl(match): |
| | expr = match.group(1) |
| | try: |
| | if re.fullmatch(r"[\d\s\+\-\*\/]+", expr): |
| | return str(eval(expr)) |
| | except Exception: |
| | pass |
| | return match.group(0) |
| |
|
| | return re.sub(r"\{([^{}]+)\}", repl, text) |
| |
|
| | def remove_before_first_colon(s: str) -> str: |
| | return s.split("BOT :", 1)[-1] |
| | def remove_before_last_colon(s: str) -> str: |
| | return s.rsplit(":", 1)[-1] |
| | def remove_after_user(text): |
| | keyword = "USER" |
| | index = text.find(keyword) |
| | if index != -1: |
| | return text[:index + len(keyword)] |
| | return text |
| | def remove_after_bot(text): |
| | keyword = "BOT" |
| | index = text.find(keyword) |
| | if index != -1: |
| | return text[:index + len(keyword)] |
| | return text |
| |
|
| | def truncate(answer): |
| | for sep in ["\n", "USER:", "BOT:"]: |
| | if sep in answer: |
| | answer = answer.split(sep)[0] |
| |
|
| | answer = remove_before_first_colon(answer) |
| | answer = remove_after_user(answer) |
| | answer = remove_after_bot(answer) |
| |
|
| | answer = answer.replace(": USER", "").replace(" USER", "").replace("USER", "").replace(" !", "!").replace(" .", ".").replace(" ,", ",").replace(": BOT", "").replace(" BOT", "").replace("BOT", "").replace(" `", "`").replace(' "', '"').replace(" β", "β").replace("do n'", "don'").replace("do nβ", "donβ") |
| | answer = answer.replace(" '", "'").replace(" :", ":").replace(" (", "(").replace(" )", ")").replace(" ?", "?").replace("Open Assistant", "Bomba-1") |
| |
|
| | return answer.strip() |
| |
|
| | def load_models(): |
| | print("π€ Loading modelsβ¦") |
| | chat_model = load_learner("model/SimpleMath.pkl") |
| | chat_model.model.eval() |
| | return chat_model |
| |
|
| | def main(): |
| | chat_model = load_models() |
| | history = deque() |
| | print("π¬ Ready! (empty line to quit)\n") |
| |
|
| | while True: |
| | try: |
| | user = input("USER: ").strip() |
| | if not user: |
| | break |
| |
|
| | history.append(f"USER: {user}") |
| | while len(history) > MAX_HISTORY_MESSAGES: |
| | history.popleft() |
| |
|
| | prompt_lines = list(history) |
| | prompt_text = " ".join(history).replace("\n"," ") |
| | if len(prompt_text) > MAX_HISTORY_CHARS: |
| | prompt_text = prompt_text[-MAX_HISTORY_CHARS:] |
| | prompt = f"{prompt_text} BOT: " |
| |
|
| | generated = chat_model.predict( |
| | prompt, |
| | n_words=GENERATE_TOKENS, |
| | temperature=TEMPERATURE, |
| | min_p=0.01 |
| | ) |
| | |
| | try: |
| | _, raw = generated.split(prompt, 1) |
| | except ValueError: |
| | raw = generated |
| |
|
| | raw = raw.strip() |
| | if raw.upper().startswith("USER:") and "BOT:" in raw: |
| | raw = raw.split("BOT:", 1)[1].strip() |
| |
|
| | answer = truncate(raw) |
| | answer = evaluate_placeholders(answer) |
| | answer = answer.replace("-", "\n-").replace("1)", "\n1)").replace("2)", "\n2)").replace("3)", "\n3)").replace("4)", "\n4)").replace("5)", "\n5)").replace("* ", "\n* ").replace("Final", "\nFinal") |
| | if not "Final" in answer: |
| | answer = answer.replace("Result", "\nResult") |
| | print("BOT:", answer, "\n") |
| | history.append(f"BOT: {answer}") |
| |
|
| | except KeyboardInterrupt: |
| | break |
| |
|
| | if __name__ == "__main__": |
| | main() |