COLE / src /evaluation /evaluation_pipeline.py
Yurhu's picture
Initial snapshot upload
88911a6 verified
import argparse
import gc
import logging
from datetime import datetime
import torch
import wandb
from tqdm import tqdm
from predictions.all_llms import llms
from src.evaluation.model_evaluator import ModelEvaluator
from src.model.hugging_face_model import HFLLMModel
from src.task.task_factory import tasks_factory
parser = argparse.ArgumentParser()
parser.add_argument(
"--test",
help="If set to true, the system will default to testing only a small model with a few examples.",
default=False,
type=bool,
)
parser.add_argument(
"--max_examples",
"-m",
help="The maximum number of examples to use, defaults to None.",
type=int,
default=None,
)
parser.add_argument(
"--token",
"-t",
help="Input your HuggingFace token to fetch models.",
type=str,
default=None,
)
parser.add_argument(
"--models_name",
"-mn",
help="The name of the model(s) to load.",
type=str,
default=None,
)
parser.add_argument(
"--batch_size",
help="The batch size to use during the evaluation.",
type=int,
default=16,
)
args = parser.parse_args()
tasks_names = [
"piaf",
"qfrblimp",
"allocine",
"qfrcola",
"gqnli",
"opus_parcus",
"paws_x",
"fquad",
"sickfr",
"sts22",
"xnli",
]
tasks = tasks_factory(tasks_names)
models = []
if args.models_name is not None:
if args.models_name in llms:
models = llms[args.models_name]
else:
models = args.models_name.split(",")
else:
models = llms["all"]
logging.info("Starting Evaluation")
time_start = datetime.now()
for model_name in tqdm(
models, total=len(models), desc="Processing LLM inference on tasks."
):
try:
model = HFLLMModel(model_name=model_name, batch_size=args.batch_size)
logging.info("Creating model")
evaluator = ModelEvaluator()
logging.info("Evaluating model")
exp_name = f"{model_name}"
wandb.init(
project="COLLE",
config={"model_name": model_name, "tasks": "; ".join(tasks_names)},
name=exp_name,
)
predictions_payload = evaluator.evaluate_subset(model, tasks, args.max_examples)
wandb.log(predictions_payload)
logging.info("Saving results")
evaluator.save_results("./results")
metrics_payload = evaluator.compute_metrics()
evaluator.save_metrics("./results")
wandb.log(metrics_payload)
wandb.finish(exit_code=0)
except Exception as e:
error_message = f"Evaluation failed for model {model_name}: {e}"
logging.error(error_message)
wandb.finish(exit_code=1)
continue
finally:
# Memory cleaning
if "model" in locals():
del model
if "evaluator" in locals():
del evaluator
gc.collect()
torch.cuda.empty_cache()
time_end = datetime.now()
info_message = f"End time: {time_end}"
logging.info(info_message)
elapsed_time = time_end - time_start
info_message = f"Elapsed time: {elapsed_time}"
logging.info(info_message)