ScientificArgumentRecommender / src /analysis /show_inference_params_on_quality_and_throughput.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import argparse
import json
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
import pandas as pd
import plotly.express as px
def get_col_name(col: str) -> str:
parts = [part[1:-1] for part in col[1:-1].split(", ") if part[1:-1] != ""]
return parts[-1]
def get_idx_entry(s: str, keep_only_last_part: bool = False) -> Tuple[str, str]:
k, v = s.split("=", 1)
if keep_only_last_part:
k = k.split(".")[-1]
return k, v
def get_idx_dict(job_id: str, keep_only_last_part: bool = False) -> Dict[str, str]:
return dict(
get_idx_entry(part, keep_only_last_part=keep_only_last_part) for part in job_id.split("-")
)
def unflatten_index(
index: Iterable[str],
keep_only_last_part: bool = False,
dtypes: Optional[Dict[str, Any]] = None,
) -> pd.MultiIndex:
as_df = pd.DataFrame.from_records(
[get_idx_dict(idx, keep_only_last_part=keep_only_last_part) for idx in index]
)
if dtypes is not None:
dtypes_valid = {col: dtype for col, dtype in dtypes.items() if col in as_df.columns}
as_df = as_df.astype(dtypes_valid)
return pd.MultiIndex.from_frame(as_df.convert_dtypes())
def col_to_str(col_entries: Iterable[str], names: Iterable[Optional[str]], sep: str) -> str:
return sep.join(
[
f"{name}={col_entry}" if name is not None else col_entry
for col_entry, name in zip(col_entries, names)
]
)
def flatten_index(index: pd.MultiIndex, names: Optional[List[Optional[str]]] = None) -> pd.Index:
names = names or index.names
if names is None:
raise ValueError("names must be provided if index has no names")
return pd.Index([col_to_str(col, names=names, sep=",") for col in index])
def prepare_quality_and_throughput_dfs(
metric_data_path: str,
job_return_value_path: str,
char_total: int,
index_dtypes: Optional[Dict[str, Any]] = None,
job_id_prefix: Optional[str] = None,
) -> Tuple[pd.DataFrame, pd.Series]:
with open(metric_data_path) as f:
data = json.load(f)
# save result from above command in "data" (use only last ouf the output line!)
df = pd.DataFrame.from_dict(data)
df.columns = [get_col_name(col) for col in df.columns]
f1_series = df.set_index([col for col in df.columns if col != "f1"])["f1"]
f1_df = f1_series.apply(lambda x: pd.Series(x)).T
with open(job_return_value_path) as f:
job_return_value = json.load(f)
job_ids = job_return_value["job_id"]
if job_id_prefix is not None:
job_ids = [
f"{job_id_prefix},{job_id}" if job_id.strip() != "" else job_id_prefix
for job_id in job_ids
]
index = unflatten_index(
job_ids,
keep_only_last_part=True,
dtypes=index_dtypes,
)
prediction_time_series = pd.Series(
job_return_value["prediction_time"], index=index, name="prediction_time"
)
f1_df.index = prediction_time_series.index
k_chars_per_s = char_total / (prediction_time_series * 1000)
k_chars_per_s.name = "1k_chars_per_s"
return f1_df, k_chars_per_s
def get_pareto_front_mask(df: pd.DataFrame, x_col: str, y_col: str) -> pd.Series:
"""
Return a boolean mask indicating which rows belong to the Pareto front.
In this version, we assume you want to maximize both x_col and y_col.
A point A is said to dominate point B if:
A[x_col] >= B[x_col] AND
A[y_col] >= B[y_col] AND
at least one is strictly greater.
Then B is not on the Pareto front.
Parameters
----------
df : pd.DataFrame
DataFrame containing the data points.
x_col : str
Name of the column to treat as the first objective (maximize).
y_col : str
Name of the column to treat as the second objective (maximize).
Returns
-------
pd.Series
A boolean Series (aligned with df.index) where True means
the row is on the Pareto front.
"""
# Extract the relevant columns as a NumPy array for speed.
data = df[[x_col, y_col]].values
n = len(data)
is_dominated = np.zeros(n, dtype=bool)
for i in range(n):
# If it's already marked dominated, skip checks
if is_dominated[i]:
continue
for j in range(n):
if i == j:
continue
# Check if j dominates i
if (
data[j, 0] >= data[i, 0]
and data[j, 1] >= data[i, 1]
and (data[j, 0] > data[i, 0] or data[j, 1] > data[i, 1])
):
is_dominated[i] = True
break
# Return True for points not dominated by any other
return pd.Series(~is_dominated, index=df.index)
def main(
job_return_value_path_test: List[str],
job_return_value_path_val: List[str],
metric_data_path_test: List[str],
metric_data_path_val: List[str],
char_total_test: int,
char_total_val: int,
job_id_prefixes: Optional[List[str]] = None,
metric_filters: Optional[List[str]] = None,
index_filters: Optional[List[str]] = None,
index_blacklist: Optional[List[str]] = None,
label_mapping: Optional[Dict[str, str]] = None,
plot_method: str = "line", # can be "scatter" or "line"
pareto_front: bool = False,
show_as: str = "figure",
columns: Optional[List[str]] = None,
color_column: Optional[str] = None,
):
label_mapping = label_mapping or {}
if job_id_prefixes is not None:
if len(job_id_prefixes) != len(job_return_value_path_test):
raise ValueError(
f"job_id_prefixes ({len(job_id_prefixes)}) and "
f"job_return_value_path_test ({len(job_return_value_path_test)}) "
f"must have the same length"
)
# replace empty strings with None
job_id_prefixes_with_none = [
job_id_prefix if job_id_prefix != "" else None for job_id_prefix in job_id_prefixes
]
else:
job_id_prefixes_with_none = [None] * len(job_return_value_path_test)
# combine input data for test and val
char_total = {"test": char_total_test, "val": char_total_val}
metric_data_path = {"test": metric_data_path_test, "val": metric_data_path_val}
job_return_value_path = {"test": job_return_value_path_test, "val": job_return_value_path_val}
# prepare dataframes
common_kwargs = dict(
index_dtypes={
"max_argument_distance": int,
"max_length": int,
"num_beams": int,
}
)
f1_df_list: Dict[str, List[pd.DataFrame]] = {"test": [], "val": []}
k_chars_per_s_list: Dict[str, List[pd.Series]] = {"test": [], "val": []}
for split in metric_data_path:
if len(metric_data_path[split]) != len(job_return_value_path[split]):
raise ValueError(
f"metric_data_path[{split}] ({len(metric_data_path[split])}) and "
f"job_return_value_path[{split}] ({len(job_return_value_path[split])}) "
f"must have the same length"
)
for current_metric_data_path, current_job_return_value_path, job_id_prefix in zip(
metric_data_path[split], job_return_value_path[split], job_id_prefixes_with_none
):
current_f1_df, current_k_chars_per_s = prepare_quality_and_throughput_dfs(
current_metric_data_path,
current_job_return_value_path,
char_total=char_total[split],
job_id_prefix=job_id_prefix,
**common_kwargs,
)
f1_df_list[split].append(current_f1_df)
k_chars_per_s_list[split].append(current_k_chars_per_s)
f1_df_dict = {split: pd.concat(f1_df_list[split], axis=0) for split in f1_df_list}
k_chars_per_s_dict = {
split: pd.concat(k_chars_per_s_list[split], axis=0) for split in k_chars_per_s_list
}
# combine dataframes for test and val
f1_df = pd.concat(f1_df_dict, names=["split"] + f1_df_dict["test"].index.names)
f1_df.columns = [col_to_str(col, names=f1_df.columns.names, sep=",") for col in f1_df.columns]
k_chars_per_s = pd.concat(
k_chars_per_s_dict,
names=["split"] + k_chars_per_s_dict["test"].index.names,
)
# combine quality and throughput data
df_plot = pd.concat([f1_df, k_chars_per_s], axis=1)
df_plot = (
df_plot.reset_index()
.set_index(list(f1_df.index.names) + [k_chars_per_s.name])
.unstack("split")
)
df_plot.columns = flatten_index(df_plot.columns, names=[None, "split"])
# remove all columns that are not needed
if metric_filters is not None:
for fil in metric_filters:
df_plot.drop(columns=[col for col in df_plot.columns if fil not in col], inplace=True)
df_plot.columns = [col.replace(fil, "") for col in df_plot.columns]
# flatten the columns
df_plot.columns = [
",".join([part for part in col.split(",") if part != ""]) for col in df_plot.columns
]
v: Any
if index_filters is not None:
for k_v in index_filters:
k, v = k_v.split("=")
if k in common_kwargs["index_dtypes"]:
v = common_kwargs["index_dtypes"][k](v)
df_plot = df_plot.xs(v, level=k, axis=0)
if index_blacklist is not None:
for k_v in index_blacklist:
k, v = k_v.split("=")
if k in common_kwargs["index_dtypes"]:
v = common_kwargs["index_dtypes"][k](v)
df_plot = df_plot.drop(v, level=k, axis=0)
if columns is not None:
df_plot = df_plot[columns]
x = "1k_chars_per_s"
y = df_plot.columns
if pareto_front:
for col in y:
current_data = df_plot[col].dropna().reset_index(x).copy()
pareto_front_mask = get_pareto_front_mask(current_data, x_col=x, y_col=col)
current_data.loc[~pareto_front_mask, col] = np.nan
current_data_reset = current_data.reset_index().set_index(df_plot.index.names)
df_plot[col] = current_data_reset[col]
# remove nan rows
df_plot = df_plot.dropna(how="all")
# plot
# Create a custom color sequence (concatenating multiple palettes if needed)
custom_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
text_cols = list(df_plot.index.names)
text_cols.remove(x)
df_plot_reset = df_plot.reset_index()
if len(text_cols) > 1:
df_plot_reset[",".join(text_cols)] = (
df_plot_reset[text_cols].astype(str).agg(", ".join, axis=1)
)
text_col = ",".join(text_cols)
if show_as == "figure":
_plot_method = getattr(px, plot_method)
df_plot_sorted = df_plot_reset.sort_values(by=x)
fig = _plot_method(
df_plot_sorted,
x=x,
y=y,
text=text_col if plot_method != "scatter" else None,
color=color_column,
color_discrete_sequence=custom_colors,
hover_data=text_cols,
)
# set connectgaps to True to connect the lines
fig.update_traces(connectgaps=True)
legend_title = "Evaluation Setup"
if metric_filters:
whitelist_filters_mapped = [label_mapping.get(fil, fil) for fil in metric_filters]
legend_title += f" ({', '.join(whitelist_filters_mapped)})"
text_cols_mapped = [label_mapping.get(col, col) for col in text_cols]
title = f"Impact of {', '.join(text_cols_mapped)} on Prediction Quality and Throughput"
if index_filters:
index_filters_mapped = [label_mapping.get(fil, fil) for fil in index_filters]
title += f" ({', '.join(index_filters_mapped)})"
if pareto_front:
title += " (Pareto Front)"
fig.update_layout(
xaxis_title="Throughput (1k chars/s)",
yaxis_title="Quality (F1)",
title=title,
# center the title
title_x=0.2,
# black title
title_font=dict(color="black"),
# change legend title
legend_title=legend_title,
font_family="Computer Modern",
# white background
plot_bgcolor="white",
paper_bgcolor="white",
)
update_axes_kwargs = dict(
tickfont=dict(color="black"),
title_font=dict(color="black"),
ticks="inside", # ensure tick markers are drawn
tickcolor="black",
tickwidth=1,
ticklen=10,
linecolor="black",
# show grid
gridcolor="lightgray",
)
fig.update_yaxes(**update_axes_kwargs)
fig.update_xaxes(**update_axes_kwargs)
fig.show()
elif show_as == "markdown":
# Print the DataFrame as a Markdown table
print(df_plot_reset.to_markdown(index=False, floatfmt=".4f"))
elif show_as == "json":
# Print the DataFrame as a JSON object
print(df_plot_reset.to_json(orient="columns", indent=4))
else:
raise ValueError(f"Unknown show_as value: {show_as}. Use 'figure', 'markdown' or 'json'.")
if __name__ == "__main__":
"""
# Example usage 1 (pipeline model, data from data source: https://github.com/ArneBinder/pie-document-level/issues/388#issuecomment-2752829257):
python src/analysis/show_inference_params_on_quality_and_throughput.py \
--job-return-value-path-test logs/prediction/multiruns/default/2025-03-26_01-31-05/job_return_value.json \
--job-return-value-path-val logs/prediction/multiruns/default/2025-03-26_16-49-36/job_return_value.json \
--metric-data-path-test data/evaluation/argumentation_structure/inference_pipeline_test.json \
--metric-data-path-val data/evaluation/argumentation_structure/inference_pipeline_validation.json \
--metric-filters task=are discont_comp=true split=val
# Example usage 2 (joint model, data from: https://github.com/ArneBinder/pie-document-level/issues/390#issuecomment-2759888004)
python src/analysis/show_inference_params_on_quality_and_throughput.py \
--job-return-value-path-test logs/prediction/multiruns/default/2025-03-28_01-34-07/job_return_value.json \
--job-return-value-path-val logs/prediction/multiruns/default/2025-03-28_02-57-00/job_return_value.json \
--metric-data-path-test data/evaluation/argumentation_structure/inference_joint_test.json \
--metric-data-path-val data/evaluation/argumentation_structure/inference_joint_validation.json \
--metric-filters task=are discont_comp=true split=val \
--plot-method scatter
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--job-return-value-path-test",
type=str,
nargs="+",
required=True,
)
parser.add_argument(
"--job-return-value-path-val",
type=str,
nargs="+",
required=True,
)
parser.add_argument(
"--metric-data-path-test",
type=str,
nargs="+",
required=True,
)
parser.add_argument(
"--metric-data-path-val",
type=str,
nargs="+",
required=True,
)
parser.add_argument(
"--job-id-prefixes",
type=str,
nargs="*",
default=None,
)
parser.add_argument(
"--plot-method",
type=str,
default="line",
choices=["scatter", "line"],
help="Plot method to use (default: line)",
)
parser.add_argument(
"--color-column",
type=str,
default=None,
help="Column to use for colour coding (default: None)",
)
parser.add_argument(
"--metric-filters",
type=str,
nargs="*",
default=None,
help="Filters to apply to the metric data in the format 'key=value'",
)
parser.add_argument(
"--index-filters",
type=str,
nargs="*",
default=None,
help="Filters to apply to the index data in the format 'key=value'",
)
parser.add_argument(
"--index-blacklist",
type=str,
nargs="*",
default=None,
help="Blacklist to apply to the index data in the format 'key=value'",
)
parser.add_argument(
"--columns",
type=str,
nargs="*",
default=None,
help="Columns to plot (default: all)",
)
parser.add_argument(
"--pareto-front",
action="store_true",
help="Whether to show only the pareto front",
)
parser.add_argument(
"--show-as",
type=str,
default="figure",
choices=["figure", "markdown", "json"],
help="How to show the results (default: figure)",
)
kwargs = vars(parser.parse_args())
main(
char_total_test=383154,
char_total_val=182794,
label_mapping={
"max_argument_distance": "Max. Argument Distance",
"max_length": "Max. Length",
"num_beams": "Num. Beams",
"task=are": "ARE",
"discont_comp=true": "Discont. Comp.",
"split=val": "Validation Split",
},
**kwargs,
)