|
import logging
|
|
from typing import Dict, List, Union
|
|
|
|
from src.task.task import Task, Tasktype
|
|
|
|
|
|
def tasks_factory(task_names: Union[Dict, List[str]]) -> List[Task]:
|
|
"""
|
|
Factory method to create a list of Task objects from a dictionary of task names and their predictions.
|
|
"""
|
|
tasks = []
|
|
if isinstance(task_names, Dict):
|
|
tasks_names = task_names.get("tasks")
|
|
task_names = [list(task.keys())[0] for task in tasks_names]
|
|
|
|
for task_name in task_names:
|
|
match task_name:
|
|
case "allocine":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case "fquad":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="fquad",
|
|
task_type=Tasktype.GENERATIVE,
|
|
)
|
|
)
|
|
case "gqnli":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case "opus_parcus":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="pearson",
|
|
)
|
|
)
|
|
case "paws_x":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case "piaf":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="fquad",
|
|
task_type=Tasktype.GENERATIVE,
|
|
)
|
|
)
|
|
case "qfrblimp":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case "qfrcola":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case "sickfr":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="pearson",
|
|
)
|
|
)
|
|
case "sts22":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="pearson",
|
|
)
|
|
)
|
|
case "xnli":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case "expressions_quebecoises":
|
|
tasks.append(
|
|
Task(
|
|
task_name=task_name,
|
|
metric="accuracy",
|
|
)
|
|
)
|
|
case _:
|
|
error = f"Unknown task {task_name}."
|
|
logging.error(error)
|
|
raise ValueError(error)
|
|
return tasks
|
|
|