Spaces:
Running
Running
""" | |
Download TinyLlama Model | |
This script downloads the TinyLlama model from Hugging Face and prepares it | |
for fine-tuning on SWIFT MT564 documentation. | |
Usage: | |
python download_tinyllama.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./data/models | |
""" | |
import os | |
import argparse | |
import logging | |
from typing import Optional | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Download TinyLlama model from Hugging Face") | |
parser.add_argument( | |
"--model_name", | |
type=str, | |
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
help="Name of the TinyLlama model on Hugging Face Hub" | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./data/models", | |
help="Directory to save the downloaded model" | |
) | |
parser.add_argument( | |
"--use_auth_token", | |
action="store_true", | |
help="Use Hugging Face authentication token for downloading gated models" | |
) | |
parser.add_argument( | |
"--branch", | |
type=str, | |
default="main", | |
help="Branch of the model repository to download from" | |
) | |
parser.add_argument( | |
"--check_integrity", | |
action="store_true", | |
help="Verify integrity of downloaded files" | |
) | |
return parser.parse_args() | |
def download_model( | |
model_name: str, | |
output_dir: str, | |
use_auth_token: bool = False, | |
branch: str = "main", | |
check_integrity: bool = False | |
) -> Optional[str]: | |
""" | |
Download model and tokenizer from Hugging Face Hub | |
Args: | |
model_name: Name of the model on Hugging Face Hub | |
output_dir: Directory to save the model | |
use_auth_token: Whether to use Hugging Face token for gated models | |
branch: Branch of the model repository | |
check_integrity: Whether to verify integrity of downloaded files | |
Returns: | |
Path to the downloaded model or None if download failed | |
""" | |
try: | |
# Import libraries here so the script doesn't fail if they're not installed | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import snapshot_download | |
logger.info(f"Downloading model: {model_name}") | |
os.makedirs(output_dir, exist_ok=True) | |
# Create model directory | |
model_output_dir = os.path.join(output_dir, model_name.split('/')[-1]) | |
os.makedirs(model_output_dir, exist_ok=True) | |
# Option 1: Use snapshot_download for more control | |
if check_integrity: | |
logger.info("Using snapshot_download with integrity checking") | |
snapshot_download( | |
repo_id=model_name, | |
local_dir=model_output_dir, | |
use_auth_token=use_auth_token if use_auth_token else None, | |
revision=branch | |
) | |
# Option 2: Use Transformers' download mechanism | |
else: | |
logger.info("Using Transformers' auto classes for downloading") | |
# Download and save tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
use_auth_token=use_auth_token if use_auth_token else None, | |
revision=branch | |
) | |
tokenizer.save_pretrained(model_output_dir) | |
logger.info(f"Tokenizer saved to {model_output_dir}") | |
# Download and save model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_auth_token=use_auth_token if use_auth_token else None, | |
revision=branch, | |
low_cpu_mem_usage=True | |
) | |
model.save_pretrained(model_output_dir) | |
logger.info(f"Model saved to {model_output_dir}") | |
logger.info(f"Successfully downloaded model to {model_output_dir}") | |
return model_output_dir | |
except ImportError as e: | |
logger.error(f"Required libraries not installed: {e}") | |
logger.error("Please install required packages: pip install torch transformers huggingface_hub") | |
return None | |
except Exception as e: | |
logger.error(f"Error downloading model: {e}") | |
return None | |
def main(): | |
args = parse_args() | |
# Check if HUGGING_FACE_TOKEN environment variable is set | |
if args.use_auth_token and "HUGGING_FACE_TOKEN" not in os.environ: | |
logger.warning("--use_auth_token flag is set but HUGGING_FACE_TOKEN environment variable is not found.") | |
logger.warning("You can set it using: export HUGGING_FACE_TOKEN=your_token_here") | |
# Download the model | |
model_path = download_model( | |
model_name=args.model_name, | |
output_dir=args.output_dir, | |
use_auth_token=args.use_auth_token, | |
branch=args.branch, | |
check_integrity=args.check_integrity | |
) | |
if model_path: | |
logger.info(f"Model downloaded successfully to: {model_path}") | |
logger.info("You can now use this model for fine-tuning with the train_mt564_model.py script.") | |
else: | |
logger.error("Failed to download the model.") | |
if __name__ == "__main__": | |
main() |