File size: 3,774 Bytes
6416f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
"""
Startup test script for Hugging Face Spaces deployment
This script helps debug model loading issues
"""

import os
import sys
import time
import logging
import traceback

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def test_imports():
    """Test if all required packages can be imported"""
    logger.info("Testing imports...")
    
    try:
        import torch
        logger.info(f"PyTorch version: {torch.__version__}")
    except ImportError as e:
        logger.error(f"Failed to import torch: {e}")
        return False
    
    try:
        import transformers
        logger.info(f"Transformers version: {transformers.__version__}")
    except ImportError as e:
        logger.error(f"Failed to import transformers: {e}")
        return False
    
    try:
        import peft
        logger.info(f"PEFT version: {peft.__version__}")
    except ImportError as e:
        logger.error(f"Failed to import peft: {e}")
        return False
    
    try:
        import fastapi
        logger.info(f"FastAPI version: {fastapi.__version__}")
    except ImportError as e:
        logger.error(f"Failed to import fastapi: {e}")
        return False
    
    return True

def test_model_files():
    """Test if model files exist"""
    logger.info("Testing model files...")
    
    model_dir = "./final-model"
    required_files = [
        "adapter_config.json",
        "adapter_model.safetensors",
        "tokenizer.json",
        "tokenizer_config.json",
        "vocab.json"
    ]
    
    if not os.path.exists(model_dir):
        logger.error(f"Model directory {model_dir} does not exist")
        return False
    
    missing_files = []
    for file in required_files:
        file_path = os.path.join(model_dir, file)
        if not os.path.exists(file_path):
            missing_files.append(file)
        else:
            size = os.path.getsize(file_path)
            logger.info(f"✓ {file} exists ({size} bytes)")
    
    if missing_files:
        logger.error(f"Missing required files: {missing_files}")
        return False
    
    return True

def test_model_loading():
    """Test model loading with timeout"""
    logger.info("Testing model loading...")
    
    try:
        from model_utils import get_model
        
        start_time = time.time()
        model = get_model()
        load_time = time.time() - start_time
        
        logger.info(f"Model loaded successfully in {load_time:.2f} seconds")
        
        # Test a simple prediction
        test_question = "How many records are there?"
        test_headers = ["id", "name", "age"]
        
        start_time = time.time()
        result = model.predict(test_question, test_headers)
        predict_time = time.time() - start_time
        
        logger.info(f"Test prediction successful in {predict_time:.2f} seconds")
        logger.info(f"Generated SQL: {result}")
        
        return True
        
    except Exception as e:
        logger.error(f"Model loading failed: {e}")
        logger.error(traceback.format_exc())
        return False

def main():
    """Run all tests"""
    logger.info("Starting Hugging Face Spaces deployment tests...")
    
    # Test 1: Imports
    if not test_imports():
        logger.error("Import test failed")
        sys.exit(1)
    
    # Test 2: Model files
    if not test_model_files():
        logger.error("Model files test failed")
        sys.exit(1)
    
    # Test 3: Model loading
    if not test_model_loading():
        logger.error("Model loading test failed")
        sys.exit(1)
    
    logger.info("All tests passed! Ready for deployment.")

if __name__ == "__main__":
    main()