File size: 9,987 Bytes
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#!/usr/bin/env python
import argparse
import json
import os
from pathlib import Path

import pandas as pd
from pie_modules.utils import flatten_dict


def str2record(s: str | None, sep_parts: str = "-", sep_k_v: str = "=") -> pd.Series:
    if s is None or s.strip() == "" or s == "None":
        return pd.Series()
    return pd.Series(dict(k_v.split(sep_k_v, 1) for k_v in s.split(sep_parts)))


def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
    parts = path_and_maybe_id.split(separator, 1)
    if len(parts) == 1:
        return None, parts[0]
    return parts[0], parts[1]


def load_data_from_json(path: str | Path) -> pd.DataFrame:
    with open(path, "r") as f:
        data_json = json.load(f)
        data_flat = flatten_dict(data_json)
    return pd.DataFrame(data_flat)


def main(
    path: str | Path,
    remove_col_prefix: str | None = None,
    sparse_col_prefix: str | None = None,
    tail_cols: list[str] | None = None,
    sort_cols: list[str] | None = None,
    split_col: str | None = None,
    replace_in_col_names: list[tuple[str, str]] | None = None,
    round_precision: int | None = None,
    in_percent: bool = False,
    common_prefix_separator: str | None = None,
    column_regex_blacklist: list[str] | None = None,
    column_regex_whitelist: list[str] | None = None,
    format: str = "markdown",
) -> None:

    if str(path).lower().endswith(".json"):
        result = load_data_from_json(path)
    elif str(path).lower().endswith(".txt"):
        with open(path, "r") as f:
            index_data = [separate_path_and_id(line.strip()) for line in f.readlines()]
        data_list = []
        for meta_id, meta_path in index_data:
            data = load_data_from_json(os.path.join(meta_path, "job_return_value.json"))
            if meta_id is not None:
                job_id_prefix = meta_id.replace(",", "-")
                data["job_id"] = job_id_prefix + "-" + data["job_id"].astype(str)
            data = data.set_index("job_id")
            data_list.append(data)
        result = pd.concat(data_list, axis=1).reset_index()
    else:
        raise ValueError("Unsupported file format. Please provide a .json or .txt file.")

    if remove_col_prefix is not None:
        result.columns = result.columns.str.replace(r"^" + remove_col_prefix, "", regex=True)

    if sparse_col_prefix is not None:
        # get all columns that contain just one not-nan value
        # number_of_non_nan_values = len(df) - df.isna().sum()
        # df_sparse = df.loc[:, number_of_non_nan_values == 1]
        sparse_cols = [col for col in result.columns if col.startswith(sparse_col_prefix)]
        other_cols = [col for col in result.columns if col not in sparse_cols]

        value_col = f"{sparse_col_prefix}value"
        name_col = f"{sparse_col_prefix}name"
        result = result.melt(
            id_vars=other_cols, value_vars=sparse_cols, var_name=name_col, value_name=value_col
        ).dropna(
            subset=[value_col]
        )  # keep rows with a value
        # strip the "f1-" prefix, leaving just the numeric threshold
        result[name_col] = result[name_col].str.replace(r"^" + sparse_col_prefix, "", regex=True)
        # convert the column to numeric (if possible)
        try:
            result[name_col] = pd.to_numeric(result[name_col])
        except ValueError:
            # if it fails, just keep it as a string
            pass

    if split_col is not None:
        new_frame = result[split_col].apply(str2record)
        result = pd.concat([result.drop(columns=[split_col]), new_frame], axis=1)

    if in_percent:
        float_columns = result.select_dtypes(include=["float64", "float32"]).columns
        result[float_columns] = result[float_columns] * 100

    if round_precision is not None:
        # round all columns to the given precision
        result = result.round(round_precision)

    if common_prefix_separator is not None:
        # remove common prefix from values in all string columns
        obj_columns = result.select_dtypes(include=["object"]).columns
        for obj_col in obj_columns:
            # get the common prefix
            common_prefix = os.path.commonprefix(result[obj_col].dropna().astype(str).tolist())
            # find last occurrence of the common_prefix_separator
            last_occurrence = common_prefix.rfind(common_prefix_separator)
            if last_occurrence != -1:
                # truncate the common prefix after the last occurrence of the separator
                common_prefix = common_prefix[: last_occurrence + len(common_prefix_separator)]
                # remove the common prefix (including the separator) from the column
                result[obj_col] = result[obj_col].str.replace(r"^" + common_prefix, "", regex=True)

    # sort columns to get a deterministic order
    result = result.sort_index(axis=1)

    if tail_cols is not None:
        front_cols = [c for c in result.columns if c not in tail_cols]
        result = result[front_cols + tail_cols]

    if sort_cols is not None:
        result = result.sort_values(sort_cols)
        # also move the sort columns to the front
        result = result[sort_cols + [c for c in result.columns if c not in sort_cols]]

    if column_regex_blacklist is not None:
        # remove columns that match any of the regex patterns in the blacklist
        for pattern in column_regex_blacklist:
            result = result.loc[:, ~result.columns.str.contains(pattern, regex=True)]

    if column_regex_whitelist is not None:
        # keep only columns that match any of the regex patterns in the whitelist
        result = result.loc[
            :, result.columns.str.contains("|".join(column_regex_whitelist), regex=True)
        ]

    if replace_in_col_names is not None:
        for old_value, new_value in replace_in_col_names:
            result.columns = result.columns.str.replace(old_value, new_value, regex=False)

    if format == "markdown":
        result_str = result.to_markdown(index=False)
    elif format == "csv":
        result_str = result.to_csv(index=False)
    elif format == "tsv":
        result_str = result.to_csv(index=False, sep="\t")
    elif format == "json":
        result_str = result.to_json(orient="records", lines=True)
    else:
        raise ValueError(
            f"Unsupported format: {format}. Supported formats are: markdown, csv, json."
        )

    print(result_str)


