""" Download TinyLlama Model This script downloads the TinyLlama model from Hugging Face and prepares it for fine-tuning on SWIFT MT564 documentation. Usage: python download_tinyllama.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./data/models """ import os import argparse import logging from typing import Optional logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Download TinyLlama model from Hugging Face") parser.add_argument( "--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", help="Name of the TinyLlama model on Hugging Face Hub" ) parser.add_argument( "--output_dir", type=str, default="./data/models", help="Directory to save the downloaded model" ) parser.add_argument( "--use_auth_token", action="store_true", help="Use Hugging Face authentication token for downloading gated models" ) parser.add_argument( "--branch", type=str, default="main", help="Branch of the model repository to download from" ) parser.add_argument( "--check_integrity", action="store_true", help="Verify integrity of downloaded files" ) return parser.parse_args() def download_model( model_name: str, output_dir: str, use_auth_token: bool = False, branch: str = "main", check_integrity: bool = False ) -> Optional[str]: """ Download model and tokenizer from Hugging Face Hub Args: model_name: Name of the model on Hugging Face Hub output_dir: Directory to save the model use_auth_token: Whether to use Hugging Face token for gated models branch: Branch of the model repository check_integrity: Whether to verify integrity of downloaded files Returns: Path to the downloaded model or None if download failed """ try: # Import libraries here so the script doesn't fail if they're not installed import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download logger.info(f"Downloading model: {model_name}") os.makedirs(output_dir, exist_ok=True) # Create model directory model_output_dir = os.path.join(output_dir, model_name.split('/')[-1]) os.makedirs(model_output_dir, exist_ok=True) # Option 1: Use snapshot_download for more control if check_integrity: logger.info("Using snapshot_download with integrity checking") snapshot_download( repo_id=model_name, local_dir=model_output_dir, use_auth_token=use_auth_token if use_auth_token else None, revision=branch ) # Option 2: Use Transformers' download mechanism else: logger.info("Using Transformers' auto classes for downloading") # Download and save tokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, use_auth_token=use_auth_token if use_auth_token else None, revision=branch ) tokenizer.save_pretrained(model_output_dir) logger.info(f"Tokenizer saved to {model_output_dir}") # Download and save model model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_auth_token=use_auth_token if use_auth_token else None, revision=branch, low_cpu_mem_usage=True ) model.save_pretrained(model_output_dir) logger.info(f"Model saved to {model_output_dir}") logger.info(f"Successfully downloaded model to {model_output_dir}") return model_output_dir except ImportError as e: logger.error(f"Required libraries not installed: {e}") logger.error("Please install required packages: pip install torch transformers huggingface_hub") return None except Exception as e: logger.error(f"Error downloading model: {e}") return None def main(): args = parse_args() # Check if HUGGING_FACE_TOKEN environment variable is set if args.use_auth_token and "HUGGING_FACE_TOKEN" not in os.environ: logger.warning("--use_auth_token flag is set but HUGGING_FACE_TOKEN environment variable is not found.") logger.warning("You can set it using: export HUGGING_FACE_TOKEN=your_token_here") # Download the model model_path = download_model( model_name=args.model_name, output_dir=args.output_dir, use_auth_token=args.use_auth_token, branch=args.branch, check_integrity=args.check_integrity ) if model_path: logger.info(f"Model downloaded successfully to: {model_path}") logger.info("You can now use this model for fine-tuning with the train_mt564_model.py script.") else: logger.error("Failed to download the model.") if __name__ == "__main__": main()