import itertools import json import os import sys from collections import Counter, defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import fire import jsonlines import numpy as np import tqdm sys.path.extend( [Path(__file__).parent.parent, Path(__file__).parent.parent / "execution_engine"] ) # exit(0) # sys.path.extend([ from api_comm import APICommunication from exec_outcome import ExecOutcome from yaml import safe_load def estimate_pass_at_k( num_samples: int | list[int] | np.ndarray, num_correct: list[int] | np.ndarray, k: int, ) -> np.ndarray: """ Estimates pass@k of each problem and returns them in an array. """ def estimator(n: int, c: int, k: int): """ Calculates 1 - comb(n - c, k) / comb(n, k). """ if n - c < k: return 1.0 return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) if isinstance(num_samples, int): num_samples_it = itertools.repeat(num_samples, len(num_correct)) else: assert len(num_samples) == len(num_correct) num_samples_it = iter(num_samples) return np.array( [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] ) def evaluate_functional_correctness( sample_file: str, k: list[int] = [1, 10, 100], n_workers: int = 4, limits_by_lang: dict = {}, compile_n_execute_args_by_lang: dict = {}, eval_result_file: str | None = None, unittest_file: str = "unittest_db.json", execeval_url: str = "http://localhost:5000", block_network: bool = True, stop_on_first_fail: bool = True, use_sanitizer: bool = False, ): """ Evaluates the functional correctness of generated samples, and writes results to f"{sample_file}_results.jsonl.gz" """ if eval_result_file is None: eval_result_file = f"{sample_file.split('.')[0]}-evaluated.jsonl" with open(unittest_file) as ut_rp: unittest_db = json.load(ut_rp) # Check the generated samples against test suites. with APICommunication(execeval_url) as execeval: execute_code = execeval.execute_code supported_langs = {r["runtime_name"] for r in execeval.get_runtimes()} with ThreadPoolExecutor(max_workers=n_workers) as executor: futures = [] completion_id = Counter() n_samples = 0 results = defaultdict(list) with jsonlines.open(sample_file) as sample_rp: for idx, sample in tqdm.tqdm( enumerate(sample_rp), desc="Reading samples" ): src_uid = sample["src_uid"] source_code = sample["source_code"] task_id = sample["task_id"] lang = sample["lang"] if src_uid not in unittest_db: continue unittests = unittest_db[src_uid] if len(unittests) == 0: continue if lang not in supported_langs: continue args = ( lang, source_code, unittests, limits_by_lang[lang], block_network, stop_on_first_fail, use_sanitizer, compile_n_execute_args_by_lang.get(lang, {}).get("compile_cmd"), compile_n_execute_args_by_lang.get(lang, {}).get( "compile_flags" ), compile_n_execute_args_by_lang.get(lang, {}).get("execute_cmd"), compile_n_execute_args_by_lang.get(lang, {}).get( "execute_flags" ), idx, task_id, ) future = executor.submit(execute_code, *args) futures.append(future) completion_id[task_id] += 1 n_samples += 1 print("Running test suites...") for idx, future in tqdm.tqdm( enumerate(as_completed(futures)), desc="Test running", total=len(futures), ): result = future.result() unittests, sample_idx, task_id = result if not isinstance(unittests, list) and "error" in unittests: """ [TODO] log it """ print("ERROR: ", unittests["error"]) continue results[task_id].append((sample_idx, unittests)) print("Calculate pass@k.") total, correct = [], [] for result in results.values(): result.sort() passed = [ all(x["exec_outcome"] == ExecOutcome.PASSED.value for x in r[1]) for r in result ] total.append(len(passed)) correct.append(sum(passed)) total = np.array(total) correct = np.array(correct) ks = k pass_at_k = { f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all() } # Finally, save the results in one file: def combine_results(): with jsonlines.open(sample_file) as sample_rp: cnt = 0 for idx, sample in enumerate(sample_rp): cnt += 1 if sample["lang"] not in supported_langs: continue task_id = sample["task_id"] if len(results[task_id]) == 0: continue if results[task_id][0][0] > idx: continue result = results[task_id].pop(0) sample["unittests"] = result[1] _exec_outcomes = [ r["exec_outcome"] for r in result[1] if r["exec_outcome"] != ExecOutcome.PASSED.value ] + [ExecOutcome.PASSED.value] sample["exec_outcome"] = _exec_outcomes[0] yield sample print(f"Writing results to {eval_result_file}...") with jsonlines.open(eval_result_file, "w") as result_wp: for result in tqdm.tqdm(combine_results(), total=n_samples): result_wp.write(result) return pass_at_k def entry_point( sample_file: str, k: str | list | tuple = "1,2,5,10", n_workers: int = 4, compile_n_execute_args_by_lang_cfg_file: str | None = None, limits_by_lang_cfg_file: str | None = None, unittest_file: str = "unittest_db.json", execeval_url: str = "http://localhost:5000", block_network: bool = True, stop_on_first_fail: bool = True, use_sanitizer: bool = False, ): """ Evaluates the functional correctness of generated samples, and writes results to f"{sample_file}_results.jsonl.gz" """ """ [TODO] compile_n_execute_args_by_lang_cfg_file: str | None = None, limits_by_lang_cfg_file: str | None = None, assume yaml files and consider config.yaml for compile..args, and resource_limits.py for limits_by_lang """ limits_by_lang, compile_n_execute_args_by_lang = None, {} if limits_by_lang_cfg_file is None: limits_by_lang_cfg_file = "limits_by_lang.yaml" if not os.path.exists(limits_by_lang_cfg_file): print( "Need resource limit defaults for all runtimes, provide the path to default 'limits_by_lang.yaml' or to the modified one." ) exit(-1) with open(limits_by_lang_cfg_file) as limit_cfg_rp: limits_by_lang = safe_load(limit_cfg_rp) if compile_n_execute_args_by_lang_cfg_file is not None and os.path.exists( compile_n_execute_args_by_lang_cfg_file ): with open( compile_n_execute_args_by_lang_cfg_file ) as compile_n_execute_args_by_lang_rp: compile_n_execute_args_by_lang = safe_load( compile_n_execute_args_by_lang_rp ) ks = list(map(int, k.split(","))) if isinstance(k, str) else list(k) results = evaluate_functional_correctness( sample_file, ks, n_workers, block_network=block_network, limits_by_lang=limits_by_lang, compile_n_execute_args_by_lang=compile_n_execute_args_by_lang, unittest_file=unittest_file, execeval_url=execeval_url, stop_on_first_fail=stop_on_first_fail, use_sanitizer=use_sanitizer, ) print(results) def main(): fire.Fire(entry_point) sys.exit(main())