import logging import pathlib from dataclasses import dataclass, field from typing import Dict, List, Optional import deepspeed import torch import transformers from transformers import AutoConfig, AutoTokenizer from blip3o.data import make_supervised_data_module from blip3o.model import blip3oQwenForCausalLM from blip3o.train.blip3o_trainer import blip3oTrainer from blip3o.utils import rank0_print from tabulate import tabulate torch.multiprocessing.set_sharing_strategy("file_system") local_rank = None @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") diffusion_name_or_path: Optional[str] = field(default="facebook/opt-125m") model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from blip3oLlama, blip3oMixtral, blip3oMistral, Llama"}) mm_tunable_parts: Optional[str] = field(default="mm_language_model") version: Optional[str] = field(default="v0") vision_tower: Optional[str] = field(default=None) vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer mm_use_im_start_end: bool = field(default=False) mm_patch_merge_type: Optional[str] = field(default="flat") mm_vision_select_feature: Optional[str] = field(default="patch") rope_scaling_factor: Optional[float] = field(default=None) rope_scaling_type: Optional[str] = field(default=None) use_pos_skipping: Optional[bool] = field(default=False) pos_skipping_range: Optional[int] = field(default=4096) delay_load: Optional[bool] = field(default=True) num_image_tokens: Optional[int] = field(default=-1) image_token_format: str = field(default="") num_scale_tokens: Optional[int] = field(default=3) scale_token_format: str = field(default="") load_embeddings_from_vision: Optional[bool] = field(default=False) @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data, in blip3o's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"}) lazy_preprocess: bool = False is_multimodal: bool = False early_mix_text: bool = False image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = "square" dataset_cls: str = field(default="blip3o") @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=4096, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) mm_vision_tower_lr: Optional[float] = None group_by_varlen: bool = field(default=False) group_by_modality_length: bool = field(default=False) group_by_modality_length_auto: bool = field(default=False) auto_find_batch_size: bool = field(default=False) gradient_checkpointing: bool = field(default=True) attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."}) dispatch_batches: Optional[bool] = field(default=None) split_batches: Optional[bool] = field(default=None) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): trainer.accelerator.wait_for_everyone() torch.cuda.synchronize() if trainer.deepspeed: trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def get_model(model_args, training_args): customized_kwargs = {} overwrite_config = {} cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) if model_args.use_pos_skipping is not None and model_args.pos_skipping_range is not None: overwrite_config["use_pos_skipping"] = model_args.use_pos_skipping overwrite_config["pos_skipping_range"] = model_args.pos_skipping_range if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: overwrite_config["rope_scaling"] = { "factor": model_args.rope_scaling_factor, "type": model_args.rope_scaling_type, } if training_args.model_max_length is None: training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor overwrite_config["max_sequence_length"] = training_args.model_max_length assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print( f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}" ) if overwrite_config: assert cfg_pretrained is not None, "cfg_pretrained is None" rank0_print(f"Overwriting config with {overwrite_config}") for k, v in overwrite_config.items(): setattr(cfg_pretrained, k, v) customized_kwargs["config"] = cfg_pretrained model = blip3oQwenForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=training_args.attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), low_cpu_mem_usage=False, **customized_kwargs) return model def train(): global local_rank parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank model = get_model(model_args, training_args) model.config.use_cache = False if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: model.config.rope_scaling = { "factor": model_args.rope_scaling_factor, "type": model_args.rope_scaling_type, } if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") if tokenizer.unk_token is not None: tokenizer.pad_token = tokenizer.unk_token if model_args.vision_tower is None: model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) data_args.image_processor = vision_tower.image_processor data_args.is_multimodal = True model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.diffusion_name_or_path = model_args.diffusion_name_or_path model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr training_args.use_im_start_end = model_args.mm_use_im_start_end model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) ### Deciding train which part of the model rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}") model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts # Set the entire model to not require gradients by default model.requires_grad_(False) vision_tower.requires_grad_(False) vision_tower.eval() # Parse the mm_tunable_parts to decide which parts to unfreeze tunable_parts = model_args.mm_tunable_parts.split(",") if "mm_vision_tower" in tunable_parts: for name, param in model.named_parameters(): if "vision_tower" in name: param.requires_grad_(True) if "mm_language_model" in tunable_parts: for name, param in model.named_parameters(): if "vision_tower" not in name: param.requires_grad_(True) if 'mm_embedding' in tunable_parts: for name, param in model.named_parameters(): if "embed_tokens" in name or 'lm_head' in name: param.requires_grad_(True) ## freeze sana except the caption projection for name, param in model.named_parameters(): if "sana" in name: param.requires_grad_(False) for name, param in model.named_parameters(): if "caption" in name: param.requires_grad_(True) total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters()) trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad) rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)") rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)") for name, p in model.named_parameters(): if p.requires_grad: rank0_print(f"Trainable parameter: {name}") data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = blip3oTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) if trainer.is_world_process_zero(): stat = [] for i, (n, p) in enumerate(trainer.model.named_parameters()): stat.append([i, n, p.shape, p.requires_grad]) print(tabulate(stat, headers=["idx", "name", "shape", "trainable"])) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) rank0_print(f"Model saved to {training_args.output_dir}") if __name__ == "__main__": train()