ScientificArgumentRecommender / src /analysis /format_metric_results.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
#!/usr/bin/env python
import argparse
import json
import os
from pathlib import Path
import pandas as pd
from pie_modules.utils import flatten_dict
def str2record(s: str | None, sep_parts: str = "-", sep_k_v: str = "=") -> pd.Series:
if s is None or s.strip() == "" or s == "None":
return pd.Series()
return pd.Series(dict(k_v.split(sep_k_v, 1) for k_v in s.split(sep_parts)))
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
parts = path_and_maybe_id.split(separator, 1)
if len(parts) == 1:
return None, parts[0]
return parts[0], parts[1]
def load_data_from_json(path: str | Path) -> pd.DataFrame:
with open(path, "r") as f:
data_json = json.load(f)
data_flat = flatten_dict(data_json)
return pd.DataFrame(data_flat)
def main(
path: str | Path,
remove_col_prefix: str | None = None,
sparse_col_prefix: str | None = None,
tail_cols: list[str] | None = None,
sort_cols: list[str] | None = None,
split_col: str | None = None,
replace_in_col_names: list[tuple[str, str]] | None = None,
round_precision: int | None = None,
in_percent: bool = False,
common_prefix_separator: str | None = None,
column_regex_blacklist: list[str] | None = None,
column_regex_whitelist: list[str] | None = None,
format: str = "markdown",
) -> None:
if str(path).lower().endswith(".json"):
result = load_data_from_json(path)
elif str(path).lower().endswith(".txt"):
with open(path, "r") as f:
index_data = [separate_path_and_id(line.strip()) for line in f.readlines()]
data_list = []
for meta_id, meta_path in index_data:
data = load_data_from_json(os.path.join(meta_path, "job_return_value.json"))
if meta_id is not None:
job_id_prefix = meta_id.replace(",", "-")
data["job_id"] = job_id_prefix + "-" + data["job_id"].astype(str)
data = data.set_index("job_id")
data_list.append(data)
result = pd.concat(data_list, axis=1).reset_index()
else:
raise ValueError("Unsupported file format. Please provide a .json or .txt file.")
if remove_col_prefix is not None:
result.columns = result.columns.str.replace(r"^" + remove_col_prefix, "", regex=True)
if sparse_col_prefix is not None:
# get all columns that contain just one not-nan value
# number_of_non_nan_values = len(df) - df.isna().sum()
# df_sparse = df.loc[:, number_of_non_nan_values == 1]
sparse_cols = [col for col in result.columns if col.startswith(sparse_col_prefix)]
other_cols = [col for col in result.columns if col not in sparse_cols]
value_col = f"{sparse_col_prefix}value"
name_col = f"{sparse_col_prefix}name"
result = result.melt(
id_vars=other_cols, value_vars=sparse_cols, var_name=name_col, value_name=value_col
).dropna(
subset=[value_col]
) # keep rows with a value
# strip the "f1-" prefix, leaving just the numeric threshold
result[name_col] = result[name_col].str.replace(r"^" + sparse_col_prefix, "", regex=True)
# convert the column to numeric (if possible)
try:
result[name_col] = pd.to_numeric(result[name_col])
except ValueError:
# if it fails, just keep it as a string
pass
if split_col is not None:
new_frame = result[split_col].apply(str2record)
result = pd.concat([result.drop(columns=[split_col]), new_frame], axis=1)
if in_percent:
float_columns = result.select_dtypes(include=["float64", "float32"]).columns
result[float_columns] = result[float_columns] * 100
if round_precision is not None:
# round all columns to the given precision
result = result.round(round_precision)
if common_prefix_separator is not None:
# remove common prefix from values in all string columns
obj_columns = result.select_dtypes(include=["object"]).columns
for obj_col in obj_columns:
# get the common prefix
common_prefix = os.path.commonprefix(result[obj_col].dropna().astype(str).tolist())
# find last occurrence of the common_prefix_separator
last_occurrence = common_prefix.rfind(common_prefix_separator)
if last_occurrence != -1:
# truncate the common prefix after the last occurrence of the separator
common_prefix = common_prefix[: last_occurrence + len(common_prefix_separator)]
# remove the common prefix (including the separator) from the column
result[obj_col] = result[obj_col].str.replace(r"^" + common_prefix, "", regex=True)
# sort columns to get a deterministic order
result = result.sort_index(axis=1)
if tail_cols is not None:
front_cols = [c for c in result.columns if c not in tail_cols]
result = result[front_cols + tail_cols]
if sort_cols is not None:
result = result.sort_values(sort_cols)
# also move the sort columns to the front
result = result[sort_cols + [c for c in result.columns if c not in sort_cols]]
if column_regex_blacklist is not None:
# remove columns that match any of the regex patterns in the blacklist
for pattern in column_regex_blacklist:
result = result.loc[:, ~result.columns.str.contains(pattern, regex=True)]
if column_regex_whitelist is not None:
# keep only columns that match any of the regex patterns in the whitelist
result = result.loc[
:, result.columns.str.contains("|".join(column_regex_whitelist), regex=True)
]
if replace_in_col_names is not None:
for old_value, new_value in replace_in_col_names:
result.columns = result.columns.str.replace(old_value, new_value, regex=False)
if format == "markdown":
result_str = result.to_markdown(index=False)
elif format == "csv":
result_str = result.to_csv(index=False)
elif format == "tsv":
result_str = result.to_csv(index=False, sep="\t")
elif format == "json":
result_str = result.to_json(orient="records", lines=True)
else:
raise ValueError(
f"Unsupported format: {format}. Supported formats are: markdown, csv, json."
)
print(result_str)
if __name__ == "__main__":
"""
Example usage:
python src/analysis/format_metric_results.py \
logs/document_evaluation/multiruns/default/2025-05-21_11-59-19/job_return_value.json \
--remove-col-prefix train/ \
--sparse-col-prefix f1- \
--split-col job_id \
--tail-cols num_positives num_total \
--sort-cols experiment model \
--round-precision 4
"""
parser = argparse.ArgumentParser(
description="Process a JSON file containing metric results (from multirun) and print as Markdown table."
)
parser.add_argument(
"path",
type=str,
help="Path to the JSON file to process. The JSON file is expected to contain "
"a (maybe nested) dictionary where each leave entry is a list of values with "
"the same length.",
)
parser.add_argument(
"--remove-col-prefix",
type=str,
default=None,
help="Prefix to remove from column names.",
)
parser.add_argument(
"--sparse-col-prefix",
type=str,
default=None,
help="Prefix of sparse columns. All sparse columns will be melted into "
"two columns: <prefix>name and <prefix>value. The name column will "
"be converted to numeric if possible.",
)
parser.add_argument(
"--split-col",
type=str,
default=None,
help="Column to split into multiple columns. The format of the "
"column entries is expected to be: <key_1>=<value_a>-<key_2>=<value_b>-...",
)
parser.add_argument(
"--tail-cols",
type=str,
nargs="+",
default=None,
help="Columns to move to the end.",
)
parser.add_argument(
"--sort-cols",
type=str,
nargs="+",
default=None,
help="Columns to sort by (they will be moved to the front).",
)
parser.add_argument(
"--replace-in-col-names",
type=lambda s: s.split(":", 1),
nargs="+",
default=None,
help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
)
parser.add_argument(
"--round-precision",
type=int,
default=None,
help="Number of decimal places to round to.",
)
parser.add_argument(
"--in-percent",
action="store_true",
default=False,
help="If set, all float columns will be multiplied by 100 to convert them to percentages.",
)
parser.add_argument(
"--common-prefix-separator",
type=str,
default=None,
help="For all string columns, remove the common prefix up to the last occurrence of this separator.",
)
parser.add_argument(
"--column-regex-blacklist",
type=str,
nargs="+",
default=None,
help="List of regex patterns to match column names. "
"Columns that match any of the patterns will be removed.",
)
parser.add_argument(
"--column-regex-whitelist",
type=str,
nargs="+",
default=None,
help="List of regex patterns to match column names. "
"Only columns that match any of the patterns will be kept.",
)
parser.add_argument(
"--format",
type=str,
default="markdown",
choices=["markdown", "csv", "tsv", "json"],
help="Format to print the result in. Supported formats are: markdown, csv, json.",
)
kwargs = vars(parser.parse_args())
main(**kwargs)