File size: 5,415 Bytes
2c72e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Download TinyLlama Model

This script downloads the TinyLlama model from Hugging Face and prepares it
for fine-tuning on SWIFT MT564 documentation.

Usage:
    python download_tinyllama.py --model_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./data/models
"""

import os
import argparse
import logging
from typing import Optional

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="Download TinyLlama model from Hugging Face")
    parser.add_argument(
        "--model_name",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        help="Name of the TinyLlama model on Hugging Face Hub"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./data/models",
        help="Directory to save the downloaded model"
    )
    parser.add_argument(
        "--use_auth_token",
        action="store_true",
        help="Use Hugging Face authentication token for downloading gated models"
    )
    parser.add_argument(
        "--branch",
        type=str,
        default="main",
        help="Branch of the model repository to download from"
    )
    parser.add_argument(
        "--check_integrity",
        action="store_true",
        help="Verify integrity of downloaded files"
    )
    return parser.parse_args()

def download_model(
    model_name: str,
    output_dir: str,
    use_auth_token: bool = False,
    branch: str = "main",
    check_integrity: bool = False
) -> Optional[str]:
    """
    Download model and tokenizer from Hugging Face Hub
    
    Args:
        model_name: Name of the model on Hugging Face Hub
        output_dir: Directory to save the model
        use_auth_token: Whether to use Hugging Face token for gated models
        branch: Branch of the model repository
        check_integrity: Whether to verify integrity of downloaded files
        
    Returns:
        Path to the downloaded model or None if download failed
    """
    try:
        # Import libraries here so the script doesn't fail if they're not installed
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from huggingface_hub import snapshot_download
        
        logger.info(f"Downloading model: {model_name}")
        os.makedirs(output_dir, exist_ok=True)
        
        # Create model directory
        model_output_dir = os.path.join(output_dir, model_name.split('/')[-1])
        os.makedirs(model_output_dir, exist_ok=True)
        
        # Option 1: Use snapshot_download for more control
        if check_integrity:
            logger.info("Using snapshot_download with integrity checking")
            snapshot_download(
                repo_id=model_name,
                local_dir=model_output_dir,
                use_auth_token=use_auth_token if use_auth_token else None,
                revision=branch
            )
        
        # Option 2: Use Transformers' download mechanism
        else:
            logger.info("Using Transformers' auto classes for downloading")
            # Download and save tokenizer
            tokenizer = AutoTokenizer.from_pretrained(
                model_name, 
                use_auth_token=use_auth_token if use_auth_token else None,
                revision=branch
            )
            tokenizer.save_pretrained(model_output_dir)
            logger.info(f"Tokenizer saved to {model_output_dir}")
            
            # Download and save model
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                use_auth_token=use_auth_token if use_auth_token else None,
                revision=branch,
                low_cpu_mem_usage=True
            )
            model.save_pretrained(model_output_dir)
            logger.info(f"Model saved to {model_output_dir}")
        
        logger.info(f"Successfully downloaded model to {model_output_dir}")
        return model_output_dir
    
    except ImportError as e:
        logger.error(f"Required libraries not installed: {e}")
        logger.error("Please install required packages: pip install torch transformers huggingface_hub")
        return None
    except Exception as e:
        logger.error(f"Error downloading model: {e}")
        return None

def main():
    args = parse_args()
    
    # Check if HUGGING_FACE_TOKEN environment variable is set
    if args.use_auth_token and "HUGGING_FACE_TOKEN" not in os.environ:
        logger.warning("--use_auth_token flag is set but HUGGING_FACE_TOKEN environment variable is not found.")
        logger.warning("You can set it using: export HUGGING_FACE_TOKEN=your_token_here")
    
    # Download the model
    model_path = download_model(
        model_name=args.model_name,
        output_dir=args.output_dir,
        use_auth_token=args.use_auth_token,
        branch=args.branch,
        check_integrity=args.check_integrity
    )
    
    if model_path:
        logger.info(f"Model downloaded successfully to: {model_path}")
        logger.info("You can now use this model for fine-tuning with the train_mt564_model.py script.")
    else:
        logger.error("Failed to download the model.")

if __name__ == "__main__":
    main()