from typing import List, Dict import logging import torch from enum import Enum class MultiTaskType(Enum): NO_MULTI_TASK = 0 SIMPLE_MULTI_TASK = 1 PROJECTED_MULTI_TASK = 2 def _find_all_linear_names(model) -> List[str]: cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if "lm_head" in lora_module_names: lora_module_names.remove("lm_head") return list(lora_module_names) def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning( f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" ) with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_peft_state(named_params, bias) -> Dict: if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError() to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora(named_params, task_names) -> Dict: to_return = {} for k, t in named_params: if "lora_" not in k: task_name_in_k = False for task_name in task_names: if task_name in k: task_name_in_k = True if t.requires_grad or task_name_in_k: to_return[k] = t to_return = { k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() } return to_return def make_model_lora(model, training_args: "TrainingArguments"): from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=_find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) if training_args.bits == 16: if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) model = get_peft_model(model, lora_config) return model def fix_tokenizer(tokenizer): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.unk_token if tokenizer.mask_token is None: tokenizer.mask_token = tokenizer.unk_token if tokenizer.cls_token is None: tokenizer.cls_token = tokenizer.unk_token if tokenizer.sep_token is None: tokenizer.sep_token = tokenizer.unk_token