import torch from transformers import ( AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM, ) def model_tokenizer_factory( model_name, huggingface_token: str, ): # bitsandbytes config USE_NESTED_QUANT = True # use_nested_quant BNB_4BIT_COMPUTE_DTYPE = "bfloat16" # bnb_4bit_compute_dtype compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE) bnb_configs = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=USE_NESTED_QUANT, ) model = AutoModelForCausalLM.from_pretrained( model_name, token=huggingface_token, quantization_config=bnb_configs, load_in_8bit=False, # Since we use 4bits trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16, ) tokenizer = AutoTokenizer.from_pretrained(model_name, token=huggingface_token) model.eval() return model, tokenizer