HoNLP_Project / download_model.py
darpanaswal's picture
Upload 7 files
4fe7b26 verified
import torch
import huggingface_hub
from transformers import (AutoTokenizer,
BitsAndBytesConfig,
MBart50TokenizerFast,
AutoModelForSeq2SeqLM,
MBartForConditionalGeneration)
device = "cuda" if torch.cuda.is_available() else "cpu"
def download_model(model_name: str):
"""Downloads the specified model."""
if model_name == "mT5":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-xl",
quantization_config=bnb_config,
device_map="auto").to(device)
tokenizer = AutoTokenizer.from_pretrained("google/mt5-xl")
return model, tokenizer
elif model_name == "mBART50":
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50").to(device)
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="en_XX")
return model, tokenizer
elif model_name == "Llama-3.2-1B-Instruct":
str1 = "f_bgSZT"
str2 = "AFSBqvApwHjMQuTOALqZKRpRBzEUL"
token = "h"+str1+str2
huggingface_hub.login(token = token)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
return model, tokenizer
else:
raise ValueError("Invalid model name. Choose from 'mT5', 'mBART', 'Llama'.")