update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
#!/usr/bin/env python | |
import argparse | |
import json | |
import os | |
from pathlib import Path | |
import pandas as pd | |
from pie_modules.utils import flatten_dict | |
def str2record(s: str | None, sep_parts: str = "-", sep_k_v: str = "=") -> pd.Series: | |
if s is None or s.strip() == "" or s == "None": | |
return pd.Series() | |
return pd.Series(dict(k_v.split(sep_k_v, 1) for k_v in s.split(sep_parts))) | |
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]: | |
parts = path_and_maybe_id.split(separator, 1) | |
if len(parts) == 1: | |
return None, parts[0] | |
return parts[0], parts[1] | |
def load_data_from_json(path: str | Path) -> pd.DataFrame: | |
with open(path, "r") as f: | |
data_json = json.load(f) | |
data_flat = flatten_dict(data_json) | |
return pd.DataFrame(data_flat) | |
def main( | |
path: str | Path, | |
remove_col_prefix: str | None = None, | |
sparse_col_prefix: str | None = None, | |
tail_cols: list[str] | None = None, | |
sort_cols: list[str] | None = None, | |
split_col: str | None = None, | |
replace_in_col_names: list[tuple[str, str]] | None = None, | |
round_precision: int | None = None, | |
in_percent: bool = False, | |
common_prefix_separator: str | None = None, | |
column_regex_blacklist: list[str] | None = None, | |
column_regex_whitelist: list[str] | None = None, | |
format: str = "markdown", | |
) -> None: | |
if str(path).lower().endswith(".json"): | |
result = load_data_from_json(path) | |
elif str(path).lower().endswith(".txt"): | |
with open(path, "r") as f: | |
index_data = [separate_path_and_id(line.strip()) for line in f.readlines()] | |
data_list = [] | |
for meta_id, meta_path in index_data: | |
data = load_data_from_json(os.path.join(meta_path, "job_return_value.json")) | |
if meta_id is not None: | |
job_id_prefix = meta_id.replace(",", "-") | |
data["job_id"] = job_id_prefix + "-" + data["job_id"].astype(str) | |
data = data.set_index("job_id") | |
data_list.append(data) | |
result = pd.concat(data_list, axis=1).reset_index() | |
else: | |
raise ValueError("Unsupported file format. Please provide a .json or .txt file.") | |
if remove_col_prefix is not None: | |
result.columns = result.columns.str.replace(r"^" + remove_col_prefix, "", regex=True) | |
if sparse_col_prefix is not None: | |
# get all columns that contain just one not-nan value | |
# number_of_non_nan_values = len(df) - df.isna().sum() | |
# df_sparse = df.loc[:, number_of_non_nan_values == 1] | |
sparse_cols = [col for col in result.columns if col.startswith(sparse_col_prefix)] | |
other_cols = [col for col in result.columns if col not in sparse_cols] | |
value_col = f"{sparse_col_prefix}value" | |
name_col = f"{sparse_col_prefix}name" | |
result = result.melt( | |
id_vars=other_cols, value_vars=sparse_cols, var_name=name_col, value_name=value_col | |
).dropna( | |
subset=[value_col] | |
) # keep rows with a value | |
# strip the "f1-" prefix, leaving just the numeric threshold | |
result[name_col] = result[name_col].str.replace(r"^" + sparse_col_prefix, "", regex=True) | |
# convert the column to numeric (if possible) | |
try: | |
result[name_col] = pd.to_numeric(result[name_col]) | |
except ValueError: | |
# if it fails, just keep it as a string | |
pass | |
if split_col is not None: | |
new_frame = result[split_col].apply(str2record) | |
result = pd.concat([result.drop(columns=[split_col]), new_frame], axis=1) | |
if in_percent: | |
float_columns = result.select_dtypes(include=["float64", "float32"]).columns | |
result[float_columns] = result[float_columns] * 100 | |
if round_precision is not None: | |
# round all columns to the given precision | |
result = result.round(round_precision) | |
if common_prefix_separator is not None: | |
# remove common prefix from values in all string columns | |
obj_columns = result.select_dtypes(include=["object"]).columns | |
for obj_col in obj_columns: | |
# get the common prefix | |
common_prefix = os.path.commonprefix(result[obj_col].dropna().astype(str).tolist()) | |
# find last occurrence of the common_prefix_separator | |
last_occurrence = common_prefix.rfind(common_prefix_separator) | |
if last_occurrence != -1: | |
# truncate the common prefix after the last occurrence of the separator | |
common_prefix = common_prefix[: last_occurrence + len(common_prefix_separator)] | |
# remove the common prefix (including the separator) from the column | |
result[obj_col] = result[obj_col].str.replace(r"^" + common_prefix, "", regex=True) | |
# sort columns to get a deterministic order | |
result = result.sort_index(axis=1) | |
if tail_cols is not None: | |
front_cols = [c for c in result.columns if c not in tail_cols] | |
result = result[front_cols + tail_cols] | |
if sort_cols is not None: | |
result = result.sort_values(sort_cols) | |
# also move the sort columns to the front | |
result = result[sort_cols + [c for c in result.columns if c not in sort_cols]] | |
if column_regex_blacklist is not None: | |
# remove columns that match any of the regex patterns in the blacklist | |
for pattern in column_regex_blacklist: | |
result = result.loc[:, ~result.columns.str.contains(pattern, regex=True)] | |
if column_regex_whitelist is not None: | |
# keep only columns that match any of the regex patterns in the whitelist | |
result = result.loc[ | |
:, result.columns.str.contains("|".join(column_regex_whitelist), regex=True) | |
] | |
if replace_in_col_names is not None: | |
for old_value, new_value in replace_in_col_names: | |
result.columns = result.columns.str.replace(old_value, new_value, regex=False) | |
if format == "markdown": | |
result_str = result.to_markdown(index=False) | |
elif format == "csv": | |
result_str = result.to_csv(index=False) | |
elif format == "tsv": | |
result_str = result.to_csv(index=False, sep="\t") | |
elif format == "json": | |
result_str = result.to_json(orient="records", lines=True) | |
else: | |
raise ValueError( | |
f"Unsupported format: {format}. Supported formats are: markdown, csv, json." | |
) | |
print(result_str) | |
if __name__ == "__main__": | |
""" | |
Example usage: | |
python src/analysis/format_metric_results.py \ | |
logs/document_evaluation/multiruns/default/2025-05-21_11-59-19/job_return_value.json \ | |
--remove-col-prefix train/ \ | |
--sparse-col-prefix f1- \ | |
--split-col job_id \ | |
--tail-cols num_positives num_total \ | |
--sort-cols experiment model \ | |
--round-precision 4 | |
""" | |
parser = argparse.ArgumentParser( | |
description="Process a JSON file containing metric results (from multirun) and print as Markdown table." | |
) | |
parser.add_argument( | |
"path", | |
type=str, | |
help="Path to the JSON file to process. The JSON file is expected to contain " | |
"a (maybe nested) dictionary where each leave entry is a list of values with " | |
"the same length.", | |
) | |
parser.add_argument( | |
"--remove-col-prefix", | |
type=str, | |
default=None, | |
help="Prefix to remove from column names.", | |
) | |
parser.add_argument( | |
"--sparse-col-prefix", | |
type=str, | |
default=None, | |
help="Prefix of sparse columns. All sparse columns will be melted into " | |
"two columns: <prefix>name and <prefix>value. The name column will " | |
"be converted to numeric if possible.", | |
) | |
parser.add_argument( | |
"--split-col", | |
type=str, | |
default=None, | |
help="Column to split into multiple columns. The format of the " | |
"column entries is expected to be: <key_1>=<value_a>-<key_2>=<value_b>-...", | |
) | |
parser.add_argument( | |
"--tail-cols", | |
type=str, | |
nargs="+", | |
default=None, | |
help="Columns to move to the end.", | |
) | |
parser.add_argument( | |
"--sort-cols", | |
type=str, | |
nargs="+", | |
default=None, | |
help="Columns to sort by (they will be moved to the front).", | |
) | |
parser.add_argument( | |
"--replace-in-col-names", | |
type=lambda s: s.split(":", 1), | |
nargs="+", | |
default=None, | |
help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.', | |
) | |
parser.add_argument( | |
"--round-precision", | |
type=int, | |
default=None, | |
help="Number of decimal places to round to.", | |
) | |
parser.add_argument( | |
"--in-percent", | |
action="store_true", | |
default=False, | |
help="If set, all float columns will be multiplied by 100 to convert them to percentages.", | |
) | |
parser.add_argument( | |
"--common-prefix-separator", | |
type=str, | |
default=None, | |
help="For all string columns, remove the common prefix up to the last occurrence of this separator.", | |
) | |
parser.add_argument( | |
"--column-regex-blacklist", | |
type=str, | |
nargs="+", | |
default=None, | |
help="List of regex patterns to match column names. " | |
"Columns that match any of the patterns will be removed.", | |
) | |
parser.add_argument( | |
"--column-regex-whitelist", | |
type=str, | |
nargs="+", | |
default=None, | |
help="List of regex patterns to match column names. " | |
"Only columns that match any of the patterns will be kept.", | |
) | |
parser.add_argument( | |
"--format", | |
type=str, | |
default="markdown", | |
choices=["markdown", "csv", "tsv", "json"], | |
help="Format to print the result in. Supported formats are: markdown, csv, json.", | |
) | |
kwargs = vars(parser.parse_args()) | |
main(**kwargs) | |