COLE / src /model /model_factory.py
Yurhu's picture
Initial snapshot upload
75ec748 verified
import torch
from transformers import (
AutoTokenizer,
BitsAndBytesConfig,
AutoModelForCausalLM,
)
def model_tokenizer_factory(
model_name,
huggingface_token: str,
):
# bitsandbytes config
USE_NESTED_QUANT = True # use_nested_quant
BNB_4BIT_COMPUTE_DTYPE = "bfloat16" # bnb_4bit_compute_dtype
compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)
bnb_configs = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=USE_NESTED_QUANT,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=huggingface_token,
quantization_config=bnb_configs,
load_in_8bit=False, # Since we use 4bits
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=huggingface_token)
model.eval()
return model, tokenizer