Spaces:
Sleeping
Sleeping
from typing import List, Dict | |
import logging | |
import torch | |
from enum import Enum | |
class MultiTaskType(Enum): | |
NO_MULTI_TASK = 0 | |
SIMPLE_MULTI_TASK = 1 | |
PROJECTED_MULTI_TASK = 2 | |
def _find_all_linear_names(model) -> List[str]: | |
cls = torch.nn.Linear | |
lora_module_names = set() | |
for name, module in model.named_modules(): | |
if isinstance(module, cls): | |
names = name.split(".") | |
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) | |
if "lm_head" in lora_module_names: | |
lora_module_names.remove("lm_head") | |
return list(lora_module_names) | |
def maybe_zero_3(param, ignore_status=False, name=None): | |
from deepspeed import zero | |
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus | |
if hasattr(param, "ds_id"): | |
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: | |
if not ignore_status: | |
logging.warning( | |
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" | |
) | |
with zero.GatheredParameters([param]): | |
param = param.data.detach().cpu().clone() | |
else: | |
param = param.detach().cpu().clone() | |
return param | |
def get_peft_state(named_params, bias) -> Dict: | |
if bias == "none": | |
to_return = {k: t for k, t in named_params if "lora_" in k} | |
elif bias == "all": | |
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} | |
elif bias == "lora_only": | |
to_return = {} | |
maybe_lora_bias = {} | |
lora_bias_names = set() | |
for k, t in named_params: | |
if "lora_" in k: | |
to_return[k] = t | |
bias_name = k.split("lora_")[0] + "bias" | |
lora_bias_names.add(bias_name) | |
elif "bias" in k: | |
maybe_lora_bias[k] = t | |
for k, t in maybe_lora_bias: | |
if bias_name in lora_bias_names: | |
to_return[bias_name] = t | |
else: | |
raise NotImplementedError() | |
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} | |
return to_return | |
def get_peft_state_non_lora(named_params, task_names) -> Dict: | |
to_return = {} | |
for k, t in named_params: | |
if "lora_" not in k: | |
task_name_in_k = False | |
for task_name in task_names: | |
if task_name in k: | |
task_name_in_k = True | |
if t.requires_grad or task_name_in_k: | |
to_return[k] = t | |
to_return = { | |
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() | |
} | |
return to_return | |
def make_model_lora(model, training_args: "TrainingArguments"): | |
from peft import LoraConfig, get_peft_model | |
lora_config = LoraConfig( | |
r=training_args.lora_r, | |
lora_alpha=training_args.lora_alpha, | |
target_modules=_find_all_linear_names(model), | |
lora_dropout=training_args.lora_dropout, | |
bias=training_args.lora_bias, | |
task_type="CAUSAL_LM", | |
) | |
if training_args.bits == 16: | |
if training_args.bf16: | |
model.to(torch.bfloat16) | |
if training_args.fp16: | |
model.to(torch.float16) | |
model = get_peft_model(model, lora_config) | |
return model | |
def fix_tokenizer(tokenizer): | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.unk_token | |
if tokenizer.mask_token is None: | |
tokenizer.mask_token = tokenizer.unk_token | |
if tokenizer.cls_token is None: | |
tokenizer.cls_token = tokenizer.unk_token | |
if tokenizer.sep_token is None: | |
tokenizer.sep_token = tokenizer.unk_token | |