if __name__ == "__main__":
    """
    Example usage:

    python src/analysis/format_metric_results.py \
    logs/document_evaluation/multiruns/default/2025-05-21_11-59-19/job_return_value.json \
    --remove-col-prefix train/ \
    --sparse-col-prefix f1- \
    --split-col job_id \
    --tail-cols num_positives num_total \
    --sort-cols experiment model \
    --round-precision 4
    """

    parser = argparse.ArgumentParser(
        description="Process a JSON file containing metric results (from multirun) and print as Markdown table."
    )
    parser.add_argument(
        "path",
        type=str,
        help="Path to the JSON file to process. The JSON file is expected to contain "
        "a (maybe nested) dictionary where each leave entry is a list of values with "
        "the same length.",
    )
    parser.add_argument(
        "--remove-col-prefix",
        type=str,
        default=None,
        help="Prefix to remove from column names.",
    )
    parser.add_argument(
        "--sparse-col-prefix",
        type=str,
        default=None,
        help="Prefix of sparse columns. All sparse columns will be melted into "
        "two columns: <prefix>name and <prefix>value. The name column will "
        "be converted to numeric if possible.",
    )

    parser.add_argument(
        "--split-col",
        type=str,
        default=None,
        help="Column to split into multiple columns. The format of the "
        "column entries is expected to be: <key_1>=<value_a>-<key_2>=<value_b>-...",
    )
    parser.add_argument(
        "--tail-cols",
        type=str,
        nargs="+",
        default=None,
        help="Columns to move to the end.",
    )
    parser.add_argument(
        "--sort-cols",
        type=str,
        nargs="+",
        default=None,
        help="Columns to sort by (they will be moved to the front).",
    )
    parser.add_argument(
        "--replace-in-col-names",
        type=lambda s: s.split(":", 1),
        nargs="+",
        default=None,
        help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
    )
    parser.add_argument(
        "--round-precision",
        type=int,
        default=None,
        help="Number of decimal places to round to.",
    )
    parser.add_argument(
        "--in-percent",
        action="store_true",
        default=False,
        help="If set, all float columns will be multiplied by 100 to convert them to percentages.",
    )
    parser.add_argument(
        "--common-prefix-separator",
        type=str,
        default=None,
        help="For all string columns, remove the common prefix up to the last occurrence of this separator.",
    )
    parser.add_argument(
        "--column-regex-blacklist",
        type=str,
        nargs="+",
        default=None,
        help="List of regex patterns to match column names. "
        "Columns that match any of the patterns will be removed.",
    )
    parser.add_argument(
        "--column-regex-whitelist",
        type=str,
        nargs="+",
        default=None,
        help="List of regex patterns to match column names. "
        "Only columns that match any of the patterns will be kept.",
    )
    parser.add_argument(
        "--format",
        type=str,
        default="markdown",
        choices=["markdown", "csv", "tsv", "json"],
        help="Format to print the result in. Supported formats are: markdown, csv, json.",
    )

    kwargs = vars(parser.parse_args())
    main(**kwargs)