update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import argparse | |
import json | |
from typing import Any, Dict, Iterable, List, Optional, Tuple | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
def get_col_name(col: str) -> str: | |
parts = [part[1:-1] for part in col[1:-1].split(", ") if part[1:-1] != ""] | |
return parts[-1] | |
def get_idx_entry(s: str, keep_only_last_part: bool = False) -> Tuple[str, str]: | |
k, v = s.split("=", 1) | |
if keep_only_last_part: | |
k = k.split(".")[-1] | |
return k, v | |
def get_idx_dict(job_id: str, keep_only_last_part: bool = False) -> Dict[str, str]: | |
return dict( | |
get_idx_entry(part, keep_only_last_part=keep_only_last_part) for part in job_id.split("-") | |
) | |
def unflatten_index( | |
index: Iterable[str], | |
keep_only_last_part: bool = False, | |
dtypes: Optional[Dict[str, Any]] = None, | |
) -> pd.MultiIndex: | |
as_df = pd.DataFrame.from_records( | |
[get_idx_dict(idx, keep_only_last_part=keep_only_last_part) for idx in index] | |
) | |
if dtypes is not None: | |
dtypes_valid = {col: dtype for col, dtype in dtypes.items() if col in as_df.columns} | |
as_df = as_df.astype(dtypes_valid) | |
return pd.MultiIndex.from_frame(as_df.convert_dtypes()) | |
def col_to_str(col_entries: Iterable[str], names: Iterable[Optional[str]], sep: str) -> str: | |
return sep.join( | |
[ | |
f"{name}={col_entry}" if name is not None else col_entry | |
for col_entry, name in zip(col_entries, names) | |
] | |
) | |
def flatten_index(index: pd.MultiIndex, names: Optional[List[Optional[str]]] = None) -> pd.Index: | |
names = names or index.names | |
if names is None: | |
raise ValueError("names must be provided if index has no names") | |
return pd.Index([col_to_str(col, names=names, sep=",") for col in index]) | |
def prepare_quality_and_throughput_dfs( | |
metric_data_path: str, | |
job_return_value_path: str, | |
char_total: int, | |
index_dtypes: Optional[Dict[str, Any]] = None, | |
job_id_prefix: Optional[str] = None, | |
) -> Tuple[pd.DataFrame, pd.Series]: | |
with open(metric_data_path) as f: | |
data = json.load(f) | |
# save result from above command in "data" (use only last ouf the output line!) | |
df = pd.DataFrame.from_dict(data) | |
df.columns = [get_col_name(col) for col in df.columns] | |
f1_series = df.set_index([col for col in df.columns if col != "f1"])["f1"] | |
f1_df = f1_series.apply(lambda x: pd.Series(x)).T | |
with open(job_return_value_path) as f: | |
job_return_value = json.load(f) | |
job_ids = job_return_value["job_id"] | |
if job_id_prefix is not None: | |
job_ids = [ | |
f"{job_id_prefix},{job_id}" if job_id.strip() != "" else job_id_prefix | |
for job_id in job_ids | |
] | |
index = unflatten_index( | |
job_ids, | |
keep_only_last_part=True, | |
dtypes=index_dtypes, | |
) | |
prediction_time_series = pd.Series( | |
job_return_value["prediction_time"], index=index, name="prediction_time" | |
) | |
f1_df.index = prediction_time_series.index | |
k_chars_per_s = char_total / (prediction_time_series * 1000) | |
k_chars_per_s.name = "1k_chars_per_s" | |
return f1_df, k_chars_per_s | |
def get_pareto_front_mask(df: pd.DataFrame, x_col: str, y_col: str) -> pd.Series: | |
""" | |
Return a boolean mask indicating which rows belong to the Pareto front. | |
In this version, we assume you want to maximize both x_col and y_col. | |
A point A is said to dominate point B if: | |
A[x_col] >= B[x_col] AND | |
A[y_col] >= B[y_col] AND | |
at least one is strictly greater. | |
Then B is not on the Pareto front. | |
Parameters | |
---------- | |
df : pd.DataFrame | |
DataFrame containing the data points. | |
x_col : str | |
Name of the column to treat as the first objective (maximize). | |
y_col : str | |
Name of the column to treat as the second objective (maximize). | |
Returns | |
------- | |
pd.Series | |
A boolean Series (aligned with df.index) where True means | |
the row is on the Pareto front. | |
""" | |
# Extract the relevant columns as a NumPy array for speed. | |
data = df[[x_col, y_col]].values | |
n = len(data) | |
is_dominated = np.zeros(n, dtype=bool) | |
for i in range(n): | |
# If it's already marked dominated, skip checks | |
if is_dominated[i]: | |
continue | |
for j in range(n): | |
if i == j: | |
continue | |
# Check if j dominates i | |
if ( | |
data[j, 0] >= data[i, 0] | |
and data[j, 1] >= data[i, 1] | |
and (data[j, 0] > data[i, 0] or data[j, 1] > data[i, 1]) | |
): | |
is_dominated[i] = True | |
break | |
# Return True for points not dominated by any other | |
return pd.Series(~is_dominated, index=df.index) | |
def main( | |
job_return_value_path_test: List[str], | |
job_return_value_path_val: List[str], | |
metric_data_path_test: List[str], | |
metric_data_path_val: List[str], | |
char_total_test: int, | |
char_total_val: int, | |
job_id_prefixes: Optional[List[str]] = None, | |
metric_filters: Optional[List[str]] = None, | |
index_filters: Optional[List[str]] = None, | |
index_blacklist: Optional[List[str]] = None, | |
label_mapping: Optional[Dict[str, str]] = None, | |
plot_method: str = "line", # can be "scatter" or "line" | |
pareto_front: bool = False, | |
show_as: str = "figure", | |
columns: Optional[List[str]] = None, | |
color_column: Optional[str] = None, | |
): | |
label_mapping = label_mapping or {} | |
if job_id_prefixes is not None: | |
if len(job_id_prefixes) != len(job_return_value_path_test): | |
raise ValueError( | |
f"job_id_prefixes ({len(job_id_prefixes)}) and " | |
f"job_return_value_path_test ({len(job_return_value_path_test)}) " | |
f"must have the same length" | |
) | |
# replace empty strings with None | |
job_id_prefixes_with_none = [ | |
job_id_prefix if job_id_prefix != "" else None for job_id_prefix in job_id_prefixes | |
] | |
else: | |
job_id_prefixes_with_none = [None] * len(job_return_value_path_test) | |
# combine input data for test and val | |
char_total = {"test": char_total_test, "val": char_total_val} | |
metric_data_path = {"test": metric_data_path_test, "val": metric_data_path_val} | |
job_return_value_path = {"test": job_return_value_path_test, "val": job_return_value_path_val} | |
# prepare dataframes | |
common_kwargs = dict( | |
index_dtypes={ | |
"max_argument_distance": int, | |
"max_length": int, | |
"num_beams": int, | |
} | |
) | |
f1_df_list: Dict[str, List[pd.DataFrame]] = {"test": [], "val": []} | |
k_chars_per_s_list: Dict[str, List[pd.Series]] = {"test": [], "val": []} | |
for split in metric_data_path: | |
if len(metric_data_path[split]) != len(job_return_value_path[split]): | |
raise ValueError( | |
f"metric_data_path[{split}] ({len(metric_data_path[split])}) and " | |
f"job_return_value_path[{split}] ({len(job_return_value_path[split])}) " | |
f"must have the same length" | |
) | |
for current_metric_data_path, current_job_return_value_path, job_id_prefix in zip( | |
metric_data_path[split], job_return_value_path[split], job_id_prefixes_with_none | |
): | |
current_f1_df, current_k_chars_per_s = prepare_quality_and_throughput_dfs( | |
current_metric_data_path, | |
current_job_return_value_path, | |
char_total=char_total[split], | |
job_id_prefix=job_id_prefix, | |
**common_kwargs, | |
) | |
f1_df_list[split].append(current_f1_df) | |
k_chars_per_s_list[split].append(current_k_chars_per_s) | |
f1_df_dict = {split: pd.concat(f1_df_list[split], axis=0) for split in f1_df_list} | |
k_chars_per_s_dict = { | |
split: pd.concat(k_chars_per_s_list[split], axis=0) for split in k_chars_per_s_list | |
} | |
# combine dataframes for test and val | |
f1_df = pd.concat(f1_df_dict, names=["split"] + f1_df_dict["test"].index.names) | |
f1_df.columns = [col_to_str(col, names=f1_df.columns.names, sep=",") for col in f1_df.columns] | |
k_chars_per_s = pd.concat( | |
k_chars_per_s_dict, | |
names=["split"] + k_chars_per_s_dict["test"].index.names, | |
) | |
# combine quality and throughput data | |
df_plot = pd.concat([f1_df, k_chars_per_s], axis=1) | |
df_plot = ( | |
df_plot.reset_index() | |
.set_index(list(f1_df.index.names) + [k_chars_per_s.name]) | |
.unstack("split") | |
) | |
df_plot.columns = flatten_index(df_plot.columns, names=[None, "split"]) | |
# remove all columns that are not needed | |
if metric_filters is not None: | |
for fil in metric_filters: | |
df_plot.drop(columns=[col for col in df_plot.columns if fil not in col], inplace=True) | |
df_plot.columns = [col.replace(fil, "") for col in df_plot.columns] | |
# flatten the columns | |
df_plot.columns = [ | |
",".join([part for part in col.split(",") if part != ""]) for col in df_plot.columns | |
] | |
v: Any | |
if index_filters is not None: | |
for k_v in index_filters: | |
k, v = k_v.split("=") | |
if k in common_kwargs["index_dtypes"]: | |
v = common_kwargs["index_dtypes"][k](v) | |
df_plot = df_plot.xs(v, level=k, axis=0) | |
if index_blacklist is not None: | |
for k_v in index_blacklist: | |
k, v = k_v.split("=") | |
if k in common_kwargs["index_dtypes"]: | |
v = common_kwargs["index_dtypes"][k](v) | |
df_plot = df_plot.drop(v, level=k, axis=0) | |
if columns is not None: | |
df_plot = df_plot[columns] | |
x = "1k_chars_per_s" | |
y = df_plot.columns | |
if pareto_front: | |
for col in y: | |
current_data = df_plot[col].dropna().reset_index(x).copy() | |
pareto_front_mask = get_pareto_front_mask(current_data, x_col=x, y_col=col) | |
current_data.loc[~pareto_front_mask, col] = np.nan | |
current_data_reset = current_data.reset_index().set_index(df_plot.index.names) | |
df_plot[col] = current_data_reset[col] | |
# remove nan rows | |
df_plot = df_plot.dropna(how="all") | |
# plot | |
# Create a custom color sequence (concatenating multiple palettes if needed) | |
custom_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24 | |
text_cols = list(df_plot.index.names) | |
text_cols.remove(x) | |
df_plot_reset = df_plot.reset_index() | |
if len(text_cols) > 1: | |
df_plot_reset[",".join(text_cols)] = ( | |
df_plot_reset[text_cols].astype(str).agg(", ".join, axis=1) | |
) | |
text_col = ",".join(text_cols) | |
if show_as == "figure": | |
_plot_method = getattr(px, plot_method) | |
df_plot_sorted = df_plot_reset.sort_values(by=x) | |
fig = _plot_method( | |
df_plot_sorted, | |
x=x, | |
y=y, | |
text=text_col if plot_method != "scatter" else None, | |
color=color_column, | |
color_discrete_sequence=custom_colors, | |
hover_data=text_cols, | |
) | |
# set connectgaps to True to connect the lines | |
fig.update_traces(connectgaps=True) | |
legend_title = "Evaluation Setup" | |
if metric_filters: | |
whitelist_filters_mapped = [label_mapping.get(fil, fil) for fil in metric_filters] | |
legend_title += f" ({', '.join(whitelist_filters_mapped)})" | |
text_cols_mapped = [label_mapping.get(col, col) for col in text_cols] | |
title = f"Impact of {', '.join(text_cols_mapped)} on Prediction Quality and Throughput" | |
if index_filters: | |
index_filters_mapped = [label_mapping.get(fil, fil) for fil in index_filters] | |
title += f" ({', '.join(index_filters_mapped)})" | |
if pareto_front: | |
title += " (Pareto Front)" | |
fig.update_layout( | |
xaxis_title="Throughput (1k chars/s)", | |
yaxis_title="Quality (F1)", | |
title=title, | |
# center the title | |
title_x=0.2, | |
# black title | |
title_font=dict(color="black"), | |
# change legend title | |
legend_title=legend_title, | |
font_family="Computer Modern", | |
# white background | |
plot_bgcolor="white", | |
paper_bgcolor="white", | |
) | |
update_axes_kwargs = dict( | |
tickfont=dict(color="black"), | |
title_font=dict(color="black"), | |
ticks="inside", # ensure tick markers are drawn | |
tickcolor="black", | |
tickwidth=1, | |
ticklen=10, | |
linecolor="black", | |
# show grid | |
gridcolor="lightgray", | |
) | |
fig.update_yaxes(**update_axes_kwargs) | |
fig.update_xaxes(**update_axes_kwargs) | |
fig.show() | |
elif show_as == "markdown": | |
# Print the DataFrame as a Markdown table | |
print(df_plot_reset.to_markdown(index=False, floatfmt=".4f")) | |
elif show_as == "json": | |
# Print the DataFrame as a JSON object | |
print(df_plot_reset.to_json(orient="columns", indent=4)) | |
else: | |
raise ValueError(f"Unknown show_as value: {show_as}. Use 'figure', 'markdown' or 'json'.") | |
if __name__ == "__main__": | |
""" | |
# Example usage 1 (pipeline model, data from data source: https://github.com/ArneBinder/pie-document-level/issues/388#issuecomment-2752829257): | |
python src/analysis/show_inference_params_on_quality_and_throughput.py \ | |
--job-return-value-path-test logs/prediction/multiruns/default/2025-03-26_01-31-05/job_return_value.json \ | |
--job-return-value-path-val logs/prediction/multiruns/default/2025-03-26_16-49-36/job_return_value.json \ | |
--metric-data-path-test data/evaluation/argumentation_structure/inference_pipeline_test.json \ | |
--metric-data-path-val data/evaluation/argumentation_structure/inference_pipeline_validation.json \ | |
--metric-filters task=are discont_comp=true split=val | |
# Example usage 2 (joint model, data from: https://github.com/ArneBinder/pie-document-level/issues/390#issuecomment-2759888004) | |
python src/analysis/show_inference_params_on_quality_and_throughput.py \ | |
--job-return-value-path-test logs/prediction/multiruns/default/2025-03-28_01-34-07/job_return_value.json \ | |
--job-return-value-path-val logs/prediction/multiruns/default/2025-03-28_02-57-00/job_return_value.json \ | |
--metric-data-path-test data/evaluation/argumentation_structure/inference_joint_test.json \ | |
--metric-data-path-val data/evaluation/argumentation_structure/inference_joint_validation.json \ | |
--metric-filters task=are discont_comp=true split=val \ | |
--plot-method scatter | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--job-return-value-path-test", | |
type=str, | |
nargs="+", | |
required=True, | |
) | |
parser.add_argument( | |
"--job-return-value-path-val", | |
type=str, | |
nargs="+", | |
required=True, | |
) | |
parser.add_argument( | |
"--metric-data-path-test", | |
type=str, | |
nargs="+", | |
required=True, | |
) | |
parser.add_argument( | |
"--metric-data-path-val", | |
type=str, | |
nargs="+", | |
required=True, | |
) | |
parser.add_argument( | |
"--job-id-prefixes", | |
type=str, | |
nargs="*", | |
default=None, | |
) | |
parser.add_argument( | |
"--plot-method", | |
type=str, | |
default="line", | |
choices=["scatter", "line"], | |
help="Plot method to use (default: line)", | |
) | |
parser.add_argument( | |
"--color-column", | |
type=str, | |
default=None, | |
help="Column to use for colour coding (default: None)", | |
) | |
parser.add_argument( | |
"--metric-filters", | |
type=str, | |
nargs="*", | |
default=None, | |
help="Filters to apply to the metric data in the format 'key=value'", | |
) | |
parser.add_argument( | |
"--index-filters", | |
type=str, | |
nargs="*", | |
default=None, | |
help="Filters to apply to the index data in the format 'key=value'", | |
) | |
parser.add_argument( | |
"--index-blacklist", | |
type=str, | |
nargs="*", | |
default=None, | |
help="Blacklist to apply to the index data in the format 'key=value'", | |
) | |
parser.add_argument( | |
"--columns", | |
type=str, | |
nargs="*", | |
default=None, | |
help="Columns to plot (default: all)", | |
) | |
parser.add_argument( | |
"--pareto-front", | |
action="store_true", | |
help="Whether to show only the pareto front", | |
) | |
parser.add_argument( | |
"--show-as", | |
type=str, | |
default="figure", | |
choices=["figure", "markdown", "json"], | |
help="How to show the results (default: figure)", | |
) | |
kwargs = vars(parser.parse_args()) | |
main( | |
char_total_test=383154, | |
char_total_val=182794, | |
label_mapping={ | |
"max_argument_distance": "Max. Argument Distance", | |
"max_length": "Max. Length", | |
"num_beams": "Num. Beams", | |
"task=are": "ARE", | |
"discont_comp=true": "Discont. Comp.", | |
"split=val": "Validation Split", | |
}, | |
**kwargs, | |
) | |