import json from dataclasses import dataclass from typing import Optional from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams from autotrain.trainers.object_detection.params import ObjectDetectionParams from autotrain.trainers.sent_transformers.params import SentenceTransformersParams from autotrain.trainers.seq2seq.params import Seq2SeqParams from autotrain.trainers.tabular.params import TabularParams from autotrain.trainers.text_classification.params import TextClassificationParams from autotrain.trainers.text_regression.params import TextRegressionParams from autotrain.trainers.token_classification.params import TokenClassificationParams from autotrain.trainers.vlm.params import VLMTrainingParams HIDDEN_PARAMS = [ "token", "project_name", "username", "task", "backend", "train_split", "valid_split", "text_column", "rejected_text_column", "prompt_text_column", "push_to_hub", "trainer", "model", "data_path", "image_path", "class_image_path", "revision", "tokenizer", "class_prompt", "num_class_images", "class_labels_conditioning", "resume_from_checkpoint", "dataloader_num_workers", "allow_tf32", "prior_generation_precision", "local_rank", "tokenizer_max_length", "rank", "xl", "checkpoints_total_limit", "validation_images", "validation_epochs", "num_validation_images", "validation_prompt", "sample_batch_size", "log", "image_column", "target_column", "id_column", "target_columns", "tokens_column", "tags_column", "objects_column", "sentence1_column", "sentence2_column", "sentence3_column", "question_column", "answer_column", ] PARAMS = {} PARAMS["llm"] = LLMTrainingParams( target_modules="all-linear", log="tensorboard", mixed_precision="fp16", quantization="int4", peft=True, block_size=1024, epochs=3, padding="right", chat_template="none", max_completion_length=128, distributed_backend="ddp", ).model_dump() PARAMS["text-classification"] = TextClassificationParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["st"] = SentenceTransformersParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["image-classification"] = ImageClassificationParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["image-object-detection"] = ObjectDetectionParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["seq2seq"] = Seq2SeqParams( mixed_precision="fp16", target_modules="all-linear", log="tensorboard", ).model_dump() PARAMS["tabular"] = TabularParams( categorical_imputer="most_frequent", numerical_imputer="median", numeric_scaler="robust", ).model_dump() PARAMS["token-classification"] = TokenClassificationParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["text-regression"] = TextRegressionParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["image-regression"] = ImageRegressionParams( mixed_precision="fp16", log="tensorboard", ).model_dump() PARAMS["vlm"] = VLMTrainingParams( mixed_precision="fp16", target_modules="all-linear", log="tensorboard", quantization="int4", peft=True, epochs=3, ).model_dump() PARAMS["extractive-qa"] = ExtractiveQuestionAnsweringParams( mixed_precision="fp16", log="tensorboard", max_seq_length=512, max_doc_stride=128, ).model_dump() @dataclass class AppParams: """ AppParams class is responsible for managing and processing parameters for various machine learning tasks. Attributes: job_params_json (str): JSON string containing job parameters. token (str): Authentication token. project_name (str): Name of the project. username (str): Username of the project owner. task (str): Type of task to be performed. data_path (str): Path to the dataset. base_model (str): Base model to be used. column_mapping (dict): Mapping of columns for the dataset. train_split (Optional[str]): Name of the training split. Default is None. valid_split (Optional[str]): Name of the validation split. Default is None. using_hub_dataset (Optional[bool]): Flag indicating if a hub dataset is used. Default is False. api (Optional[bool]): Flag indicating if API is used. Default is False. Methods: __post_init__(): Validates the parameters after initialization. munge(): Processes the parameters based on the task type. _munge_common_params(): Processes common parameters for all tasks. _munge_params_sent_transformers(): Processes parameters for sentence transformers task. _munge_params_llm(): Processes parameters for large language model task. _munge_params_vlm(): Processes parameters for vision-language model task. _munge_params_text_clf(): Processes parameters for text classification task. _munge_params_extractive_qa(): Processes parameters for extractive question answering task. _munge_params_text_reg(): Processes parameters for text regression task. _munge_params_token_clf(): Processes parameters for token classification task. _munge_params_seq2seq(): Processes parameters for sequence-to-sequence task. _munge_params_img_clf(): Processes parameters for image classification task. _munge_params_img_reg(): Processes parameters for image regression task. _munge_params_img_obj_det(): Processes parameters for image object detection task. _munge_params_tabular(): Processes parameters for tabular data task. """ job_params_json: str token: str project_name: str username: str task: str data_path: str base_model: str column_mapping: dict train_split: Optional[str] = None valid_split: Optional[str] = None using_hub_dataset: Optional[bool] = False api: Optional[bool] = False def __post_init__(self): if self.using_hub_dataset and not self.train_split: raise ValueError("train_split is required when using a hub dataset") def munge(self): if self.task == "text-classification": return self._munge_params_text_clf() elif self.task == "seq2seq": return self._munge_params_seq2seq() elif self.task == "image-classification": return self._munge_params_img_clf() elif self.task == "image-object-detection": return self._munge_params_img_obj_det() elif self.task.startswith("tabular"): return self._munge_params_tabular() elif self.task.startswith("llm"): return self._munge_params_llm() elif self.task == "token-classification": return self._munge_params_token_clf() elif self.task == "text-regression": return self._munge_params_text_reg() elif self.task.startswith("st:"): return self._munge_params_sent_transformers() elif self.task == "image-regression": return self._munge_params_img_reg() elif self.task.startswith("vlm"): return self._munge_params_vlm() elif self.task == "extractive-qa": return self._munge_params_extractive_qa() else: raise ValueError(f"Unknown task: {self.task}") def _munge_common_params(self): _params = json.loads(self.job_params_json) _params["token"] = self.token _params["project_name"] = f"{self.project_name}" if "push_to_hub" not in _params: _params["push_to_hub"] = True _params["data_path"] = self.data_path _params["username"] = self.username return _params def _munge_params_sent_transformers(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["sentence1_column"] = "autotrain_sentence1" _params["sentence2_column"] = "autotrain_sentence2" _params["sentence3_column"] = "autotrain_sentence3" _params["target_column"] = "autotrain_target" _params["valid_split"] = "validation" else: _params["sentence1_column"] = self.column_mapping.get( "sentence1" if not self.api else "sentence1_column", "sentence1" ) _params["sentence2_column"] = self.column_mapping.get( "sentence2" if not self.api else "sentence2_column", "sentence2" ) _params["sentence3_column"] = self.column_mapping.get( "sentence3" if not self.api else "sentence3_column", "sentence3" ) _params["target_column"] = self.column_mapping.get("target" if not self.api else "target_column", "target") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split trainer = self.task.split(":")[1] _params["trainer"] = trainer.lower() return SentenceTransformersParams(**_params) def _munge_params_llm(self): _params = self._munge_common_params() _params["model"] = self.base_model if not self.using_hub_dataset: _params["text_column"] = "autotrain_text" _params["prompt_text_column"] = "autotrain_prompt" _params["rejected_text_column"] = "autotrain_rejected_text" else: _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") _params["prompt_text_column"] = self.column_mapping.get( "prompt" if not self.api else "prompt_text_column", "prompt" ) _params["rejected_text_column"] = self.column_mapping.get( "rejected_text" if not self.api else "rejected_text_column", "rejected_text" ) _params["train_split"] = self.train_split if "log" not in _params: _params["log"] = "tensorboard" trainer = self.task.split(":")[1] if trainer != "generic": _params["trainer"] = trainer.lower() if "quantization" in _params: if _params["quantization"] in ("none", "no"): _params["quantization"] = None return LLMTrainingParams(**_params) def _munge_params_vlm(self): _params = self._munge_common_params() _params["model"] = self.base_model if not self.using_hub_dataset: _params["text_column"] = "autotrain_text" _params["prompt_text_column"] = "autotrain_prompt" _params["image_column"] = "autotrain_image" _params["valid_split"] = "validation" else: _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") _params["prompt_text_column"] = self.column_mapping.get( "prompt" if not self.api else "prompt_text_column", "prompt" ) _params["image_column"] = self.column_mapping.get( "image" if not self.api else "rejected_text_column", "image" ) _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split if "log" not in _params: _params["log"] = "tensorboard" trainer = self.task.split(":")[1] _params["trainer"] = trainer.lower() if "quantization" in _params: if _params["quantization"] in ("none", "no"): _params["quantization"] = None return VLMTrainingParams(**_params) def _munge_params_text_clf(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["text_column"] = "autotrain_text" _params["target_column"] = "autotrain_label" _params["valid_split"] = "validation" else: _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") _params["target_column"] = self.column_mapping.get("label" if not self.api else "target_column", "label") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return TextClassificationParams(**_params) def _munge_params_extractive_qa(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["text_column"] = "autotrain_text" _params["question_column"] = "autotrain_question" _params["answer_column"] = "autotrain_answer" _params["valid_split"] = "validation" else: _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") _params["question_column"] = self.column_mapping.get( "question" if not self.api else "question_column", "question" ) _params["answer_column"] = self.column_mapping.get("answer" if not self.api else "answer_column", "answer") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return ExtractiveQuestionAnsweringParams(**_params) def _munge_params_text_reg(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["text_column"] = "autotrain_text" _params["target_column"] = "autotrain_label" _params["valid_split"] = "validation" else: _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") _params["target_column"] = self.column_mapping.get("label" if not self.api else "target_column", "label") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return TextRegressionParams(**_params) def _munge_params_token_clf(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["tokens_column"] = "autotrain_text" _params["tags_column"] = "autotrain_label" _params["valid_split"] = "validation" else: _params["tokens_column"] = self.column_mapping.get("tokens" if not self.api else "tokens_column", "tokens") _params["tags_column"] = self.column_mapping.get("tags" if not self.api else "tags_column", "tags") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return TokenClassificationParams(**_params) def _munge_params_seq2seq(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["text_column"] = "autotrain_text" _params["target_column"] = "autotrain_label" _params["valid_split"] = "validation" else: _params["text_column"] = self.column_mapping.get("text" if not self.api else "text_column", "text") _params["target_column"] = self.column_mapping.get("label" if not self.api else "target_column", "label") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return Seq2SeqParams(**_params) def _munge_params_img_clf(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["image_column"] = "autotrain_image" _params["target_column"] = "autotrain_label" _params["valid_split"] = "validation" else: _params["image_column"] = self.column_mapping.get("image" if not self.api else "image_column", "image") _params["target_column"] = self.column_mapping.get("label" if not self.api else "target_column", "label") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return ImageClassificationParams(**_params) def _munge_params_img_reg(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["image_column"] = "autotrain_image" _params["target_column"] = "autotrain_label" _params["valid_split"] = "validation" else: _params["image_column"] = self.column_mapping.get("image" if not self.api else "image_column", "image") _params["target_column"] = self.column_mapping.get("target" if not self.api else "target_column", "target") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return ImageRegressionParams(**_params) def _munge_params_img_obj_det(self): _params = self._munge_common_params() _params["model"] = self.base_model if "log" not in _params: _params["log"] = "tensorboard" if not self.using_hub_dataset: _params["image_column"] = "autotrain_image" _params["objects_column"] = "autotrain_objects" _params["valid_split"] = "validation" else: _params["image_column"] = self.column_mapping.get("image" if not self.api else "image_column", "image") _params["objects_column"] = self.column_mapping.get( "objects" if not self.api else "objects_column", "objects" ) _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split return ObjectDetectionParams(**_params) def _munge_params_tabular(self): _params = self._munge_common_params() _params["model"] = self.base_model if not self.using_hub_dataset: _params["id_column"] = "autotrain_id" _params["valid_split"] = "validation" if len(self.column_mapping["label"]) == 1: _params["target_columns"] = ["autotrain_label"] else: _params["target_columns"] = [ "autotrain_label_" + str(i) for i in range(len(self.column_mapping["label"])) ] else: _params["id_column"] = self.column_mapping.get("id" if not self.api else "id_column", "id") _params["train_split"] = self.train_split _params["valid_split"] = self.valid_split _params["target_columns"] = self.column_mapping.get("label" if not self.api else "target_columns", "label") if len(_params["categorical_imputer"].strip()) == 0 or _params["categorical_imputer"].lower() == "none": _params["categorical_imputer"] = None if len(_params["numerical_imputer"].strip()) == 0 or _params["numerical_imputer"].lower() == "none": _params["numerical_imputer"] = None if len(_params["numeric_scaler"].strip()) == 0 or _params["numeric_scaler"].lower() == "none": _params["numeric_scaler"] = None if "classification" in self.task: _params["task"] = "classification" else: _params["task"] = "regression" return TabularParams(**_params) def get_task_params(task, param_type): """ Retrieve task-specific parameters while filtering out hidden parameters based on the task and parameter type. Args: task (str): The task identifier, which can include prefixes like "llm", "st:", "vlm:", etc. param_type (str): The type of parameters to retrieve, typically "basic" or other types. Returns: dict: A dictionary of task-specific parameters with hidden parameters filtered out. Notes: - The function handles various task prefixes and adjusts the task and trainer variables accordingly. - Hidden parameters are filtered out based on the task and parameter type. - Additional hidden parameters are defined for specific tasks and trainers. """ if task.startswith("llm"): trainer = task.split(":")[1].lower() task = task.split(":")[0].lower() if task.startswith("st:"): trainer = task.split(":")[1].lower() task = task.split(":")[0].lower() if task.startswith("vlm:"): trainer = task.split(":")[1].lower() task = task.split(":")[0].lower() if task.startswith("tabular"): task = "tabular" if task not in PARAMS: return {} task_params = PARAMS[task] task_params = {k: v for k, v in task_params.items() if k not in HIDDEN_PARAMS} if task == "llm": more_hidden_params = [] if trainer == "sft": more_hidden_params = [ "model_ref", "dpo_beta", "add_eos_token", "max_prompt_length", "max_completion_length", ] elif trainer == "reward": more_hidden_params = [ "model_ref", "dpo_beta", "add_eos_token", "max_prompt_length", "max_completion_length", "unsloth", ] elif trainer == "orpo": more_hidden_params = [ "model_ref", "dpo_beta", "add_eos_token", "unsloth", ] elif trainer == "generic": more_hidden_params = [ "model_ref", "dpo_beta", "max_prompt_length", "max_completion_length", ] elif trainer == "dpo": more_hidden_params = [ "add_eos_token", "unsloth", ] if param_type == "basic": more_hidden_params.extend( [ "padding", "use_flash_attention_2", "disable_gradient_checkpointing", "logging_steps", "eval_strategy", "save_total_limit", "auto_find_batch_size", "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "quantization", "merge_adapter", "lora_r", "lora_alpha", "lora_dropout", "max_completion_length", ] ) task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "text-classification" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "extractive-qa" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "st" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "vlm" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", "quantization", "lora_r", "lora_alpha", "lora_dropout", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "text-regression" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "image-classification" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "image-regression" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "image-object-detection" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "seq2seq" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "quantization", "lora_r", "lora_alpha", "lora_dropout", "target_modules", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} if task == "token-classification" and param_type == "basic": more_hidden_params = [ "warmup_ratio", "weight_decay", "max_grad_norm", "seed", "logging_steps", "auto_find_batch_size", "save_total_limit", "eval_strategy", "early_stopping_patience", "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} return task_params