| | import os |
| | import re |
| | import subprocess |
| | import tempfile |
| | import multiprocessing |
| | from collections import Counter |
| | from contextlib import contextmanager |
| | from dataclasses import dataclass |
| |
|
| |
|
| | class PythonREPL: |
| | def __init__(self, timeout=5): |
| | self.timeout = timeout |
| |
|
| | @staticmethod |
| | def _run_code(temp_file_path): |
| | result = subprocess.run( |
| | ["python3", temp_file_path], |
| | capture_output=True, |
| | check=False, |
| | text=True |
| | ) |
| | if result.returncode == 0: |
| | return True, result.stdout.strip() |
| | else: |
| | error_msg = result.stderr.strip() |
| | msgs = error_msg.split("\n") |
| | new_msgs = [] |
| | want_next = False |
| | for m in msgs: |
| | if "Traceback" in m: |
| | new_msgs.append(m) |
| | elif m == msgs[-1]: |
| | new_msgs.append(m) |
| | elif temp_file_path in m: |
| | st = m.index('"/') + 1 if '"/' in m else 0 |
| | ed = m.index(temp_file_path) + 1 if temp_file_path in m else None |
| | clr = m[st:ed] if not ed else m[st:] |
| | m = m.replace(clr, "") |
| | new_msgs.append(m) |
| | want_next = True |
| | elif want_next: |
| | new_msgs.append(m) |
| | want_next = False |
| | return False, "\n".join(new_msgs).strip() |
| |
|
| | def __call__(self, query): |
| | query = "import math\nimport numpy as np\nimport sympy as sp\n" + query |
| | query = query.strip().split("\n") |
| | if "print(" not in query[-1]: |
| | if "#" in query[-1]: |
| | query[-1] = query[-1].split("#")[0] |
| | query[-1] = "print(" + query[-1] + ")" |
| | query = "\n".join(query) |
| | |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | temp_file_path = os.path.join(temp_dir, "tmp.py") |
| | with open(temp_file_path, "w", encoding="utf-8") as f: |
| | f.write(query) |
| |
|
| | with multiprocessing.Pool(1) as pool: |
| | result = pool.apply_async(self._run_code, (temp_file_path,)) |
| | try: |
| | success, output = result.get(self.timeout) |
| | except multiprocessing.TimeoutError: |
| | pool.terminate() |
| | return False, f"Timed out after {self.timeout} seconds." |
| | return success, output |
| |
|
| |
|
| | def execute_completion(executor, completion, return_status, last_code_block): |
| | executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) |
| | if len(executions) == 0: |
| | return completion, False if return_status else completion |
| | if last_code_block: |
| | executions = [executions[-1]] |
| | outputs = [] |
| | successes = [] |
| | for code in executions: |
| | success = False |
| | for lib in ("subprocess", "venv"): |
| | if lib in code: |
| | output = f"{lib} is not allowed" |
| | outputs.append(output) |
| | successes.append(success) |
| | continue |
| | try: |
| | success, output = executor(code) |
| | except TimeoutError as e: |
| | print("Code timed out") |
| | output = e |
| | if not success and not return_status: |
| | output = "" |
| | outputs.append(output) |
| | successes.append(success) |
| | output = str(outputs[-1]).strip() |
| | success = successes[-1] |
| | if return_status: |
| | return output, success |
| | return output |
| |
|
| |
|
| | def postprocess_completion(text, return_status, last_code_block): |
| | executor = PythonREPL() |
| | result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) |
| | del executor |
| | return result |
| |
|
| |
|
| | def get_majority_vote(answers): |
| | if not len(answers): |
| | return 0 |
| | c = Counter(answers) |
| | value, _ = c.most_common()[0] |
| | return value |