annabeth97c's picture
feat(src/sonicverse): Initial commit
7c34c28
from typing import List, Dict
import logging
import torch
from enum import Enum
class MultiTaskType(Enum):
NO_MULTI_TASK = 0
SIMPLE_MULTI_TASK = 1
PROJECTED_MULTI_TASK = 2
def _find_all_linear_names(model) -> List[str]:
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names:
lora_module_names.remove("lm_head")
return list(lora_module_names)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
)
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_peft_state(named_params, bias) -> Dict:
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError()
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora(named_params, task_names) -> Dict:
to_return = {}
for k, t in named_params:
if "lora_" not in k:
task_name_in_k = False
for task_name in task_names:
if task_name in k:
task_name_in_k = True
if t.requires_grad or task_name_in_k:
to_return[k] = t
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def make_model_lora(model, training_args: "TrainingArguments"):
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=_find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
model = get_peft_model(model, lora_config)
return model
def fix_tokenizer(tokenizer):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
if tokenizer.mask_token is None:
tokenizer.mask_token = tokenizer.unk_token
if tokenizer.cls_token is None:
tokenizer.cls_token = tokenizer.unk_token
if tokenizer.sep_token is None:
tokenizer.sep_token = tokenizer.unk_token