Spaces:
Running
Running
File size: 5,415 Bytes
2c72e40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
Download TinyLlama Model
This script downloads the TinyLlama model from Hugging Face and prepares it
for fine-tuning on SWIFT MT564 documentation.
Usage:
python download_tinyllama.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./data/models
"""
import os
import argparse
import logging
from typing import Optional
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Download TinyLlama model from Hugging Face")
parser.add_argument(
"--model_name",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
help="Name of the TinyLlama model on Hugging Face Hub"
)
parser.add_argument(
"--output_dir",
type=str,
default="./data/models",
help="Directory to save the downloaded model"
)
parser.add_argument(
"--use_auth_token",
action="store_true",
help="Use Hugging Face authentication token for downloading gated models"
)
parser.add_argument(
"--branch",
type=str,
default="main",
help="Branch of the model repository to download from"
)
parser.add_argument(
"--check_integrity",
action="store_true",
help="Verify integrity of downloaded files"
)
return parser.parse_args()
def download_model(
model_name: str,
output_dir: str,
use_auth_token: bool = False,
branch: str = "main",
check_integrity: bool = False
) -> Optional[str]:
"""
Download model and tokenizer from Hugging Face Hub
Args:
model_name: Name of the model on Hugging Face Hub
output_dir: Directory to save the model
use_auth_token: Whether to use Hugging Face token for gated models
branch: Branch of the model repository
check_integrity: Whether to verify integrity of downloaded files
Returns:
Path to the downloaded model or None if download failed
"""
try:
# Import libraries here so the script doesn't fail if they're not installed
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
logger.info(f"Downloading model: {model_name}")
os.makedirs(output_dir, exist_ok=True)
# Create model directory
model_output_dir = os.path.join(output_dir, model_name.split('/')[-1])
os.makedirs(model_output_dir, exist_ok=True)
# Option 1: Use snapshot_download for more control
if check_integrity:
logger.info("Using snapshot_download with integrity checking")
snapshot_download(
repo_id=model_name,
local_dir=model_output_dir,
use_auth_token=use_auth_token if use_auth_token else None,
revision=branch
)
# Option 2: Use Transformers' download mechanism
else:
logger.info("Using Transformers' auto classes for downloading")
# Download and save tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_auth_token=use_auth_token if use_auth_token else None,
revision=branch
)
tokenizer.save_pretrained(model_output_dir)
logger.info(f"Tokenizer saved to {model_output_dir}")
# Download and save model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=use_auth_token if use_auth_token else None,
revision=branch,
low_cpu_mem_usage=True
)
model.save_pretrained(model_output_dir)
logger.info(f"Model saved to {model_output_dir}")
logger.info(f"Successfully downloaded model to {model_output_dir}")
return model_output_dir
except ImportError as e:
logger.error(f"Required libraries not installed: {e}")
logger.error("Please install required packages: pip install torch transformers huggingface_hub")
return None
except Exception as e:
logger.error(f"Error downloading model: {e}")
return None
def main():
args = parse_args()
# Check if HUGGING_FACE_TOKEN environment variable is set
if args.use_auth_token and "HUGGING_FACE_TOKEN" not in os.environ:
logger.warning("--use_auth_token flag is set but HUGGING_FACE_TOKEN environment variable is not found.")
logger.warning("You can set it using: export HUGGING_FACE_TOKEN=your_token_here")
# Download the model
model_path = download_model(
model_name=args.model_name,
output_dir=args.output_dir,
use_auth_token=args.use_auth_token,
branch=args.branch,
check_integrity=args.check_integrity
)
if model_path:
logger.info(f"Model downloaded successfully to: {model_path}")
logger.info("You can now use this model for fine-tuning with the train_mt564_model.py script.")
else:
logger.error("Failed to download the model.")
if __name__ == "__main__":
main() |