File size: 9,987 Bytes
d868d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
#!/usr/bin/env python
import argparse
import json
import os
from pathlib import Path
import pandas as pd
from pie_modules.utils import flatten_dict
def str2record(s: str | None, sep_parts: str = "-", sep_k_v: str = "=") -> pd.Series:
if s is None or s.strip() == "" or s == "None":
return pd.Series()
return pd.Series(dict(k_v.split(sep_k_v, 1) for k_v in s.split(sep_parts)))
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
parts = path_and_maybe_id.split(separator, 1)
if len(parts) == 1:
return None, parts[0]
return parts[0], parts[1]
def load_data_from_json(path: str | Path) -> pd.DataFrame:
with open(path, "r") as f:
data_json = json.load(f)
data_flat = flatten_dict(data_json)
return pd.DataFrame(data_flat)
def main(
path: str | Path,
remove_col_prefix: str | None = None,
sparse_col_prefix: str | None = None,
tail_cols: list[str] | None = None,
sort_cols: list[str] | None = None,
split_col: str | None = None,
replace_in_col_names: list[tuple[str, str]] | None = None,
round_precision: int | None = None,
in_percent: bool = False,
common_prefix_separator: str | None = None,
column_regex_blacklist: list[str] | None = None,
column_regex_whitelist: list[str] | None = None,
format: str = "markdown",
) -> None:
if str(path).lower().endswith(".json"):
result = load_data_from_json(path)
elif str(path).lower().endswith(".txt"):
with open(path, "r") as f:
index_data = [separate_path_and_id(line.strip()) for line in f.readlines()]
data_list = []
for meta_id, meta_path in index_data:
data = load_data_from_json(os.path.join(meta_path, "job_return_value.json"))
if meta_id is not None:
job_id_prefix = meta_id.replace(",", "-")
data["job_id"] = job_id_prefix + "-" + data["job_id"].astype(str)
data = data.set_index("job_id")
data_list.append(data)
result = pd.concat(data_list, axis=1).reset_index()
else:
raise ValueError("Unsupported file format. Please provide a .json or .txt file.")
if remove_col_prefix is not None:
result.columns = result.columns.str.replace(r"^" + remove_col_prefix, "", regex=True)
if sparse_col_prefix is not None:
# get all columns that contain just one not-nan value
# number_of_non_nan_values = len(df) - df.isna().sum()
# df_sparse = df.loc[:, number_of_non_nan_values == 1]
sparse_cols = [col for col in result.columns if col.startswith(sparse_col_prefix)]
other_cols = [col for col in result.columns if col not in sparse_cols]
value_col = f"{sparse_col_prefix}value"
name_col = f"{sparse_col_prefix}name"
result = result.melt(
id_vars=other_cols, value_vars=sparse_cols, var_name=name_col, value_name=value_col
).dropna(
subset=[value_col]
) # keep rows with a value
# strip the "f1-" prefix, leaving just the numeric threshold
result[name_col] = result[name_col].str.replace(r"^" + sparse_col_prefix, "", regex=True)
# convert the column to numeric (if possible)
try:
result[name_col] = pd.to_numeric(result[name_col])
except ValueError:
# if it fails, just keep it as a string
pass
if split_col is not None:
new_frame = result[split_col].apply(str2record)
result = pd.concat([result.drop(columns=[split_col]), new_frame], axis=1)
if in_percent:
float_columns = result.select_dtypes(include=["float64", "float32"]).columns
result[float_columns] = result[float_columns] * 100
if round_precision is not None:
# round all columns to the given precision
result = result.round(round_precision)
if common_prefix_separator is not None:
# remove common prefix from values in all string columns
obj_columns = result.select_dtypes(include=["object"]).columns
for obj_col in obj_columns:
# get the common prefix
common_prefix = os.path.commonprefix(result[obj_col].dropna().astype(str).tolist())
# find last occurrence of the common_prefix_separator
last_occurrence = common_prefix.rfind(common_prefix_separator)
if last_occurrence != -1:
# truncate the common prefix after the last occurrence of the separator
common_prefix = common_prefix[: last_occurrence + len(common_prefix_separator)]
# remove the common prefix (including the separator) from the column
result[obj_col] = result[obj_col].str.replace(r"^" + common_prefix, "", regex=True)
# sort columns to get a deterministic order
result = result.sort_index(axis=1)
if tail_cols is not None:
front_cols = [c for c in result.columns if c not in tail_cols]
result = result[front_cols + tail_cols]
if sort_cols is not None:
result = result.sort_values(sort_cols)
# also move the sort columns to the front
result = result[sort_cols + [c for c in result.columns if c not in sort_cols]]
if column_regex_blacklist is not None:
# remove columns that match any of the regex patterns in the blacklist
for pattern in column_regex_blacklist:
result = result.loc[:, ~result.columns.str.contains(pattern, regex=True)]
if column_regex_whitelist is not None:
# keep only columns that match any of the regex patterns in the whitelist
result = result.loc[
:, result.columns.str.contains("|".join(column_regex_whitelist), regex=True)
]
if replace_in_col_names is not None:
for old_value, new_value in replace_in_col_names:
result.columns = result.columns.str.replace(old_value, new_value, regex=False)
if format == "markdown":
result_str = result.to_markdown(index=False)
elif format == "csv":
result_str = result.to_csv(index=False)
elif format == "tsv":
result_str = result.to_csv(index=False, sep="\t")
elif format == "json":
result_str = result.to_json(orient="records", lines=True)
else:
raise ValueError(
f"Unsupported format: {format}. Supported formats are: markdown, csv, json."
)
print(result_str)
if __name__ == "__main__":
"""
Example usage:
python src/analysis/format_metric_results.py \
logs/document_evaluation/multiruns/default/2025-05-21_11-59-19/job_return_value.json \
--remove-col-prefix train/ \
--sparse-col-prefix f1- \
--split-col job_id \
--tail-cols num_positives num_total \
--sort-cols experiment model \
--round-precision 4
"""
parser = argparse.ArgumentParser(
description="Process a JSON file containing metric results (from multirun) and print as Markdown table."
)
parser.add_argument(
"path",
type=str,
help="Path to the JSON file to process. The JSON file is expected to contain "
"a (maybe nested) dictionary where each leave entry is a list of values with "
"the same length.",
)
parser.add_argument(
"--remove-col-prefix",
type=str,
default=None,
help="Prefix to remove from column names.",
)
parser.add_argument(
"--sparse-col-prefix",
type=str,
default=None,
help="Prefix of sparse columns. All sparse columns will be melted into "
"two columns: <prefix>name and <prefix>value. The name column will "
"be converted to numeric if possible.",
)
parser.add_argument(
"--split-col",
type=str,
default=None,
help="Column to split into multiple columns. The format of the "
"column entries is expected to be: <key_1>=<value_a>-<key_2>=<value_b>-...",
)
parser.add_argument(
"--tail-cols",
type=str,
nargs="+",
default=None,
help="Columns to move to the end.",
)
parser.add_argument(
"--sort-cols",
type=str,
nargs="+",
default=None,
help="Columns to sort by (they will be moved to the front).",
)
parser.add_argument(
"--replace-in-col-names",
type=lambda s: s.split(":", 1),
nargs="+",
default=None,
help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
)
parser.add_argument(
"--round-precision",
type=int,
default=None,
help="Number of decimal places to round to.",
)
parser.add_argument(
"--in-percent",
action="store_true",
default=False,
help="If set, all float columns will be multiplied by 100 to convert them to percentages.",
)
parser.add_argument(
"--common-prefix-separator",
type=str,
default=None,
help="For all string columns, remove the common prefix up to the last occurrence of this separator.",
)
parser.add_argument(
"--column-regex-blacklist",
type=str,
nargs="+",
default=None,
help="List of regex patterns to match column names. "
"Columns that match any of the patterns will be removed.",
)
parser.add_argument(
"--column-regex-whitelist",
type=str,
nargs="+",
default=None,
help="List of regex patterns to match column names. "
"Only columns that match any of the patterns will be kept.",
)
parser.add_argument(
"--format",
type=str,
default="markdown",
choices=["markdown", "csv", "tsv", "json"],
help="Format to print the result in. Supported formats are: markdown, csv, json.",
)
kwargs = vars(parser.parse_args())
main(**kwargs)
|