Spaces:
Sleeping
Sleeping
from typing import Optional, List | |
from dataclasses import field, dataclass | |
import logging | |
import subprocess | |
import pathlib | |
import torch | |
import shutil | |
import glob | |
import os | |
import json | |
import transformers | |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR | |
from transformers import Trainer | |
from multi_token.training_data import ( | |
DataArguments, | |
LMMDataset, | |
DataCollatorForSupervisedLMMDataset, | |
) | |
from multi_token.model_utils import ( | |
make_model_lora, | |
get_peft_state, | |
get_peft_state_non_lora, | |
fix_tokenizer, | |
MultiTaskType | |
) | |
from multi_token.modalities.base_modality import Modality | |
README_TEMPLATE = """ | |
--- | |
license: apache-2.0 | |
base_model: {base_model} | |
dataset: {dataset} | |
tags: | |
- finetuned | |
- multimodal | |
inference: false | |
--- | |
These are weights for a version of `{base_model}` finetuned for multimodal applications. | |
### Modalities | |
{modalities} | |
### Usage | |
GitHub: https://github.com/sshh12/multi_token (includes training scripts and basic inference server) | |
### Dataset | |
{dataset} ({num_examples} examples) | |
``` | |
{dataset_example} | |
``` | |
### Training Device(s) | |
``` | |
{training_devices_dump} | |
``` | |
### Model | |
``` | |
{repr_model} | |
``` | |
""" | |
class TrainingArguments(transformers.TrainingArguments): | |
cache_dir: Optional[str] = field(default=None) | |
remove_unused_columns: bool = field(default=False) | |
optim: str = field(default="adamw_torch") | |
model_max_length: int = field( | |
default=512, | |
metadata={ | |
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." | |
}, | |
) | |
double_quant: bool = field( | |
default=True, | |
metadata={ | |
"help": "Compress the quantization statistics through double quantization." | |
}, | |
) | |
quant_type: str = field( | |
default="nf4", | |
metadata={ | |
"help": "Quantization data type to use. Should be one of `fp4` or `nf4`." | |
}, | |
) | |
pretrain_projectors: bool = field(default=False) | |
pretrained_projectors_path: Optional[str] = field(default=None) | |
pretrained_projectors_config: Optional[str] = field(default=None) | |
bits: int = field(default=16, metadata={"help": "How many bits to use."}) | |
lora_enable: bool = False | |
lora_r: int = 64 | |
lora_alpha: int = 16 | |
lora_dropout: float = 0.05 | |
lora_weight_path: str = "" | |
lora_bias: str = "none" | |
class ModelArguments: | |
model_name_or_path: str = field(default="mistralai/Mistral-7B-Instruct-v0.1") | |
model_cls: str = field(default="MistralLMMForCausalLM") | |
modality_builder: str = field(default="vision_clip") | |
use_multi_task: int = field(default=MultiTaskType.PROJECTED_MULTI_TASK) | |
tasks_config: str = field(default="src/sonicverse/configs/tasks.json") | |
model_lora_path: Optional[str] = field(default="amaai-lab/SonicVerse") | |
class LMMTrainer(Trainer): | |
def _save_checkpoint(self, model, trial, metrics=None): | |
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" | |
run_dir = self._get_output_dir(trial=trial) | |
output_dir = os.path.join(run_dir, checkpoint_folder) | |
self._save_extras(output_dir) | |
super(LMMTrainer, self)._save_checkpoint(model, trial, metrics) | |
def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
self._save_extras(output_dir) | |
super(LMMTrainer, self)._save(output_dir, state_dict) | |
for unused_dir in glob.iglob(os.path.join(output_dir, "global_step*")): | |
shutil.rmtree(unused_dir) | |
def _save_extras(self, output_dir: Optional[str] = None): | |
self.model.config.save_pretrained(output_dir) | |
task_names = [] | |
for m in self.model.modalities: | |
task_names += m.tasks["task_heads"].keys() | |
non_lora_state_dict = get_peft_state_non_lora(self.model.named_parameters(), task_names) | |
torch.save( | |
non_lora_state_dict, | |
os.path.join(output_dir, "non_lora_trainables.bin"), | |
) | |
def _get_training_devices_dump() -> str: | |
out = subprocess.check_output( | |
["nvidia-smi", "--query-gpu=gpu_name,gpu_bus_id,vbios_version", "--format=csv"] | |
) | |
return out.decode("utf-8").strip() | |
def train_for_modalities( | |
model_cls, | |
training_args: TrainingArguments, | |
model_args: ModelArguments, | |
train_data_args: DataArguments, | |
evaluation_data_args: DataArguments, | |
modalities: List[Modality], | |
): | |
for m in modalities: | |
m.to( | |
dtype=torch.bfloat16 if training_args.bf16 else torch.float16, | |
device=training_args.device, | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=training_args.cache_dir, | |
model_max_length=training_args.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
fix_tokenizer(tokenizer) | |
train_dataset = LMMDataset(train_data_args, tokenizer, modalities) | |
evaluation_dataset = LMMDataset(evaluation_data_args, tokenizer, modalities) | |
collator = DataCollatorForSupervisedLMMDataset(tokenizer, modalities) | |
model = model_cls.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=training_args.cache_dir, | |
) | |
model.to( | |
dtype=torch.bfloat16 if training_args.bf16 else torch.float16, | |
device=training_args.device, | |
) | |
model.modalities = modalities | |
model.config.use_cache = False | |
model.config.model_cls = model_cls.__name__ | |
model.config.modality_builder = model_args.modality_builder | |
if training_args.gradient_checkpointing: | |
if hasattr(model, "enable_input_require_grads"): | |
model.enable_input_require_grads() | |
else: | |
def make_inputs_require_grad(module, input, output): | |
output.requires_grad_(True) | |
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) | |
if model_args.model_lora_path: | |
raise ValueError( | |
"LoRA path not supported for training -- set the output path to an existing model to resume training" | |
) | |
if training_args.lora_enable: | |
logging.info("Adding LoRA adapters...") | |
model = make_model_lora(model, training_args) | |
if training_args.pretrained_projectors_path: | |
projector_weights_og = torch.load( | |
training_args.pretrained_projectors_path, map_location="cpu" | |
) | |
if model_args.use_multi_task==MultiTaskType.SIMPLE_MULTI_TASK: | |
projector_weights = {} | |
for k, v in projector_weights_og.items(): | |
for m in modalities: | |
for task_name in m.tasks["task_heads"].keys(): | |
if task_name in k: | |
projector_weights[k] = v | |
else: | |
projector_weights = { | |
k: v for k, v in projector_weights_og.items() if "_lmm_projector" in k | |
} | |
elif training_args.pretrained_projectors_config: | |
with open(training_args.pretrained_projectors_config, "r") as f: | |
pretrained_weights_config = json.load(f) | |
projector_weights = {} | |
for pretrained_path_info in pretrained_weights_config["pretrained_paths"]: | |
pretrained_path = pretrained_path_info["path"] | |
components = pretrained_path_info["components"] | |
use_prefix = pretrained_path_info["use_prefix"] | |
prefix = pretrained_path_info["prefix"] | |
pretrained_weights = torch.load(pretrained_path, map_location="cpu") | |
for k, v in pretrained_weights.items(): | |
if any(component in k for component in components): | |
weight_key = k | |
if use_prefix: | |
weight_key = prefix + "." + k | |
projector_weights[weight_key] = v | |
else: | |
projector_weights = {} | |
model.get_model().initialize_modules(modalities, projector_weights) | |
task_names = [] | |
tasks = {} | |
for m in model.modalities: | |
if m.use_multi_task != MultiTaskType.NO_MULTI_TASK: | |
tasks = m.tasks | |
task_names += m.tasks["task_heads"].keys() | |
if training_args.pretrain_projectors: | |
model.requires_grad_(False) | |
for m in modalities: | |
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: | |
for task_name in m.tasks["task_heads"].keys(): | |
task_model = getattr(model.get_model(), m.name + "_" + task_name) | |
for p in task_model.parameters(): | |
p.requires_grad = True | |
elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: | |
proj = getattr(model.get_model(), m.name + "_lmm_projector") | |
if "backbone" in m.tasks.keys(): | |
backbone = getattr(proj, "backbone") | |
for backbone_param in backbone.parameters(): | |
backbone_param.requires_grad = tasks["backbone"]["requires_grad"] | |
for task in task_names: | |
task_head = getattr(proj, task) | |
for task_head_param in task_head.parameters(): | |
task_head_param.requires_grad = tasks["task_heads"][task]["requires_grad"] | |
if task in tasks["task_projectors"]: | |
task_projector = getattr(proj, task + "_projector") | |
for task_projector_param in task_projector.parameters(): | |
task_projector_param.requires_grad = tasks["task_projectors"][task]["requires_grad"] | |
else: | |
proj = getattr(model.get_model(), m.name + "_lmm_projector") | |
for p in proj.parameters(): | |
p.requires_grad = True | |
os.makedirs(training_args.output_dir, exist_ok=True) | |
with open( | |
os.path.join(training_args.output_dir, "model_named_parameters.txt"), "w" | |
) as f: | |
for name, param in model.named_parameters(): | |
f.write(f"{name} {param.shape} {param.requires_grad}\n") | |
with open(os.path.join(training_args.output_dir, "README.md"), "w") as f: | |
modalities_text = [ | |
f"* {m.__class__.__name__} (use `{m.token}` in text and provide `{m.data_key}`, encoded as {m.token_width} tokens)" | |
for m in modalities | |
] | |
readme_text = README_TEMPLATE.format( | |
base_model=model_args.model_name_or_path, | |
dataset=train_data_args.dataset_path, | |
dataset_example=repr(train_dataset.get_example()), | |
num_examples=len(train_dataset), | |
modalities="\n".join(modalities_text), | |
training_devices_dump=_get_training_devices_dump(), | |
repr_model=f"{model_cls.__name__}.model =\n\n{repr(model)}", | |
) | |
f.write(readme_text) | |
trainer = LMMTrainer( | |
model=model, | |
tokenizer=tokenizer, | |
args=training_args, | |
data_collator=collator, | |
train_dataset=train_dataset, | |
eval_dataset=evaluation_dataset, | |
) | |
if list(pathlib.Path(training_args.output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")): | |
trainer.train(resume_from_checkpoint=True) | |
else: | |
trainer.train() | |
trainer.save_state() | |
model.config.use_cache = True | |
model.config.save_pretrained(training_args.output_dir) | |
state_dict = get_peft_state(model.named_parameters(), training_args.lora_bias) | |
model.save_pretrained(training_args.output_dir, state_dict=state_dict) | |
non_lora_state_dict = get_peft_state_non_lora(model.named_parameters(), task_names) | |
torch.save( | |
non_lora_state_dict, | |
os.path.join(training_args.output_dir, "non_lora_trainables.bin"), | |
) | |