MT564AITraining / model /download_tinyllama.py
pareshmishra
Add full project source files for MT564 AI
2c72e40
"""
Download TinyLlama Model
This script downloads the TinyLlama model from Hugging Face and prepares it
for fine-tuning on SWIFT MT564 documentation.
Usage:
python download_tinyllama.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./data/models
"""
import os
import argparse
import logging
from typing import Optional
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Download TinyLlama model from Hugging Face")
parser.add_argument(
"--model_name",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
help="Name of the TinyLlama model on Hugging Face Hub"
)
parser.add_argument(
"--output_dir",
type=str,
default="./data/models",
help="Directory to save the downloaded model"
)
parser.add_argument(
"--use_auth_token",
action="store_true",
help="Use Hugging Face authentication token for downloading gated models"
)
parser.add_argument(
"--branch",
type=str,
default="main",
help="Branch of the model repository to download from"
)
parser.add_argument(
"--check_integrity",
action="store_true",
help="Verify integrity of downloaded files"
)
return parser.parse_args()
def download_model(
model_name: str,
output_dir: str,
use_auth_token: bool = False,
branch: str = "main",
check_integrity: bool = False
) -> Optional[str]:
"""
Download model and tokenizer from Hugging Face Hub
Args:
model_name: Name of the model on Hugging Face Hub
output_dir: Directory to save the model
use_auth_token: Whether to use Hugging Face token for gated models
branch: Branch of the model repository
check_integrity: Whether to verify integrity of downloaded files
Returns:
Path to the downloaded model or None if download failed
"""
try:
# Import libraries here so the script doesn't fail if they're not installed
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
logger.info(f"Downloading model: {model_name}")
os.makedirs(output_dir, exist_ok=True)
# Create model directory
model_output_dir = os.path.join(output_dir, model_name.split('/')[-1])
os.makedirs(model_output_dir, exist_ok=True)
# Option 1: Use snapshot_download for more control
if check_integrity:
logger.info("Using snapshot_download with integrity checking")
snapshot_download(
repo_id=model_name,
local_dir=model_output_dir,
use_auth_token=use_auth_token if use_auth_token else None,
revision=branch
)
# Option 2: Use Transformers' download mechanism
else:
logger.info("Using Transformers' auto classes for downloading")
# Download and save tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_auth_token=use_auth_token if use_auth_token else None,
revision=branch
)
tokenizer.save_pretrained(model_output_dir)
logger.info(f"Tokenizer saved to {model_output_dir}")
# Download and save model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=use_auth_token if use_auth_token else None,
revision=branch,
low_cpu_mem_usage=True
)
model.save_pretrained(model_output_dir)
logger.info(f"Model saved to {model_output_dir}")
logger.info(f"Successfully downloaded model to {model_output_dir}")
return model_output_dir
except ImportError as e:
logger.error(f"Required libraries not installed: {e}")
logger.error("Please install required packages: pip install torch transformers huggingface_hub")
return None
except Exception as e:
logger.error(f"Error downloading model: {e}")
return None
def main():
args = parse_args()
# Check if HUGGING_FACE_TOKEN environment variable is set
if args.use_auth_token and "HUGGING_FACE_TOKEN" not in os.environ:
logger.warning("--use_auth_token flag is set but HUGGING_FACE_TOKEN environment variable is not found.")
logger.warning("You can set it using: export HUGGING_FACE_TOKEN=your_token_here")
# Download the model
model_path = download_model(
model_name=args.model_name,
output_dir=args.output_dir,
use_auth_token=args.use_auth_token,
branch=args.branch,
check_integrity=args.check_integrity
)
if model_path:
logger.info(f"Model downloaded successfully to: {model_path}")
logger.info("You can now use this model for fine-tuning with the train_mt564_model.py script.")
else:
logger.error("Failed to download the model.")
if __name__ == "__main__":
main()