Spaces:
Sleeping
Sleeping
from typing import Type, List, Optional | |
import logging | |
from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig | |
from huggingface_hub import hf_hub_download | |
from peft import PeftModel | |
import torch | |
import os | |
from multi_token.model_utils import fix_tokenizer, MultiTaskType | |
from multi_token.modalities.base_modality import Modality | |
from multi_token.language_models.mistral import MistralForCausalLM | |
from multi_token.language_models import LANGUAGE_MODEL_NAME_TO_CLASS | |
from multi_token.modalities import MODALITY_BUILDERS | |
def load_trained_lora_model( | |
model_name_or_path: str, | |
model_lora_path: str, | |
model_cls: Optional[Type] = None, | |
modalities: Optional[List[Modality]] = None, | |
load_bits: int = 16, | |
device_map: str = "auto", | |
use_multi_task: int = MultiTaskType.NO_MULTI_TASK, | |
tasks_config: str = None | |
): | |
load_kwargs = {"device_map": device_map} | |
if load_bits == 8: | |
load_kwargs["load_in_8bit"] = True | |
elif load_bits == 4: | |
load_kwargs["load_in_4bit"] = True | |
load_kwargs["quantization_config"] = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
) | |
elif load_bits == 16: | |
load_kwargs["torch_dtype"] = torch.float16 | |
else: | |
raise ValueError(f"Invalid load_bits: {load_bits}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) | |
fix_tokenizer(tokenizer) | |
cfg = AutoConfig.from_pretrained(model_lora_path) | |
if model_cls is None: | |
model_cls = LANGUAGE_MODEL_NAME_TO_CLASS[cfg.model_cls] | |
if modalities is None: | |
if use_multi_task: | |
modalities = MODALITY_BUILDERS[cfg.modality_builder](use_multi_task = use_multi_task, tasks_config = tasks_config) | |
else: | |
modalities = MODALITY_BUILDERS[cfg.modality_builder]() | |
logging.info(f"Loading base model from {model_name_or_path} as {load_bits} bits") | |
model = model_cls.from_pretrained( | |
model_name_or_path, low_cpu_mem_usage=True, config=cfg, **load_kwargs | |
) | |
model.modalities = modalities | |
logging.info(f"Loading projector weights for {[m.name for m in modalities]}") | |
if os.path.exists(os.path.join(model_lora_path, "non_lora_trainables.bin")): | |
non_lora_trainables = torch.load( | |
os.path.join(model_lora_path, "non_lora_trainables.bin"), map_location="cuda" | |
) | |
else: | |
local_fn = hf_hub_download( | |
repo_id=model_lora_path, | |
filename="non_lora_trainables.bin", | |
repo_type="model", | |
) | |
non_lora_trainables = torch.load(local_fn, map_location="cuda") | |
model.get_model().initialize_pretrained_modules(modalities, non_lora_trainables) | |
logging.info(f"Loading and merging LoRA weights from {model_lora_path}") | |
model = PeftModel.from_pretrained(model, model_lora_path) | |
if load_bits == 16: | |
# TODO: Figure out why this fails for other bit sizes | |
model = model.merge_and_unload() | |
model.eval() | |
return model, tokenizer | |