from typing import Optional from pydantic import Field from autotrain.trainers.common import AutoTrainParams class VLMTrainingParams(AutoTrainParams): """ VLMTrainingParams Attributes: model (str): Model name. Default is "google/paligemma-3b-pt-224". project_name (str): Output directory. Default is "project-name". data_path (str): Data path. Default is "data". train_split (str): Train data config. Default is "train". valid_split (Optional[str]): Validation data config. Default is None. trainer (str): Trainer type (captioning, vqa, segmentation, detection). Default is "vqa". log (str): Logging using experiment tracking. Default is "none". disable_gradient_checkpointing (bool): Gradient checkpointing. Default is False. logging_steps (int): Logging steps. Default is -1. eval_strategy (str): Evaluation strategy. Default is "epoch". save_total_limit (int): Save total limit. Default is 1. auto_find_batch_size (bool): Auto find batch size. Default is False. mixed_precision (Optional[str]): Mixed precision (fp16, bf16, or None). Default is None. lr (float): Learning rate. Default is 3e-5. epochs (int): Number of training epochs. Default is 1. batch_size (int): Training batch size. Default is 2. warmup_ratio (float): Warmup proportion. Default is 0.1. gradient_accumulation (int): Gradient accumulation steps. Default is 4. optimizer (str): Optimizer. Default is "adamw_torch". scheduler (str): Scheduler. Default is "linear". weight_decay (float): Weight decay. Default is 0.0. max_grad_norm (float): Max gradient norm. Default is 1.0. seed (int): Seed. Default is 42. quantization (Optional[str]): Quantization (int4, int8, or None). Default is "int4". target_modules (Optional[str]): Target modules. Default is "all-linear". merge_adapter (bool): Merge adapter. Default is False. peft (bool): Use PEFT. Default is False. lora_r (int): Lora r. Default is 16. lora_alpha (int): Lora alpha. Default is 32. lora_dropout (float): Lora dropout. Default is 0.05. image_column (Optional[str]): Image column. Default is "image". text_column (str): Text (answer) column. Default is "text". prompt_text_column (Optional[str]): Prompt (prefix) column. Default is "prompt". push_to_hub (bool): Push to hub. Default is False. username (Optional[str]): Hugging Face Username. Default is None. token (Optional[str]): Huggingface token. Default is None. """ model: str = Field("google/paligemma-3b-pt-224", title="Model name") project_name: str = Field("project-name", title="Output directory") # data params data_path: str = Field("data", title="Data path") train_split: str = Field("train", title="Train data config") valid_split: Optional[str] = Field(None, title="Validation data config") # trainer params trainer: str = Field("vqa", title="Trainer type") # captioning, vqa, segmentation, detection log: str = Field("none", title="Logging using experiment tracking") disable_gradient_checkpointing: bool = Field(False, title="Gradient checkpointing") logging_steps: int = Field(-1, title="Logging steps") eval_strategy: str = Field("epoch", title="Evaluation strategy") save_total_limit: int = Field(1, title="Save total limit") auto_find_batch_size: bool = Field(False, title="Auto find batch size") mixed_precision: Optional[str] = Field(None, title="fp16, bf16, or None") lr: float = Field(3e-5, title="Learning rate") epochs: int = Field(1, title="Number of training epochs") batch_size: int = Field(2, title="Training batch size") warmup_ratio: float = Field(0.1, title="Warmup proportion") gradient_accumulation: int = Field(4, title="Gradient accumulation steps") optimizer: str = Field("adamw_torch", title="Optimizer") scheduler: str = Field("linear", title="Scheduler") weight_decay: float = Field(0.0, title="Weight decay") max_grad_norm: float = Field(1.0, title="Max gradient norm") seed: int = Field(42, title="Seed") # peft quantization: Optional[str] = Field("int4", title="int4, int8, or None") target_modules: Optional[str] = Field("all-linear", title="Target modules") merge_adapter: bool = Field(False, title="Merge adapter") peft: bool = Field(False, title="Use PEFT") lora_r: int = Field(16, title="Lora r") lora_alpha: int = Field(32, title="Lora alpha") lora_dropout: float = Field(0.05, title="Lora dropout") # column mappings image_column: Optional[str] = Field("image", title="Image column") text_column: str = Field("text", title="Text (answer) column") prompt_text_column: Optional[str] = Field("prompt", title="Prompt (prefix) column") # push to hub push_to_hub: bool = Field(False, title="Push to hub") username: Optional[str] = Field(None, title="Hugging Face Username") token: Optional[str] = Field(None, title="Huggingface token")