# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from typing import Any, Optional from transformers import TrainingArguments from .sft_config import SFTConfig @dataclass class GKDConfig(SFTConfig): """ Configuration class for [`GKDTrainer`]. This class includes only the parameters that are specific to GKD training. For a full list of training arguments, please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation. Args: temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. lmbda (`float`, *optional*, defaults to `0.5`): Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy student-generated outputs). beta (`float`, *optional*, defaults to `0.5`): Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. max_new_tokens (`int`, *optional*, defaults to `128`): Maximum number of tokens to generate per completion. teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being trained. teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model. seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). """ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] temperature: float = field( default=0.9, metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, ) lmbda: float = field( default=0.5, metadata={ "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy " "student-generated outputs)." }, ) beta: float = field( default=0.5, metadata={ "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence " "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL " "Divergence." }, ) max_new_tokens: int = field( default=128, metadata={"help": "Maximum number of tokens to generate per completion."}, ) teacher_model_name_or_path: Optional[str] = field( default=None, metadata={ "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the " "model being trained." }, ) teacher_model_init_kwargs: Optional[dict[str, Any]] = field( default=None, metadata={ "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " "teacher model from a string." }, ) disable_dropout: bool = field( default=True, metadata={"help": "Whether to disable dropouts in `model`."}, ) seq_kd: bool = field( default=False, metadata={ "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised " "FT on teacher-generated output)." }, ) def __post_init__(self): super().__post_init__() # check lmbda and beta are in the range [0, 1] if self.lmbda < 0.0 or self.lmbda > 1.0: raise ValueError("lmbda must be in the range [0.0, 1.0].") if self.beta < 0.0 or self.beta > 1.0: raise ValueError("beta must be in the range [0.0, 1.0].")