Spaces:
Runtime error
Runtime error
import glob | |
import json | |
import os | |
from dataclasses import dataclass | |
from src.display.formatting import make_hyperlink | |
from src.display.utils import AutoEvalColumn | |
class EvalResult: | |
"""Represents one full evaluation. Built from a combination of the result and request file for a given run.""" | |
method_name: str | |
method_url: str | |
model_name: str | |
model_url: str | |
with_hint: bool | |
attempts: int | |
fast_pass_count: int | |
full_pass_count: int | |
full_pass_count_crash: int | |
full_pass_count_hang: int | |
full_pass_count_miscompilation: int | |
build_count: int | |
build_failure_count: int | |
mttr: float | |
sample_count: float | |
fixed_bug_ids: list[str] | |
fixed_bug_ids_fast: list[str] | |
patches: dict[str, str] | |
def init_from_json_file(self, json_filepath): | |
"""Inits the result from the specific model result file""" | |
with open(json_filepath) as fp: | |
data = json.load(fp) | |
method_name = data.get("method_name", "") | |
method_url = data.get("method_url", "") | |
model_name = data.get("base_model_name", "") | |
model_url = data.get("base_model_url", "") | |
with_hint = data.get("with_hint", False) | |
fixes = data.get("fixes", []) | |
attempts = len(fixes) | |
fast_pass_count = 0 | |
full_pass_count = 0 | |
full_pass_count_cat = {} | |
build_count = 0 | |
build_failure_count = 0 | |
ttr_sum = 0 | |
fixed_bug_ids = [] | |
fixed_bug_ids_fast = [] | |
sample_count = 0 | |
patches = dict() | |
for fix in fixes: | |
bug_type = fix.get("bug_type", "") | |
if fix.get("fast_check_pass", False): | |
fast_pass_count += 1 | |
fixed_bug_ids_fast.append(fix.get("bug_id", "")) | |
if fix.get("full_check_pass", False): | |
full_pass_count += 1 | |
full_pass_count_cat[bug_type] = full_pass_count_cat.get(bug_type, 0) + 1 | |
ttr_sum += fix.get("wall_time", 0) | |
fixed_bug_ids.append(fix.get("bug_id", "")) | |
sample_count += fix.get("fast_check_count", 0) + fix.get("full_check_count", 0) | |
build_count += fix.get("build_count", 0) | |
build_failure_count += fix.get("build_failure_count", 0) | |
patch = "" | |
patch += f"// Fast check: {fix.get('fast_check_pass', False)}\n" | |
patch += f"// Full check: {fix.get('full_check_pass', False)}\n" | |
patch += fix.get("patch", "") | |
patches[fix.get("bug_id", "")] = patch | |
return self( | |
method_name=method_name, | |
method_url=method_url, | |
model_name=model_name, | |
model_url=model_url, | |
with_hint=with_hint, | |
attempts=attempts, | |
fast_pass_count=fast_pass_count, | |
full_pass_count=full_pass_count, | |
full_pass_count_crash=full_pass_count_cat.get("crash", 0), | |
full_pass_count_hang=full_pass_count_cat.get("hang", 0), | |
full_pass_count_miscompilation=full_pass_count_cat.get("miscompilation", 0), | |
build_count=build_count, | |
build_failure_count=build_failure_count, | |
mttr=round(ttr_sum / full_pass_count / 60, 1) if full_pass_count > 0 else 0, | |
fixed_bug_ids=fixed_bug_ids, | |
fixed_bug_ids_fast=fixed_bug_ids_fast, | |
sample_count=round(sample_count / full_pass_count, 1) if full_pass_count > 0 else 0, | |
patches=patches, | |
) | |
def to_dict(self, total_issues): | |
"""Converts the Eval Result to a dict compatible with our dataframe display""" | |
data_dict = { | |
AutoEvalColumn.method_name.name: make_hyperlink(self.method_url, self.method_name), | |
AutoEvalColumn.model_name.name: make_hyperlink(self.model_url, self.model_name), | |
AutoEvalColumn.with_hint.name: "w/ hint" if self.with_hint else "w/o hint", | |
AutoEvalColumn.score.name: round(self.full_pass_count * 100.0 / total_issues, 1), | |
AutoEvalColumn.ratio.name: round(self.full_pass_count * 100.0 / self.attempts, 1), | |
AutoEvalColumn.attempts.name: self.attempts, | |
AutoEvalColumn.fast_pass_count.name: self.fast_pass_count, | |
AutoEvalColumn.full_pass_count.name: self.full_pass_count, | |
AutoEvalColumn.full_pass_count_crash.name: self.full_pass_count_crash, | |
AutoEvalColumn.full_pass_count_hang.name: self.full_pass_count_hang, | |
AutoEvalColumn.full_pass_count_miscompilation.name: self.full_pass_count_miscompilation, | |
AutoEvalColumn.build_success_rate.name: round( | |
(self.build_count - self.build_failure_count) * 100.0 / self.build_count, 1 | |
), | |
AutoEvalColumn.mttr.name: self.mttr, | |
"fixed_bug_ids": self.fixed_bug_ids, | |
"fixed_bug_ids_fast": self.fixed_bug_ids_fast, | |
"method_id": self.method_name + "(" + self.model_name + ")", | |
"patches": self.patches, | |
AutoEvalColumn.sample_count.name: self.sample_count, | |
} | |
return data_dict | |
def get_raw_eval_results(requests_path: str) -> list[EvalResult]: | |
"""From the path of the results folder root, extract all needed info for results""" | |
results = [] | |
for root, _, files in os.walk(requests_path): | |
for file in files: | |
if file.endswith(".json"): | |
results.append(EvalResult.init_from_json_file(os.path.join(root, file))) | |
return results | |